# Confusion matrix decorators

Confusion matrix with a number of graphical decorators (introduced as "Confusion Matrix Chart" in https://towardsdatascience.com/the-confusion-matrix-visualized-e778584c8834 by Søren Laursen https://www.linkedin.com/in/soren02laursen/)

The concept is from the article, the code is mainly from the article but with several modifications 

## Imports

In [1]:
import pandas as pd
import numpy as np
import altair as alt #documentation for altair:  https://altair-viz.github.io/

## Odss Ratio calculations

First here I define the functions to get the ratios 

In [2]:
#DIAGNOSTIC ODDS RATIO = ratio of the odds of the classification being positive if the subject is actually positive
#relative to the odds of the classification being positive if the subject is not actually positive 

#a = TP,Yy
#b = FN, Yn
#c = FP, Ny
#d = TN, Nn


#define the dictionary with the alpha values for the confidence interval
alpha_value_dict = {70:1.04, 75:1.15, 80:1.28, 85:1.44, 90:1.64 , 95:1.96 , 98:2.33 , 99:2.58}


def odds_ratio(TP, FN, FP, TN):
    if TP==0 or np.isnan(TP) or FN==0 or np.isnan(FN) or FP==0 or np.isnan(FP) or TN==0 or np.isnan(TN):
        TP = 0.5 if np.isnan(TP) else TP + 0.5
        FN = 0.5 if np.isnan(FN) else FN + 0.5
        FP = 0.5 if np.isnan(FP) else FP + 0.5
        TN = 0.5 if np.isnan(TN) else TN + 0.5

    return (TP*TN)/(FP*FN) #odds ratio is (TP*TN)/(FN*FP)



def odds_ratio_lower_ci(OR, TP, FN, FP, TN, confidence_level):
    if TP==0 or np.isnan(TP) or FN==0 or np.isnan(FN) or FP==0 or np.isnan(FP) or TN==0 or np.isnan(TN):
        TP = 0.5 if np.isnan(TP) else TP + 0.5
        FN = 0.5 if np.isnan(FN) else FN + 0.5
        FP = 0.5 if np.isnan(FP) else FP + 0.5
        TN = 0.5 if np.isnan(TN) else TN + 0.5

    return np.exp(np.log(OR) - alpha_value_dict[confidence_level]*np.sqrt(1/TP + 1/FN + 1/FP + 1/TN))


def odds_ratio_upper_ci(OR, TP, FN, FP, TN, confidence_level):
    if TP==0 or np.isnan(TP) or FN==0 or np.isnan(FN) or FP==0 or np.isnan(FP) or TN==0 or np.isnan(TN):
        TP = 0.5 if np.isnan(TP) else TP + 0.5
        FN = 0.5 if np.isnan(FN) else FN + 0.5
        FP = 0.5 if np.isnan(FP) else FP + 0.5
        TN = 0.5 if np.isnan(TN) else TN + 0.5

    return np.exp(np.log(OR) + alpha_value_dict[confidence_level]*np.sqrt(1/TP + 1/FN + 1/FP + 1/TN))


#INTERPRETATION OF DOR
#The diagnostic odds ratio ranges from zero to infinity, although for useful tests it is greater than one, 
#and higher diagnostic odds ratios are indicative of better test performance.

#Diagnostic odds ratios less than one indicate that the test can be improved by simply inverting the outcome 
#of the test – the test is in the wrong direction, 

#while a diagnostic odds ratio of exactly one means that the test is equally likely to predict a positive 
#outcome whatever the true condition – the test gives no information.


## Derive confusion matrix data

Then here I create the dataframe with all the relevant information to then build the visualization

In [6]:
def confusion_matrix_data(Yy, Yn, Ny, Nn):
    CM = pd.DataFrame({'label':['Yy','Yn','Ny','Nn', 
                                'y|Y','n|Y','n|N','y|N',
                                'Y|y','N|y','N|n','Y|n',
                                'Y','N','y','n',
                                'Y*','N*','y*','n*',
                                'OR_lci90','OR_lci95','OR_lci99', #lower confidence intervals at different alpha
                                'OR',                             #DOR
                                'OR_uci90','OR_uci95','OR_uci99', #upper confidence interval at different alpha
                                '1',
                                'ACC','ACC-','F1','F1-'], 
                       'value':[Yy,  Yn,  Ny,  Nn,   
                                0 if Yy+Yn==0 else Yy/(Yy+Yn), 
                                0 if Yy+Yn==0 else Yn/(Yy+Yn), 
                                0 if Ny+Nn==0 else Nn/(Ny+Nn), 
                                0 if Ny+Nn==0 else Ny/(Ny+Nn),
                                0 if Yy+Ny==0 else Yy/(Yy+Ny), 
                                0 if Yy+Ny==0 else Ny/(Yy+Ny), 
                                0 if Yn+Nn==0 else Nn/(Yn+Nn), 
                                0 if Yn+Nn==0 else Yn/(Yn+Nn),
                                Yy+Yn, Ny+Nn, Yy+Ny, Yn+Nn, 
                                (Yy+Yn)/(Yy+Yn+Ny+Nn), (Ny+Nn)/(Yy+Yn+Ny+Nn), 
                                (Yy+Ny)/(Yy+Yn+Ny+Nn), (Yn+Nn)/(Yy+Yn+Ny+Nn),
                                odds_ratio_lower_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 90), 
                                odds_ratio_lower_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 95), 
                                odds_ratio_lower_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 99), 
                                odds_ratio(Yy, Yn, Ny, Nn), 
                                odds_ratio_upper_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 90), 
                                odds_ratio_upper_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 95), 
                                odds_ratio_upper_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 99), 
                                1,
                                (Yy+Nn)/(Yy+Yn+Ny+Nn), (Yn+Ny)/(Yy+Yn+Ny+Nn),
                                0 if Yy==0 or Yy+Yn==0 or Yy+Ny==0 else 2 * ((Yy/(Yy+Yn)) * (Yy/(Yy+Ny))) / ((Yy/(Yy+Yn)) + (Yy/(Yy+Ny))),
                                1 if Yy==0 or Yy+Yn==0 or Yy+Ny==0 else 1 - (2 * ((Yy/(Yy+Yn)) * (Yy/(Yy+Ny))) / ((Yy/(Yy+Yn)) + (Yy/(Yy+Ny))))
                               ]})


    colours = alt.Scale(domain=['Yy','Ny','Yn','Nn', 
                                'y|Y','n|Y','n|N','y|N',
                                'Y|y','N|y','N|n','Y|n',
                                'Y','N','y','n',
                                'Y*','N*',
                                'y*','n*',
                                
                                'OR_lci90','OR_lci95','OR_lci99',
                                'OR',
                                'OR_uci90','OR_uci95','OR_uci99', 
                                
                                '1',
                                
                                'ACC','ACC-','F1','F1-'], 
                        range =['snow', 'snow','snow', 'snow',
                                'forestgreen','palegreen','powderblue','cadetblue',
                                'forestgreen','cadetblue','powderblue','palegreen',
                                'goldenrod','gold','goldenrod','gold',
                                'goldenrod','gold',
                                'goldenrod','gold',
                                'dodgerblue','deepskyblue','lightskyblue','blue',
                                'dodgerblue','deepskyblue','lightskyblue','darkorange',
                                'goldenrod','gold','goldenrod','gold'
                               ])
    return CM, colours


