Skip to content

Commit

Permalink
fix: verbosity with schedulers
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Jun 18, 2020
1 parent d40b02f commit d6fbf90
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
15 changes: 14 additions & 1 deletion census_example.ipynb
Expand Up @@ -151,7 +151,10 @@
" cat_dims=cat_dims,\n",
" cat_emb_dim=1,\n",
" optimizer_fn=torch.optim.Adam,\n",
" optimizer_params=dict(lr=2e-2))"
" optimizer_params=dict(lr=2e-2),\n",
" scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n",
" \"gamma\":0.9},\n",
" scheduler_fn=torch.optim.lr_scheduler.StepLR)"
]
},
{
Expand Down Expand Up @@ -227,6 +230,16 @@
"plt.plot([-x for x in clf.history['valid']['metric']])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot learning rates\n",
"plt.plot([x for x in clf.history['train']['lr']])"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
31 changes: 17 additions & 14 deletions pytorch_tabnet/tab_model.py
Expand Up @@ -148,11 +148,11 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
else:
self.scheduler = None

losses_train = []
losses_valid = []

metrics_train = []
metrics_valid = []
self.losses_train = []
self.losses_valid = []
self.learning_rates = []
self.metrics_train = []
self.metrics_valid = []

if self.verbose > 0:
print("Will train until validation stopping metric",
Expand All @@ -165,13 +165,16 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
while (self.epoch < self.max_epochs and
self.patience_counter < self.patience):
starting_time = time.time()
# updates learning rate history
self.learning_rates.append(self.optimizer.param_groups[-1]["lr"])

fit_metrics = self.fit_epoch(train_dataloader, valid_dataloader)

# leaving it here, may be used for callbacks later
losses_train.append(fit_metrics['train']['loss_avg'])
losses_valid.append(fit_metrics['valid']['total_loss'])
metrics_train.append(fit_metrics['train']['stopping_loss'])
metrics_valid.append(fit_metrics['valid']['stopping_loss'])
self.losses_train.append(fit_metrics['train']['loss_avg'])
self.losses_valid.append(fit_metrics['valid']['total_loss'])
self.metrics_train.append(fit_metrics['train']['stopping_loss'])
self.metrics_valid.append(fit_metrics['valid']['stopping_loss'])

stopping_loss = fit_metrics['valid']['stopping_loss']
if stopping_loss < self.best_cost:
Expand Down Expand Up @@ -201,10 +204,11 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
print(f"Training done in {total_time:.3f} seconds.")
print('---------------------------------------')

self.history = {"train": {"loss": losses_train,
"metric": metrics_train},
"valid": {"loss": losses_valid,
"metric": metrics_valid}}
self.history = {"train": {"loss": self.losses_train,
"metric": self.metrics_train,
"lr": self.learning_rates},
"valid": {"loss": self.losses_valid,
"metric": self.metrics_valid}}
# load best models post training
self.load_best_model()

Expand Down Expand Up @@ -767,7 +771,6 @@ def train_epoch(self, train_loader):

if self.scheduler is not None:
self.scheduler.step()
print("Current learning rate: ", self.optimizer.param_groups[-1]["lr"])
return epoch_metrics

def train_batch(self, data, targets):
Expand Down

0 comments on commit d6fbf90

Please sign in to comment.