diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 8428722b..5241a395 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -290,11 +290,11 @@ def _update_config(self, config) -> InferredConfig: raise ValueError(f"{config.task} is an unsupported task.") if self.train is not None: categorical_cardinality = [ - int(self.train[col].fillna("NA").nunique()) + 1 for col in config.categorical_cols + int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values) ] else: categorical_cardinality = [ - int(self.train_dataset.data[col].nunique()) + 1 for col in config.categorical_cols + int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values) ] if getattr(config, "embedding_dims", None) is not None: embedding_dims = config.embedding_dims