Skip to content

Commit

Permalink
feat: speedups
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Oct 12, 2020
1 parent d871406 commit 5a01359
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 7 deletions.
15 changes: 11 additions & 4 deletions pytorch_tabnet/abstract_model.py
Expand Up @@ -134,7 +134,7 @@ def fit(
self.drop_last = drop_last
self.input_dim = X_train.shape[1]
self._stop_training = False
self.eval_set = eval_set if eval_set else []
eval_set = eval_set if eval_set else []

if loss_fn is None:
self.loss_fn = self._default_loss
Expand Down Expand Up @@ -204,7 +204,10 @@ def predict(self, X):
"""
self.network.eval()
dataloader = DataLoader(
PredictDataset(X), batch_size=self.batch_size, shuffle=False
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

results = []
Expand Down Expand Up @@ -235,7 +238,10 @@ def explain(self, X):
self.network.eval()

dataloader = DataLoader(
PredictDataset(X), batch_size=self.batch_size, shuffle=False
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

res_explain = []
Expand Down Expand Up @@ -369,7 +375,8 @@ def _train_batch(self, X, y):
X = X.to(self.device).float()
y = y.to(self.device).float()

self._optimizer.zero_grad()
for param in self.network.parameters():
param.grad = None

output, M_loss = self.network(X)

Expand Down
10 changes: 8 additions & 2 deletions pytorch_tabnet/multitask.py
Expand Up @@ -91,7 +91,10 @@ def predict(self, X):
"""
self.network.eval()
dataloader = DataLoader(
PredictDataset(X), batch_size=self.batch_size, shuffle=False
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

results = {}
Expand Down Expand Up @@ -136,7 +139,10 @@ def predict_proba(self, X):
self.network.eval()

dataloader = DataLoader(
PredictDataset(X), batch_size=self.batch_size, shuffle=False
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

results = {}
Expand Down
5 changes: 4 additions & 1 deletion pytorch_tabnet/tab_model.py
Expand Up @@ -90,7 +90,10 @@ def predict_proba(self, X):
self.network.eval()

dataloader = DataLoader(
PredictDataset(X), batch_size=self.batch_size, shuffle=False
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)

results = []
Expand Down
2 changes: 2 additions & 0 deletions pytorch_tabnet/utils.py
Expand Up @@ -117,6 +117,7 @@ def create_dataloaders(
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True
)

valid_dataloaders = []
Expand All @@ -127,6 +128,7 @@ def create_dataloaders(
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
)

Expand Down

0 comments on commit 5a01359

Please sign in to comment.