From dbac2c1b5ef7f8dd080aafc19b6d4c64bfe9fe99 Mon Sep 17 00:00:00 2001 From: Andrea Rizzi Date: Fri, 22 Dec 2023 18:03:24 +0100 Subject: [PATCH] Fix bug unrepresented state. The current check raised an error even if all labels were present if one of the label was only in the first sample of the batch. --- mlcolvar/core/loss/tda_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlcolvar/core/loss/tda_loss.py b/mlcolvar/core/loss/tda_loss.py index 42ab3ebe..6ddc57ec 100644 --- a/mlcolvar/core/loss/tda_loss.py +++ b/mlcolvar/core/loss/tda_loss.py @@ -156,7 +156,7 @@ def tda_loss( for i in range(n_states): # check which elements belong to class i - if not torch.nonzero(labels == i).any(): + if not (labels == i).any(): raise ValueError( f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!" )