diff --git a/pytorch_tabnet/multitask.py b/pytorch_tabnet/multitask.py index 9666bf4f..c69d4cdd 100644 --- a/pytorch_tabnet/multitask.py +++ b/pytorch_tabnet/multitask.py @@ -3,7 +3,7 @@ from scipy.special import softmax from pytorch_tabnet.utils import PredictDataset, filter_weights from pytorch_tabnet.abstract_model import TabModel -from pytorch_tabnet.multiclass_utils import infer_multitask_output +from pytorch_tabnet.multiclass_utils import infer_multitask_output, check_output_dim from torch.utils.data import DataLoader @@ -60,6 +60,9 @@ def stack_batches(self, list_y_true, list_y_score): def update_fit_params(self, X_train, y_train, eval_set, weights): output_dim, train_labels = infer_multitask_output(y_train) + for _, y in eval_set: + for task_idx in range(y.shape[1]): + check_output_dim(train_labels[task_idx], y[:, task_idx]) self.output_dim = output_dim self.classes_ = train_labels self.target_mapper = [