From 28346c2259cabbed79e83963c4009eac3ae38f9e Mon Sep 17 00:00:00 2001 From: Optimox Date: Thu, 15 Oct 2020 16:37:41 +0200 Subject: [PATCH] fix: pin memory available for training only --- pytorch_tabnet/abstract_model.py | 9 +++++++-- pytorch_tabnet/multitask.py | 2 -- pytorch_tabnet/tab_model.py | 1 - pytorch_tabnet/utils.py | 6 +++--- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 64201b46..3a282232 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -82,6 +82,7 @@ def fit( num_workers=0, drop_last=False, callbacks=None, + pin_memory=True ): """Train a neural network stored in self.network Using train_dataloader for training data and @@ -119,6 +120,9 @@ def fit( Whether to drop last batch during training callbacks : list of callback function List of custom callbacks + pin_memory: bool + Whether to set pin_memory to True or False during training + """ # update model name @@ -130,6 +134,8 @@ def fit( self.drop_last = drop_last self.input_dim = X_train.shape[1] self._stop_training = False + self.pin_memory = pin_memory + eval_set = eval_set if eval_set else [] if loss_fn is None: @@ -203,7 +209,6 @@ def predict(self, X): PredictDataset(X), batch_size=self.batch_size, shuffle=False, - pin_memory=True ) results = [] @@ -237,7 +242,6 @@ def explain(self, X): PredictDataset(X), batch_size=self.batch_size, shuffle=False, - pin_memory=True ) res_explain = [] @@ -590,6 +594,7 @@ def _construct_loaders(self, X_train, y_train, eval_set): self.batch_size, self.num_workers, self.drop_last, + self.pin_memory, ) return train_dataloader, valid_dataloaders diff --git a/pytorch_tabnet/multitask.py b/pytorch_tabnet/multitask.py index c69d4cdd..148347bb 100644 --- a/pytorch_tabnet/multitask.py +++ b/pytorch_tabnet/multitask.py @@ -97,7 +97,6 @@ def predict(self, X): PredictDataset(X), batch_size=self.batch_size, shuffle=False, - pin_memory=True ) results = {} @@ -145,7 +144,6 @@ def predict_proba(self, X): PredictDataset(X), batch_size=self.batch_size, shuffle=False, - pin_memory=True ) results = {} diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 9a42649f..3441548e 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -93,7 +93,6 @@ def predict_proba(self, X): PredictDataset(X), batch_size=self.batch_size, shuffle=False, - pin_memory=True ) results = [] diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index 5bac2713..ce890d21 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -51,7 +51,7 @@ def __getitem__(self, index): def create_dataloaders( - X_train, y_train, eval_set, weights, batch_size, num_workers, drop_last + X_train, y_train, eval_set, weights, batch_size, num_workers, drop_last, pin_memory ): """ Create dataloaders with or wihtout subsampling depending on weights and balanced. @@ -117,7 +117,7 @@ def create_dataloaders( shuffle=need_shuffle, num_workers=num_workers, drop_last=drop_last, - pin_memory=True + pin_memory=pin_memory ) valid_dataloaders = [] @@ -128,7 +128,7 @@ def create_dataloaders( batch_size=batch_size, shuffle=False, num_workers=num_workers, - pin_memory=True + pin_memory=pin_memory ) )