In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels

def plot_confusion_matrix(predicted_labels, true_labels, classes, normalize=False, verbose=False):
    """
    Computes and plots the confusion matrix of given predictions provided the true labels and classes. 
    May be set to normalize.
    """
    cm = confusion_matrix(true_labels, predicted_labels)
    if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    fig, ax = plt.subplots(figsize=(8,8))
    im = ax.imshow(cm, cmap=plt.cm.Greens)
    ax.figure.colorbar(im, ax=ax)
    # display and label all ticks, set titles
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=classes, yticklabels=classes,
           title="Normalized confusion matrix" if normalize else "Confusion matrix", 
           ylabel="True label",
           xlabel="Predicted label"
           )
    # rotate labels
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    
    # Add annotations in each cell
    fmt = '.2f' if normalize else 'd' # format based on normalize setting
    thresh = cm.max() / 2. # when to switch from black to white text
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    
    fig.tight_layout()
    return ax