From 6318a734baab99ba8c2db6c96c86aa72a60b8339 Mon Sep 17 00:00:00 2001 From: Cedric Ewen Date: Wed, 14 Jun 2023 16:40:46 +0200 Subject: [PATCH] fix typo --- src/data/jetnet_datamodule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/data/jetnet_datamodule.py b/src/data/jetnet_datamodule.py index 613f02e7..391a3d97 100644 --- a/src/data/jetnet_datamodule.py +++ b/src/data/jetnet_datamodule.py @@ -287,9 +287,11 @@ def setup(self, stage: Optional[str] = None): mask_test = torch.tensor(np.expand_dims(mask_test[..., 0], axis=-1)) if self.hparams.normalize: - self.data_train = TensorDataset(tensor_train, mask_train) - self.data_val = TensorDataset(tensor_val, mask_val) - self.data_test = TensorDataset(tensor_test, mask_test) + self.data_train = TensorDataset( + tensor_train, mask_train, tensor_conditioning_train + ) + self.data_val = TensorDataset(tensor_val, mask_val, tensor_conditioning_val) + self.data_test = TensorDataset(tensor_test, mask_test, tensor_conditioning_test) self.means = torch.tensor(means) self.stds = torch.tensor(stds) @@ -315,7 +317,7 @@ def setup(self, stage: Optional[str] = None): self.mask_val = unnormalized_mask_val self.x_mean = x_mean self.x_cov = x_cov - self.tensor_conditioning_train = tensor_conditioning_test + self.tensor_conditioning_train = tensor_conditioning_train self.tensor_conditioning_val = tensor_conditioning_val self.tensor_conditioning_test = tensor_conditioning_test