Skip to content

Commit cf1454a

Browse files
snehilchatterjeepre-commit-ci[bot]manujosephv
authored
Bug fix for "Categorical" dtype (#493)
* Categorical bug fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * dataloader kwargs part removed --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Manu Joseph V <manujosephv@gmail.com>
1 parent a3272ea commit cf1454a

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/pytorch_tabular/categorical_encoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def transform(self, X):
6262
not X[self.cols].isnull().any().any()
6363
), "`handle_missing` = `error` and missing values found in columns to encode."
6464
X_encoded = X.copy(deep=True)
65+
category_cols = X_encoded.select_dtypes(include="category").columns
66+
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
6567
for col, mapping in self._mapping.items():
6668
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
6769

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig:
301301
else:
302302
raise ValueError(f"{config.task} is an unsupported task.")
303303
if self.train is not None:
304+
category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns
305+
self.train[category_cols] = self.train[category_cols].astype("object")
304306
categorical_cardinality = [
305307
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
306308
]
307309
else:
310+
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns
311+
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object")
308312
categorical_cardinality = [
309313
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
310314
]

0 commit comments

Comments
 (0)