Skip to content

Commit

Permalink
fix: add check for evalset dim
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Oct 15, 2020
1 parent 5df2dd1 commit ba09980
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pytorch_tabnet/multitask.py
Expand Up @@ -3,7 +3,7 @@
from scipy.special import softmax
from pytorch_tabnet.utils import PredictDataset, filter_weights
from pytorch_tabnet.abstract_model import TabModel
from pytorch_tabnet.multiclass_utils import infer_multitask_output
from pytorch_tabnet.multiclass_utils import infer_multitask_output, check_output_dim
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -60,6 +60,9 @@ def stack_batches(self, list_y_true, list_y_score):

def update_fit_params(self, X_train, y_train, eval_set, weights):
output_dim, train_labels = infer_multitask_output(y_train)
for _, y in eval_set:
for task_idx in range(y.shape[1]):
check_output_dim(train_labels[task_idx], y[:, task_idx])
self.output_dim = output_dim
self.classes_ = train_labels
self.target_mapper = [
Expand Down

0 comments on commit ba09980

Please sign in to comment.