Skip to content

Commit

Permalink
fix: pin memory available for training only
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed Oct 15, 2020
1 parent 46a301f commit 28346c2
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
9 changes: 7 additions & 2 deletions pytorch_tabnet/abstract_model.py
Expand Up @@ -82,6 +82,7 @@ def fit(
num_workers=0,
drop_last=False,
callbacks=None,
pin_memory=True
):
"""Train a neural network stored in self.network
Using train_dataloader for training data and
Expand Down Expand Up @@ -119,6 +120,9 @@ def fit(
Whether to drop last batch during training
callbacks : list of callback function
List of custom callbacks
pin_memory: bool
Whether to set pin_memory to True or False during training
"""
# update model name

Expand All @@ -130,6 +134,8 @@ def fit(
self.drop_last = drop_last
self.input_dim = X_train.shape[1]
self._stop_training = False
self.pin_memory = pin_memory

eval_set = eval_set if eval_set else []

if loss_fn is None:
Expand Down Expand Up @@ -203,7 +209,6 @@ def predict(self, X):
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

results = []
Expand Down Expand Up @@ -237,7 +242,6 @@ def explain(self, X):
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

res_explain = []
Expand Down Expand Up @@ -590,6 +594,7 @@ def _construct_loaders(self, X_train, y_train, eval_set):
self.batch_size,
self.num_workers,
self.drop_last,
self.pin_memory,
)
return train_dataloader, valid_dataloaders

Expand Down
2 changes: 0 additions & 2 deletions pytorch_tabnet/multitask.py
Expand Up @@ -97,7 +97,6 @@ def predict(self, X):
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

results = {}
Expand Down Expand Up @@ -145,7 +144,6 @@ def predict_proba(self, X):
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

results = {}
Expand Down
1 change: 0 additions & 1 deletion pytorch_tabnet/tab_model.py
Expand Up @@ -93,7 +93,6 @@ def predict_proba(self, X):
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

results = []
Expand Down
6 changes: 3 additions & 3 deletions pytorch_tabnet/utils.py
Expand Up @@ -51,7 +51,7 @@ def __getitem__(self, index):


def create_dataloaders(
X_train, y_train, eval_set, weights, batch_size, num_workers, drop_last
X_train, y_train, eval_set, weights, batch_size, num_workers, drop_last, pin_memory
):
"""
Create dataloaders with or wihtout subsampling depending on weights and balanced.
Expand Down Expand Up @@ -117,7 +117,7 @@ def create_dataloaders(
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True
pin_memory=pin_memory
)

valid_dataloaders = []
Expand All @@ -128,7 +128,7 @@ def create_dataloaders(
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
pin_memory=pin_memory
)
)

Expand Down

0 comments on commit 28346c2

Please sign in to comment.