Skip to content

Commit

Permalink
fix: resolve timer issue and warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Jan 10, 2020
1 parent 8a2cd87 commit ecd2cd9
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
11 changes: 8 additions & 3 deletions census_example.ipynb
Expand Up @@ -19,7 +19,12 @@
"\n",
"import os\n",
"import wget\n",
"from pathlib import Path"
"from pathlib import Path\n",
"\n",
"\n",
"# This is due to torch1.3 bug : https://github.com/pytorch/pytorch/issues/27972\n",
"import warnings\n",
"warnings.simplefilter(\"ignore\", UserWarning)"
]
},
{
Expand Down Expand Up @@ -172,7 +177,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
"scrolled": true
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -336,7 +341,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.7.6"
}
},
"nbformat": 4,
Expand Down
8 changes: 6 additions & 2 deletions forest_example.ipynb
Expand Up @@ -21,7 +21,11 @@
"import wget\n",
"from pathlib import Path\n",
"import shutil\n",
"import gzip"
"import gzip\n",
"\n",
"# This is due to torch1.3 bug : https://github.com/pytorch/pytorch/issues/27972\n",
"import warnings\n",
"warnings.simplefilter(\"ignore\", UserWarning)"
]
},
{
Expand Down Expand Up @@ -381,7 +385,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.7.6"
}
},
"nbformat": 4,
Expand Down
11 changes: 7 additions & 4 deletions pytorch_tabnet/tab_model.py
Expand Up @@ -150,14 +150,14 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None,
if self.verbose > 0:
print("Will train until validation stopping metric",
f"hasn't improved in {self.patience} rounds.")
msg_epoch = f'| EPOCH | train | valid | total time (s)'
msg_epoch = f'| EPOCH | train | valid | total time (s)'
print('---------------------------------------')
print(msg_epoch)

starting_time = time.time()
total_time = 0
while (self.epoch < self.max_epochs and
self.patience_counter < self.patience):
starting_time = time.time()
fit_metrics = self.fit_epoch(train_dataloader, valid_dataloader)

# leaving it here, may be used for callbacks later
Expand All @@ -181,9 +181,12 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None,
total_time += time.time() - starting_time
if self.verbose > 0:
if self.epoch % self.verbose == 0:
separator = "|"
msg_epoch = f"| {self.epoch:<5} | "
msg_epoch += f"{-np.round(fit_metrics['train']['stopping_loss'], 5):<5} | "
msg_epoch += f"{-np.round(fit_metrics['valid']['stopping_loss'], 5):<5} | "
msg_epoch += f"{-fit_metrics['train']['stopping_loss']:.5f}"
msg_epoch += f' {separator:<2} '
msg_epoch += f"{-fit_metrics['valid']['stopping_loss']:.5f}"
msg_epoch += f' {separator:<2} '
msg_epoch += f" {np.round(total_time, 1):<10}"
print(msg_epoch)

Expand Down
8 changes: 6 additions & 2 deletions regression_example.ipynb
Expand Up @@ -19,7 +19,11 @@
"\n",
"import os\n",
"import wget\n",
"from pathlib import Path"
"from pathlib import Path\n",
"\n",
"# This is due to torch1.3 bug : https://github.com/pytorch/pytorch/issues/27972\n",
"import warnings\n",
"warnings.simplefilter(\"ignore\", UserWarning)"
]
},
{
Expand Down Expand Up @@ -335,7 +339,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.7.6"
}
},
"nbformat": 4,
Expand Down

0 comments on commit ecd2cd9

Please sign in to comment.