# Helper function: Plot confusion matrix

In [None]:
def plot_confusion_matrix(y_pred, y_true, n_classes, labels, title, save = False, filepath = "", filename = ""):
    '''
    Input: prediction and true test labels
           title to specify the axis title for this confusion matrix
           labels: list of labels corresponding to the n_classes
           if save = True, then need to specify filepath and filename in string format
    Output: confusion matrix plot
    '''
    import matplotlib
    import pandas as pd
    import numpy as np
    from matplotlib import pyplot as plt
    from matplotlib import cm
    from sklearn.metrics import confusion_matrix
    plt.style.use('ggplot')
    plt.rcParams["font.weight"] = "bold"
    plt.rcParams["axes.labelweight"] = "bold"
    plt.rcParams["axes.titleweight"] = "bold"
    my_cmap = cm.get_cmap('Blues')
    fig, ax = plt.subplots(1,1, figsize = (8,8), dpi = 320)
    conf_mx = confusion_matrix(y_true, y_pred)
    ax.matshow(conf_mx, cmap = my_cmap)
    # Loop over data dimensions and create text annotations.
    for i in range(n_classes):
        for j in range(n_classes):
            text = ax.text(j, i, conf_mx[i, j],
                    ha = "center", 
                    va = "center", 
                    color = "black"
                )
    ax.grid(False)
    ax.set_xticks(np.arange(n_classes), labels = labels)
    ax.set_yticks(np.arange(n_classes), labels = labels)
    ax.set_title("Confusion Matrix: " + title)
    plt.show()
    if save:
        fig.savefig(filepath + filename + ".png", format = "png")
    pass