In [1]:
import itertools
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix

def create_confusion_matrix(y_true, y_pred, classes=None, figsize=(10,10), 
                            text_size=15, normalized=False, save_figure=False):
  """
  Create confusion matrix comparing predictions and ground truth labels

  Parameters:
  y_true (int) = Array of truth labels
  y_pred (int) = Array of predictied labels 
  classes (str) = Array of class labels
  figsize (int) = Size of output figure
  text_size (int) = Size of output figure text
  normalized (bool) = Normalize values or not
  savefigure (bool) = Save confusion matrix to a file

  """
  #Create the confusion matrix
  confusion_matrix = confusion_matrix(y_true, y_pred)
  confusion_matrix_normalized = confusion_matrix.astype("float") / confusion_matrix.sum(axis=1)[:, np.newaxis]
  number_of_classes = confusion_matrix.shape[0]

  #Plot the figure
  fig, ax = plt.subplots(figsize=figsize)
  cax = ax.matshow(confusion_matrix, cmap = plt.cm.Blues)
  fig.colorbar(cax)

  #List classes 
  if classes:
    labels = classes
  else:
    labels = np.arange(confusion_matrix.shape[0])

  #Label the axes
  ax.set(title = "Confusion Matrix",
         xlabel = "Predicted Label",
         ylabel = "True Label",
         xticks = np.arange(number_of_classes),
         yticks = np.arange(number_of_classes),
         xticklabels=labels,
         yticklabels=labels)
  
  #Set X-Axis labels on bottom
  ax.xaxis.set_label_position("bottom")
  ax.xaxis.tick_bottom()

  #Set color thresholds
  threshold = (confusion_matrix.max() + confusion_matrix.min()) / 2.

  #Plot the text on each cell 
  for i, j in itertools.product(range(confusion_matrix[0]), range(confusion_matrix.shape[1])):
    if normalized:
      plt.text(j, i, f"{confusion_matrix[i, j]} ({confusion_matrix_normalized[i, j]*100:.1f}%)",
              horizontalalignment="center",
              color="white" if confusion_matrix[i, j] > threshold else "black",
              size=text_size)
    else:
      plt.text(j, i, f"{cm[i, j]}",
              horizontalalignment="center",
              color="white" if confusion_matrix[i, j] > threshold else "black",
              size=text_size)
      
    #Save figure to working directory
    if save_figure:
      fig.savefig("confusion_matrix.png")

    
