From f96795ff46e02af4ca7c0ed6648276f4e4b788b0 Mon Sep 17 00:00:00 2001 From: Sebastien Fischman Date: Wed, 11 Mar 2020 15:05:19 +0100 Subject: [PATCH] fix: remove dead code for plots --- pytorch_tabnet/utils.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index ab2bbec2..94ed2182 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -1,9 +1,7 @@ from torch.utils.data import Dataset -import matplotlib.pyplot as plt from torch.utils.data import DataLoader, WeightedRandomSampler import torch import numpy as np -from IPython.display import clear_output class TorchDataset(Dataset): @@ -101,42 +99,6 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, batch_size): return train_dataloader, valid_dataloader -def plot_losses(losses_train, losses_valid, metrics_train, metrics_valid): - """ - Plot train and validation losses. - - Parameters - ---------- - losses_train : list - list of train losses per epoch - losses_valid : list - list of valid losses per epoch - metrics_train : list - list of train metrics per epoch - metrics_valid : list - list of valid metrics per epoch - Returns - ------ - plot - """ - clear_output() - plt.figure(figsize=(15, 5)) - plt.subplot(1, 2, 1) - plt.plot(range(len(losses_train)), losses_train, label='Train') - plt.plot(range(len(losses_valid)), losses_valid, label='Valid') - plt.grid() - plt.title('Losses') - plt.legend() - - plt.subplot(1, 2, 2) - plt.plot(range(len(metrics_train)), metrics_train, label='Train') - plt.plot(range(len(metrics_valid)), metrics_valid, label='Valid') - plt.grid() - plt.title('Training Metrics') - plt.legend() - plt.show() - - def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim): """ This is a computational trick.