# Libraries

In [1]:
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
# Basic libraries
# 
import pandas as pd
import numpy  as np

# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
# Visualization library
# 
from   plotly import graph_objects as go

# Function to Visualise Confusion Matrix Using Sankey Chart

In [21]:
def tranform_confusion_matrix(confusion_matrix, targets_list=None):
    """
    function to transform confusion matrix to dataframe needed to plot Sankey chart
    
    returns a dataframe and list of unique labels for Sankey chart nodes
    
    Parameters
    --------------
    confusion_matrix : numpy.ndarray
        The confusion matrix to be visualised
    target_list : {'list', 'numpy.ndarray'}
        List of unique classes
        
    """
    
    # create a dataframe
    
    if targets_list is None:
        df = pd.DataFrame(data=confusion_matrix, 
                          index=[f"True Class-{i+1}" for i in range(confusion_matrix.shape[0])],
                          columns=[f"Predicted Class-{i+1}" for i in range(confusion_matrix.shape[0])])
    else:
        df = pd.DataFrame(data=confusion_matrix, 
                          index=[f"True {i}" for i in targets_list],
                          columns=[f"Predicted {i}" for i in targets_list])

    # restructre the dataframe
    df = df.stack().reset_index()

    # rename the default column names
    df.rename(columns={'level_0':'source', 'level_1':'target', 0:'value'}, inplace=True)

    # add new column for colour
    # here rgba(211,255,216,0.6) indicates green colour whereas rgba(245,173,168,0.6) is red colour
    # green colour illustrates correct predictions and red colour is for incorrect predictions
    df["colour"] = df.apply(lambda x: 
                              "rgba(211,255,216,0.6)" if x.source.split()[1:] == x.target.split()[1:] 
                               else "rgba(245,173,168,0.6)", axis=1)

    # extract unique values from source and target columns
    labels = pd.concat([df.source, df.target]).unique()

    # get indices of the above unique values
    labels_indices = {label:index for index, label in enumerate(labels)}
    labels_indices

    # map the source and target column using the above indices
    df[["source", "target"]] = df[["source", "target"]].applymap(lambda x: labels_indices[x])

    # create a column for tooltip
    df["tooltip"] = df.apply(lambda x:
                             f"{x['value']} {' '.join(labels[x['source']].split()[1:])} instances correctly classified as {' '.join(labels[x['target']].split()[1:])}" 
                             if x['colour']=='rgba(211,255,216,0.6)'

                             else 
                             f"{x['value']} {' '.join(labels[x['source']].split()[1:])} instances misclassified as {' '.join(labels[x['target']].split()[1:])}", axis=1)

    return df, labels
    

def plot_sankey_for_confusion_matrix(df, labels, params = None):
    
    """
    plots sankey diagram from given dataframe and labels
    """

    if (params is None):
        params = {'font_size': 13, 'width': 510, 'height': 450}
    
    # plot figure
    fig = go.Figure(data=[go.Sankey(
    
    node = dict(
      pad = 20,
      thickness = 20,
      line = dict(color = "black", width = 1.0),
      label = labels,
      
      # this template will be used to display text when hovering over nodes  
      hovertemplate = "%{label} has total %{value:d} instances<extra></extra>"
    ),
    link = dict(
      source = df.source, 
      target = df.target,
      value = df.value,
      color = df.colour,
      customdata = df['tooltip'], 
        
      # this template will be used to display text when hovering over the links  
      hovertemplate = "%{customdata}<extra></extra>"  
    ))])

    fig.update_layout(title_text = "Confusion Matrix Visualisation Using Sankey Diagram", 
                      font_size  = params['font_size'],
                      width      = params['width'], 
                      height     = params['height']
                      )
    
    return fig

In [22]:
confusion_matrix = np.array([[10, 6, 1],
                             [2, 12, 3],
                             [4, 7, 45]])

params = {'font_size': 14, 
          'width':     800, 
          'height':    600}

df, labels = tranform_confusion_matrix(confusion_matrix, ['Class-1', 'Class-2', 'Class-3']) 
plot_sankey_for_confusion_matrix(df, labels, params)