From 0ae114ff59900537cd3c48dc9d44669f52b9141e Mon Sep 17 00:00:00 2001 From: Optimox Date: Fri, 9 Oct 2020 18:21:15 +0200 Subject: [PATCH] feat: add easy schedulers --- .circleci/config.yml | 14 + Makefile | 4 + README.md | 2 + customizing_example.ipynb | 615 +++++++++++++++++++++++++++++++ pytorch_tabnet/abstract_model.py | 55 +-- pytorch_tabnet/callbacks.py | 66 +++- pytorch_tabnet/metrics.py | 2 +- 7 files changed, 723 insertions(+), 35 deletions(-) create mode 100755 customizing_example.ipynb diff --git a/.circleci/config.yml b/.circleci/config.yml index 11eece30..a388b54d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -165,6 +165,20 @@ jobs: shell: bash -leo pipefail command: | make test-nb-multi-task + test-nb-customization: + executor: python-executor + steps: + - checkout + # Download and cache dependencies + - restore_cache: + keys: + - v1-dependencies-{{ checksum "poetry.lock" }} + - install_poetry + - run: + name: run test-nb-customization + shell: bash -leo pipefail + command: | + make test-nb-customization workflows: version: 2 diff --git a/Makefile b/Makefile index 5b67f9c3..90cc15c1 100644 --- a/Makefile +++ b/Makefile @@ -93,6 +93,10 @@ test-nb-multi-task: ## run multi task classification example tests using noteboo $(MAKE) _run_notebook NB_FILE="./multi_task_example.ipynb" .PHONY: test-obfuscator +test-nb-customization: ## run customization example tests using notebooks + $(MAKE) _run_notebook NB_FILE="./customizing_example.ipynb" +.PHONY: test-obfuscator + help: ## Display help @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' .PHONY: help diff --git a/README.md b/README.md index 89cf7c15..501ab6b3 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,8 @@ clf.fit( ``` +A specific customization example notebook is available here : https://github.com/dreamquark-ai/tabnet/blob/develop/customizing_example.ipynb + # Useful links - explanatory video : https://youtu.be/ysBaZO8YmX8 diff --git a/customizing_example.ipynb b/customizing_example.ipynb new file mode 100755 index 00000000..34efe429 --- /dev/null +++ b/customizing_example.ipynb @@ -0,0 +1,615 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Customize a TabNet Model\n", + "\n", + "## This tutorial gives examples on how to easily customize a TabNet Model\n", + "\n", + "### 1 - Customizing your learning rate scheduler\n", + "\n", + "Almost all classical pytroch schedulers are now easy to integrate with pytorch-tabnet\n", + "\n", + "### 2 - Use your own loss function\n", + "\n", + "It's really easy to use any pytorch loss function with TabNet, we'll walk you through that\n", + "\n", + "\n", + "### 3 - Customizing your evaluation metric and evaluations sets\n", + "\n", + "Like XGBoost, you can easily monitor different metrics on different evaluation sets with pytorch-tabnet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_tabnet.tab_model import TabNetClassifier\n", + "\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.metrics import roc_auc_score\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "np.random.seed(0)\n", + "\n", + "\n", + "import os\n", + "import wget\n", + "from pathlib import Path\n", + "\n", + "from matplotlib import pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download census-income dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n", + "dataset_name = 'census-income'\n", + "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "out.parent.mkdir(parents=True, exist_ok=True)\n", + "if out.exists():\n", + " print(\"File already exists.\")\n", + "else:\n", + " print(\"Downloading file...\")\n", + " wget.download(url, out.as_posix())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load data and split" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train = pd.read_csv(out)\n", + "target = ' <=50K'\n", + "if \"Set\" not in train.columns:\n", + " train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n", + "\n", + "train_indices = train[train.Set==\"train\"].index\n", + "valid_indices = train[train.Set==\"valid\"].index\n", + "test_indices = train[train.Set==\"test\"].index" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Simple preprocessing\n", + "\n", + "Label encode categorical features and fill empty cells." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nunique = train.nunique()\n", + "types = train.dtypes\n", + "\n", + "categorical_columns = []\n", + "categorical_dims = {}\n", + "for col in train.columns:\n", + " if types[col] == 'object' or nunique[col] < 200:\n", + " print(col, train[col].nunique())\n", + " l_enc = LabelEncoder()\n", + " train[col] = train[col].fillna(\"VV_likely\")\n", + " train[col] = l_enc.fit_transform(train[col].values)\n", + " categorical_columns.append(col)\n", + " categorical_dims[col] = len(l_enc.classes_)\n", + " else:\n", + " train.fillna(train.loc[train_indices, col].mean(), inplace=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define categorical features for categorical embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "unused_feat = ['Set']\n", + "\n", + "features = [ col for col in train.columns if col not in unused_feat+[target]] \n", + "\n", + "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n", + "\n", + "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1 - Customizing your learning rate scheduler\n", + "\n", + "TabNetClassifier, TabNetRegressor and TabNetMultiTaskClassifier all takes two arguments:\n", + "- scheduler_fn : Any torch.optim.lr_scheduler should work\n", + "- scheduler_params : A dictionnary that contains the parameters of your scheduler (without the optimizer)\n", + "\n", + "----\n", + "NB1 : Some schedulers like torch.optim.lr_scheduler.ReduceLROnPlateau depend on the evolution of a metric, pytorch-tabnet will use the early stopping metric you asked (the last eval_metric, see 2-) to perform the schedulers updates\n", + "\n", + "EX1 : \n", + "```\n", + "scheduler_fn=torch.optim.lr_scheduler.ReduceLROnPlateau\n", + "scheduler_params={\"mode\":'max', # max because default eval metric for binary is AUC\n", + " \"factor\":0.1,\n", + " \"patience\":1}\n", + "```\n", + "\n", + "-----\n", + "NB2 : Some schedulers require updates at batch level, they can be used very easily the only thing to do is to add `is_batch_level` to True in your `scheduler_params`\n", + "\n", + "EX2:\n", + "```\n", + "scheduler_fn=torch.optim.lr_scheduler.CyclicLR\n", + "scheduler_params={\"is_batch_level\":True,\n", + " \"base_lr\":1e-3,\n", + " \"max_lr\":1e-2,\n", + " \"step_size_up\":100\n", + " }\n", + "```\n", + "\n", + "-----\n", + "NB3: Note that you can also customize your optimizer function, any torch optimizer should work" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Network parameters\n", + "max_epochs = 20 if not os.getenv(\"CI\", False) else 2\n", + "batch_size = 1024\n", + "clf = TabNetClassifier(cat_idxs=cat_idxs,\n", + " cat_dims=cat_dims,\n", + " cat_emb_dim=1,\n", + " optimizer_fn=torch.optim.Adam, # Any optimizer works here\n", + " optimizer_params=dict(lr=2e-2),\n", + " scheduler_fn=torch.optim.lr_scheduler.OneCycleLR,\n", + " scheduler_params={\"is_batch_level\":True,\n", + " \"max_lr\":5e-2,\n", + " \"steps_per_epoch\":int(train.shape[0] / batch_size)+1,\n", + " \"epochs\":max_epochs\n", + " },\n", + " mask_type='entmax', # \"sparsemax\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X_train = train[features].values[train_indices]\n", + "y_train = train[target].values[train_indices]\n", + "\n", + "X_valid = train[features].values[valid_indices]\n", + "y_valid = train[target].values[valid_indices]\n", + "\n", + "X_test = train[features].values[test_indices]\n", + "y_test = train[target].values[test_indices]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2 - Use your own loss function\n", + "\n", + "The default loss for classification is torch.nn.functional.cross_entropy\n", + "The default loss for regression is torch.nn.functional.mse_loss\n", + "\n", + "Any derivable loss function of the type lambda y_pred, y_true : loss(y_pred, y_true) should work if it uses torch computation (to allow gradients computations).\n", + "\n", + "In particular, any pytorch loss function should work.\n", + "\n", + "Once your loss is defined simply pass it loss_fn argument when defining your model.\n", + "\n", + "/!\\ : One important thing to keep in mind is that when computing the loss for TabNetClassifier and TabNetMultiTaskClassifier you'll need to apply first torch.nn.Softmax() to y_pred as the final model prediction is softmaxed automatically.\n", + "\n", + "NB : Tabnet also has an internal loss (the sparsity loss) which is summed to the loss_fn, the importance of the sparsity loss can be mitigated using `lambda_sparse` parameter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def my_loss_fn(y_pred, y_true):\n", + " \"\"\"\n", + " Dummy example similar to using default torch.nn.functional.cross_entropy\n", + " \"\"\"\n", + " softmax_pred = torch.nn.Softmax(dim=-1)(y_pred)\n", + " logloss = (1-y_true)*torch.log(softmax_pred[:,0])\n", + " logloss += y_true*torch.log(softmax_pred[:,1])\n", + " return -torch.mean(logloss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3 - Customizing your evaluation metric and evaluations sets\n", + "\n", + "When calling the `fit` method you can speficy:\n", + "- eval_set : a list of tuples like (X_valid, y_valid)\n", + " Note that the last value of this list will be used for early stopping\n", + "- eval_name : a list to name each eval set\n", + " default will be val_0, val_1 ...\n", + "- eval_metric : a list of default metrics or custom metrics\n", + " Default : \"auc\", \"accuracy\", \"logloss\", \"balanced_accuracy\", \"mse\", \"rmse\"\n", + " \n", + " \n", + "NB : If no eval_set is given no early stopping will occure (patience is then ignored) and the weights used will be the last epoch's weights\n", + "\n", + "NB2 : If `patience<=0` this will disable early stopping\n", + "\n", + "NB3 : Setting `patience` to `max_epochs` ensures that training won't be early stopped, but best weights from the best epochs will be used (instead of the last weight if early stopping is disabled)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_tabnet.metrics import Metric" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class my_metric(Metric):\n", + " \"\"\"\n", + " 2xAUC.\n", + " \"\"\"\n", + "\n", + " def __init__(self):\n", + " self._name = \"custom\" # write an understandable name here\n", + " self._maximize = True\n", + "\n", + " def __call__(self, y_true, y_score):\n", + " \"\"\"\n", + " Compute AUC of predictions.\n", + "\n", + " Parameters\n", + " ----------\n", + " y_true: np.ndarray\n", + " Target matrix or vector\n", + " y_score: np.ndarray\n", + " Score matrix or vector\n", + "\n", + " Returns\n", + " -------\n", + " float\n", + " AUC of predictions vs targets.\n", + " \"\"\"\n", + " return 2*roc_auc_score(y_true, y_score[:, 1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clf.fit(\n", + " X_train=X_train, y_train=y_train,\n", + " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", + " eval_name=['train', 'val'],\n", + " eval_metric=[\"auc\", my_metric],\n", + " max_epochs=max_epochs , patience=0,\n", + " batch_size=batch_size,\n", + " virtual_batch_size=128,\n", + " num_workers=0,\n", + " weights=1,\n", + " drop_last=False,\n", + " loss_fn=my_loss_fn\n", + ") \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot losses\n", + "plt.plot(clf.history['loss'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot auc\n", + "plt.plot(clf.history['train_auc'])\n", + "plt.plot(clf.history['val_auc'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot learning rates\n", + "plt.plot(clf.history['lr'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = clf.predict_proba(X_test)\n", + "test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n", + "\n", + "print(f\"FINAL VALID SCORE FOR {dataset_name} : {clf.history['val_auc'][-1]}\")\n", + "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save and load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save tabnet model\n", + "saving_path_name = \"./tabnet_model_test_1\"\n", + "saved_filepath = clf.save_model(saving_path_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define new model with basic parameters and load state dict weights\n", + "loaded_clf = TabNetClassifier()\n", + "loaded_clf.load_model(saved_filepath)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loaded_preds = loaded_clf.predict_proba(X_test)\n", + "loaded_test_auc = roc_auc_score(y_score=loaded_preds[:,1], y_true=y_test)\n", + "\n", + "print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_auc}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert(test_auc == loaded_test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Global explainability : feat importance summing to 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clf.feature_importances_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Local explainability and masks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "explain_matrix, masks = clf.explain(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(20,20))\n", + "\n", + "for i in range(3):\n", + " axs[i].imshow(masks[i][:50])\n", + " axs[i].set_title(f\"mask {i}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# XGB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from xgboost import XGBClassifier\n", + "\n", + "clf_xgb = XGBClassifier(max_depth=8,\n", + " learning_rate=0.1,\n", + " n_estimators=1000,\n", + " verbosity=0,\n", + " silent=None,\n", + " objective='binary:logistic',\n", + " booster='gbtree',\n", + " n_jobs=-1,\n", + " nthread=None,\n", + " gamma=0,\n", + " min_child_weight=1,\n", + " max_delta_step=0,\n", + " subsample=0.7,\n", + " colsample_bytree=1,\n", + " colsample_bylevel=1,\n", + " colsample_bynode=1,\n", + " reg_alpha=0,\n", + " reg_lambda=1,\n", + " scale_pos_weight=1,\n", + " base_score=0.5,\n", + " random_state=0,\n", + " seed=None,)\n", + "\n", + "clf_xgb.fit(X_train, y_train,\n", + " eval_set=[(X_valid, y_valid)],\n", + " early_stopping_rounds=40,\n", + " verbose=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = np.array(clf_xgb.predict_proba(X_valid))\n", + "valid_auc = roc_auc_score(y_score=preds[:,1], y_true=y_valid)\n", + "print(valid_auc)\n", + "\n", + "preds = np.array(clf_xgb.predict_proba(X_test))\n", + "test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n", + "print(test_auc)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 149e46be..98f35883 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -16,6 +16,7 @@ CallbackContainer, History, EarlyStopping, + LRSchedulerCallback, ) from pytorch_tabnet.metrics import MetricContainer, check_metrics from sklearn.base import BaseEstimator @@ -72,7 +73,7 @@ def fit( self, X_train, y_train, - eval_set=None, + eval_set=[], eval_name=None, eval_metric=None, loss_fn=None, @@ -133,9 +134,6 @@ def fit( self.input_dim = X_train.shape[1] self._stop_training = False - if eval_set is None: - eval_set = [] - if loss_fn is None: self.loss_fn = self._default_loss else: @@ -154,9 +152,8 @@ def fit( self._set_network() self._set_metrics(eval_metric, eval_names) - self._set_callbacks(callbacks) self._set_optimizer() - self._set_scheduler() + self._set_callbacks(callbacks) # Call method on_train_begin for all callbacks self._callback_container.on_train_begin() @@ -383,9 +380,6 @@ def _train_batch(self, X, y): batch_logs["loss"] = loss.cpu().detach().numpy().item() - if self._scheduler is not None: - self._scheduler.step() - return batch_logs def _predict_epoch(self, name, loader): @@ -513,16 +507,33 @@ def _set_callbacks(self, custom_callbacks): List of callback functions. """ - # Setup default callbacks history and early stopping + # Setup default callbacks history, early stopping and scheduler + callbacks = [] self.history = History(self, verbose=self.verbose) - early_stopping = EarlyStopping( - early_stopping_metric=self.early_stopping_metric, - is_maximize=( - self._metrics[-1]._maximize if len(self._metrics) > 0 else None - ), - patience=self.patience, - ) - callbacks = [self.history, early_stopping] + callbacks.append(self.history) + if (self.early_stopping_metric is not None) and (self.patience > 0): + early_stopping = EarlyStopping( + early_stopping_metric=self.early_stopping_metric, + is_maximize=( + self._metrics[-1]._maximize if len(self._metrics) > 0 else None + ), + patience=self.patience, + ) + callbacks.append(early_stopping) + else: + print("No early stopping will be performed, last training weights will be used.") + if self.scheduler_fn is not None: + # Add LR Scheduler call_back + is_batch_level = self.scheduler_params.pop("is_batch_level", False) + scheduler = LRSchedulerCallback( + scheduler_fn=self.scheduler_fn, + scheduler_params=self.scheduler_params, + optimizer=self._optimizer, + early_stopping_metric=self.early_stopping_metric, + is_batch_level=is_batch_level, + ) + callbacks.append(scheduler) + if custom_callbacks: callbacks.extend(custom_callbacks) self._callback_container = CallbackContainer(callbacks) @@ -534,14 +545,6 @@ def _set_optimizer(self): self.network.parameters(), **self.optimizer_params ) - def _set_scheduler(self): - """Setup scheduler.""" - self._scheduler = None - if self.scheduler_fn: - self._scheduler = self.scheduler_fn( - self._optimizer, **self.scheduler_params - ) - def _construct_loaders(self, X_train, y_train, eval_set): """Generate dataloaders for train and eval set. diff --git a/pytorch_tabnet/callbacks.py b/pytorch_tabnet/callbacks.py index 0e680f0f..51554c34 100644 --- a/pytorch_tabnet/callbacks.py +++ b/pytorch_tabnet/callbacks.py @@ -131,14 +131,15 @@ def on_epoch_end(self, epoch, logs=None): current_loss = logs.get(self.early_stopping_metric) if current_loss is None: return + loss_change = current_loss - self.best_loss max_improved = self.is_maximize and loss_change > self.tol min_improved = (not self.is_maximize) and (-loss_change > self.tol) if max_improved or min_improved: self.best_loss = current_loss + self.best_epoch = epoch self.wait = 1 self.best_weights = copy.deepcopy(self.trainer.network.state_dict()) - self.best_epoch = epoch else: if self.wait >= self.patience: self.stopped_epoch = epoch @@ -148,12 +149,10 @@ def on_epoch_end(self, epoch, logs=None): def on_train_end(self, logs=None): self.trainer.best_epoch = self.best_epoch self.trainer.best_cost = self.best_loss - final_weights = ( - self.best_weights - if self.best_weights is not None - else copy.deepcopy(self.trainer.network.state_dict()) - ) - self.trainer.network.load_state_dict(final_weights) + + if self.best_weights is not None: + self.trainer.network.load_state_dict(self.best_weights) + if self.stopped_epoch > 0: msg = f"\nEarly stopping occured at epoch {self.stopped_epoch}" msg += ( @@ -162,8 +161,11 @@ def on_train_end(self, logs=None): ) print(msg) else: - msg = f"Stop training because you reached max_epochs = {self.trainer.max_epochs}" + msg = (f"Stop training because you reached max_epochs = {self.trainer.max_epochs}" + + f" with best_epoch = {self.best_epoch} and " + + f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}") print(msg) + print("Best weights from best epoch are automatically used!") @dataclass @@ -230,3 +232,51 @@ def __repr__(self): def __str__(self): return str(self.epoch_metrics) + + +@dataclass +class LRSchedulerCallback(Callback): + """Wrapper for most torch scheduler functions. + + Parameters + --------- + scheduler_fn : torch.optim.lr_scheduler + Torch scheduling class + scheduler_params : dict + Dictionnary containing all parameters for the scheduler_fn + is_batch_level : bool (default = False) + If set to False : lr updates will happen at every epoch + If set to True : lr updates happen at every batch + Set this to True for OneCycleLR for example + """ + + scheduler_fn: Any + optimizer: Any + scheduler_params: dict + early_stopping_metric: str + is_batch_level: bool = False + + def __post_init__(self, ): + self.is_metric_related = hasattr(self.scheduler_fn, + "is_better") + self.scheduler = self.scheduler_fn(self.optimizer, + **self.scheduler_params) + super().__init__() + + def on_batch_end(self, batch, logs=None): + if self.is_batch_level: + self.scheduler.step() + else: + pass + + def on_epoch_end(self, epoch, logs=None): + current_loss = logs.get(self.early_stopping_metric) + if current_loss is None: + return + if self.is_batch_level: + pass + else: + if self.is_metric_related: + self.scheduler.step(current_loss) + else: + self.scheduler.step() diff --git a/pytorch_tabnet/metrics.py b/pytorch_tabnet/metrics.py index df200065..c6f6a7b3 100644 --- a/pytorch_tabnet/metrics.py +++ b/pytorch_tabnet/metrics.py @@ -84,7 +84,7 @@ def get_metrics_by_names(cls, names): available_names = [metric()._name for metric in available_metrics] metrics = [] for name in names: - assert name in available_names, f"{name} is not available" + assert name in available_names, f"{name} is not available, choose in {available_names}" idx = available_names.index(name) metric = available_metrics[idx]() metrics.append(metric)