In [7]:
confusion_matrix_data(168,67,71,776)

(       label       value
 0         Yy  168.000000
 1         Yn   67.000000
 2         Ny   71.000000
 3         Nn  776.000000
 4        y|Y    0.714894
 5        n|Y    0.285106
 6        n|N    0.916175
 7        y|N    0.083825
 8        Y|y    0.702929
 9        N|y    0.297071
 10       N|n    0.920522
 11       Y|n    0.079478
 12         Y  235.000000
 13         N  847.000000
 14         y  239.000000
 15         n  843.000000
 16        Y*    0.217190
 17        N*    0.782810
 18        y*    0.220887
 19        n*    0.779113
 20  OR_lci90   20.055298
 21  OR_lci95   18.869868
 22  OR_lci99   16.768806
 23        OR   27.405508
 24  OR_uci90   37.449549
 25  OR_uci95   39.802178
 26  OR_uci99   44.789225
 27         1    1.000000
 28       ACC    0.872458
 29      ACC-    0.127542
 30        F1    0.708861
 31       F1-    0.291139,
 Scale({
   domain: ['Yy', 'Ny', 'Yn', 'Nn', 'y|Y', 'n|Y', 'n|N', 'y|N', 'Y|y', 'N|y', 'N|n', 'Y|n', 'Y', 'N', 'y', 'n', 'Y*', 'N*', 'y*', 'n

## Create confusion matrix chart

In [69]:
def cf_v_bar(CM, colours, label_list, sort_order, w_factor, h_factor, sf):
    bar = alt.Chart(CM.loc[CM['label'].isin(label_list)]).mark_bar(size=w_factor*sf).encode(
        y=alt.Y('sum(value)', stack='normalize', title=None, axis=None),
        color=alt.Color('label', scale = colours, legend=None),
        order=alt.Order('label', sort=sort_order),
        tooltip=['value']
    ).properties(width=w_factor*sf, height=h_factor*sf) 
    
    return bar

def cf_h_bar(CM, colours, label_list, sort_order, w_factor, h_factor, sf):
    bar = alt.Chart(CM.loc[CM['label'].isin(label_list)]).mark_bar(size=h_factor*sf).encode(
        x=alt.X('sum(value)', stack='normalize', title=None, axis=None),
        color=alt.Color('label', scale = colours, legend=None),
        order=alt.Order('label', sort=sort_order),
        tooltip=['value']
    ).properties(width=w_factor*sf, height=h_factor*sf) 
    
    return bar


def cf_text(CM, label, format, font_size, w_factor, dy_factor, sf):
    text = alt.Chart(CM.loc[CM['label']==label]).mark_text(fontSize=font_size, color='black').encode(
        text=alt.Text('sum(value)', format=format)
    ).properties(width=w_factor*sf, height=w_factor*sf) 

    return text


def confusion_matrix_chart(Yy, Yn, Ny, Nn):
    
    # Define the scaling factor
    sf = 15
    
    
    # Compute the relevant data 
    CM, colours = confusion_matrix_data(Yy, Yn, Ny, Nn)
    
    
    # FIRST ROW

    text_Yy = cf_text(CM, label='Yy', format='.0f', font_size=36, 
                      w_factor=10, dy_factor=5, sf=sf)

    bar_Y = cf_v_bar(CM, colours,
                     label_list=['n|Y','y|Y'], sort_order='descending', 
                     w_factor=2, h_factor=10, sf=sf)
    
    text_Yn = cf_text(CM, label='Yn', format='.0f', font_size=36, 
                      w_factor=10, dy_factor=5, sf=sf)

    # SECOND ROW 
    
    bar_y = cf_h_bar(CM, colours,
                     label_list=['Y|y','N|y'], sort_order='ascending', 
                     w_factor=13, h_factor=2, sf=sf)
    
    bar_a = cf_v_bar(CM, colours,
                     label_list=['ACC','ACC-'], sort_order='descending', 
                     w_factor=3, h_factor=2, sf=sf)
    
    bar_n = cf_h_bar(CM, colours,
                     label_list=['N|n','Y|n'], sort_order='ascending', 
                     w_factor=13, h_factor=2, sf=sf)
    
    # THIRD ROW
    
    text_Ny = cf_text(CM, label='Ny', format='.0f', font_size=36, 
                      w_factor=10, dy_factor=5, sf=sf)

    bar_N = cf_v_bar(CM, colours,
                     label_list=['n|N','y|N'], sort_order='ascending', 
                     w_factor=2, h_factor=10, sf=sf)
    
    text_Nn = cf_text(CM, label='Nn', format='.0f', font_size=36, 
                      w_factor=10, dy_factor=5, sf=sf)

    
    # FRAMING BARS
    
    # Left bar
    bar_L = cf_v_bar(CM, colours,
                     label_list=['Y*','N*'], sort_order='ascending', 
                     w_factor=2, h_factor=25, sf=sf)
    
    # Top left corner bar
    bar_0 = cf_v_bar(CM, colours,
                     label_list=['F1','F1-'], sort_order='ascending', 
                     w_factor=2, h_factor=2, sf=sf)
    
    # Top bar
    bar_T = cf_h_bar(CM, colours,
                     label_list=['y*','n*'], sort_order='descending', 
                     w_factor=25, h_factor=2, sf=sf)
    
    # Top right corner text
    text_R = cf_text(CM, label='OR', format='.1f', font_size=12, w_factor=2, dy_factor=1, sf=sf)

    # Right bar
    bar_R = alt.Chart(CM.loc[
        CM['label'].isin(['1','OR_lci90','OR_lci95','OR_lci99','OR','OR_uci90','OR_uci95','OR_uci99'])]
                     ).mark_circle(opacity=0.8, stroke='black', strokeWidth=1, size=10*sf).encode(
        y=alt.Y('value', title=None, axis=None),
        color=alt.Color('label', scale = colours, legend=None),
        order=alt.Order('label', sort='descending'),
        tooltip=['value']
    ).properties(width=2*sf, height=33*sf) 


    # BUILD COMBINED CHART
    
    return (bar_0 | bar_T | text_R) & (bar_L | ( ( (text_Yy) | bar_Y | text_Yn) & (bar_y | bar_a | bar_n) & (text_Ny | bar_N | text_Nn) ) 
     | bar_R )



## Instantiate the confusion matrix chart

In [70]:
#TP        FN        FP        TN
Yy = 168 ; Yn = 67 ; Ny = 71 ; Nn = 776

confusion_matrix_chart(Yy, Yn, Ny, Nn)

Poi se arrivo effettivamente a fare il tool con i vari widget può diventare tutto più interattivo anche nella lettura delle metriche (cioè non solo con la label posizionando il cursore ma evidenziando e cose così)