In [15]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels

def plot_confusion_matrix(true_labels, pred_labels, class_labels, normalize=False, verbose=False):
    """
    Computes and plots the confusion matrix of given model on provided data (as a dataloader). 
    May be set to normalize.
    """
    # compute confusion matrix

    cm = confusion_matrix(true_labels, pred_labels)
    print(cm)
    if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    # start plotting
    fig, ax = plt.subplots(figsize=(8,8))
    im = ax.imshow(cm, cmap=plt.cm.Greens)
    ax.figure.colorbar(im, ax=ax)
    # display and label all ticks, set titles
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=class_labels, yticklabels=class_labels,
           title="Normalized confusion matrix" if normalize else "Confusion matrix", 
           ylabel="True label",
           xlabel="Predicted label"
           )
    # rotate labels
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    
    # Add annotations in each cell
    fmt = '.2f' if normalize else 'd' # format based on normalize setting
    thresh = cm.max() / 2. # when to switch from black to white text
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    
    fig.tight_layout()
    return ax

In [1]:
def plot_accuracies(train_accs, test_accs, legends):
    num_epochs = len(train_accs)
    plt.figure(figsize=(6,6))
    plt.plot(np.arange(1,num_epochs+1), train_accs,'-')
    plt.plot(np.arange(1,num_epochs+1), test_accs,'-')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(legends)
    plt.title('Accuracy/Epoch')
    plt.show()

In [None]:
def plot_loss(train_loss, val_loss, legends=["Train", "Val"]):
    num_epochs = len(train_loss)
    plt.figure(figsize=(6,6))
    plt.plot(np.arange(1,num_epochs+1), train_loss,'-')
    plt.plot(np.arange(1,num_epochs+1), val_loss,'-')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(legends)
    plt.title('Loss/Epoch')
    plt.show()