Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/336 : check if pandas df and drop_last default to True #338

Merged
merged 4 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
create_dataloaders,
define_device,
ComplexEncoder,
check_input
)
from pytorch_tabnet.callbacks import (
CallbackContainer,
Expand All @@ -22,7 +23,7 @@
)
from pytorch_tabnet.metrics import MetricContainer, check_metrics
from sklearn.base import BaseEstimator
from sklearn.utils import check_array

from torch.utils.data import DataLoader
import io
import json
Expand Down Expand Up @@ -115,7 +116,7 @@ def fit(
batch_size=1024,
virtual_batch_size=128,
num_workers=0,
drop_last=False,
drop_last=True,
callbacks=None,
pin_memory=True,
from_unsupervised=None,
Expand Down Expand Up @@ -182,7 +183,7 @@ def fit(
else:
self.loss_fn = loss_fn

check_array(X_train)
check_input(X_train)

self.update_fit_params(
X_train,
Expand Down
8 changes: 4 additions & 4 deletions pytorch_tabnet/pretraining.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch
import numpy as np
from sklearn.utils import check_array
from torch.utils.data import DataLoader
from pytorch_tabnet import tab_network
from pytorch_tabnet.utils import (
create_explain_matrix,
filter_weights,
PredictDataset
PredictDataset,
check_input
)
from torch.nn.utils import clip_grad_norm_
from pytorch_tabnet.pretraining_utils import (
Expand Down Expand Up @@ -55,7 +55,7 @@ def fit(
batch_size=1024,
virtual_batch_size=128,
num_workers=0,
drop_last=False,
drop_last=True,
callbacks=None,
pin_memory=True,
):
Expand Down Expand Up @@ -118,7 +118,7 @@ def fit(
else:
self.loss_fn = loss_fn

check_array(X_train)
check_input(X_train)

self.update_fit_params(
weights,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_tabnet/pretraining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from pytorch_tabnet.utils import (
create_sampler,
PredictDataset,
check_input
)
from sklearn.utils import check_array


def create_dataloaders(
Expand Down Expand Up @@ -93,7 +93,7 @@ def validate_eval_set(eval_set, eval_name, X_train):
), "eval_set and eval_name have not the same length"

for set_nb, X in enumerate(eval_set):
check_array(X)
check_input(X)
msg = (
f"Number of columns is different between eval set {set_nb}"
+ f"({X.shape[1]}) and X_train ({X_train.shape[1]})"
Expand Down
15 changes: 14 additions & 1 deletion pytorch_tabnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import scipy
import json
from sklearn.utils import check_array
import pandas as pd


class TorchDataset(Dataset):
Expand Down Expand Up @@ -271,7 +272,7 @@ def validate_eval_set(eval_set, eval_name, X_train, y_train):
len(elem) == 2 for elem in eval_set
), "Each tuple of eval_set need to have two elements"
for name, (X, y) in zip(eval_name, eval_set):
check_array(X)
check_input(X)
msg = (
f"Dimension mismatch between X_{name} "
+ f"{X.shape} and X_train {X_train.shape}"
Expand Down Expand Up @@ -337,3 +338,15 @@ def default(self, obj):
return int(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)


def check_input(X):
"""
Raise a clear error if X is a pandas dataframe
and check array according to scikit rules
"""
if isinstance(X, (pd.DataFrame, pd.Series)):
err_message = "Pandas DataFrame are not supported: apply X.values when calling fit"
raise(ValueError, err_message)
check_array(X)
return