diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index ab2bbec2..94ed2182 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -1,9 +1,7 @@ from torch.utils.data import Dataset -import matplotlib.pyplot as plt from torch.utils.data import DataLoader, WeightedRandomSampler import torch import numpy as np -from IPython.display import clear_output class TorchDataset(Dataset): @@ -101,42 +99,6 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, batch_size): return train_dataloader, valid_dataloader -def plot_losses(losses_train, losses_valid, metrics_train, metrics_valid): - """ - Plot train and validation losses. - - Parameters - ---------- - losses_train : list - list of train losses per epoch - losses_valid : list - list of valid losses per epoch - metrics_train : list - list of train metrics per epoch - metrics_valid : list - list of valid metrics per epoch - Returns - ------ - plot - """ - clear_output() - plt.figure(figsize=(15, 5)) - plt.subplot(1, 2, 1) - plt.plot(range(len(losses_train)), losses_train, label='Train') - plt.plot(range(len(losses_valid)), losses_valid, label='Valid') - plt.grid() - plt.title('Losses') - plt.legend() - - plt.subplot(1, 2, 2) - plt.plot(range(len(metrics_train)), metrics_train, label='Train') - plt.plot(range(len(metrics_valid)), metrics_valid, label='Valid') - plt.grid() - plt.title('Training Metrics') - plt.legend() - plt.show() - - def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim): """ This is a computational trick.