#  Dependancies

In [None]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix, roc_curve, auc
from typing import List
import seaborn as sns
import ruamel.yaml
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

config_name = './model/config.yaml'

In [None]:
with open(config_name, 'r') as stream:
    try:
        yaml = ruamel.yaml.YAML()
        config = yaml.load(stream)
    except yaml.YAMLError as exc:
        print(exc)

In [None]:
#  function make_confusion_matrix is taken from snippets are taken from https://github.com/DTrimarchi10/confusion_matrix/blob/master/cf_matrix.py and amended

def make_confusion_matrix(cf=None,
                          group_names=None,
                          categories='auto',
                          count=True,
                          percent=True,
                          cbar=False,
                          xyticks=True,
                          xyplotlabels=True,
                          sum_stats=True,
                          figsize=None,
                          cmap=None,
                          title=None,
                          pngName='image'):
    '''
    This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization.
    Arguments
    ---------
    cf:            confusion matrix to be passed in
    group_names:   List of strings that represent the labels row by row to be shown in each square.
    categories:    List of strings containing the categories to be displayed on the x,y axis. Default is 'auto'
    count:         If True, show the raw number in the confusion matrix. Default is True.
    normalize:     If True, show the proportions for each category. Default is True.
    cbar:          If True, show the color bar. The cbar values are based off the values in the confusion matrix.
                   Default is True.
    xyticks:       If True, show x and y ticks. Default is True.
    xyplotlabels:  If True, show 'True Label' and 'Predicted Label' on the figure. Default is True.
    sum_stats:     If True, display summary statistics below the figure. Default is True.
    figsize:       Tuple representing the figure size. Default will be the matplotlib rcParams value.
    cmap:          Colormap of the values displayed from matplotlib.pyplot.cm. Default is 'Blues'
                   See http://matplotlib.org/examples/color/colormaps_reference.html
                   
    title:         Title for the heatmap. Default is None.
    '''

    # CODE TO GENERATE TEXT INSIDE EACH SQUARE
    blanks = ['' for i in range(cf.size)]

    if group_names and len(group_names)==cf.size:
        group_labels = ["{}\n".format(value) for value in group_names]
    else:
        group_labels = blanks

    if count:
        group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()]
    else:
        group_counts = blanks

    if percent:
        group_percentages = ["{0:.2%}".format(value) for value in cf.flatten()/np.sum(cf)]
    else:
        group_percentages = blanks

    box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels,group_counts,group_percentages)]
    box_labels = np.asarray(box_labels).reshape(cf.shape[0],cf.shape[1])

    # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS
    if sum_stats:
        #Accuracy is sum of diagonal divided by total observations
        accuracy  = np.trace(cf) / float(np.sum(cf))

        #if it is a binary confusion matrix, show some more stats
        if len(cf)==2:
            #Metrics for Binary Confusion Matrices
            precision = cf[1,1] / sum(cf[:,1])
            recall    = cf[1,1] / sum(cf[1,:])
            f1_score  = 2*precision*recall / (precision + recall)
            stats_text = "\n\nAccuracy={:0.3f}\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(
                accuracy,precision,recall,f1_score)
            
        else:
            stats_text = "\n\nAccuracy={:0.3f}".format(accuracy)
    else:
        stats_text = ""

    # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS
    if figsize==None:
        #Get default figure size if not set
        figsize = plt.rcParams.get('figure.figsize')

    if xyticks==False:
        #Do not show categories if xyticks is False
        categories=False

    # MAKE THE HEATMAP VISUALIZATION
    plt.figure(figsize=figsize)
    sns.heatmap(cf,annot=box_labels,fmt="",cmap=cmap,cbar=cbar,xticklabels=categories,yticklabels=categories)

    if xyplotlabels:
        plt.ylabel('True label')
        plt.xlabel('Predicted label' + stats_text)
    else:
        plt.xlabel(stats_text)
    
    if title:
        plt.title(title)

    plt.savefig(pngName + '.png', dpi=300, bbox_inches='tight')

#  Helper function for display_hex_colours
def apply_formatting(col, hex_colors):
    for hex_color in hex_colors:
        if col.name == hex_color:
            return [f'background-color: {hex_color}' for c in col.values]

#  Displays colours of your interest to have it on the screen to compare    
def display_hex_colors(hex_colors: List[str]):
    df = pd.DataFrame(hex_colors).T
    df.columns = hex_colors
    df.iloc[0,0:len(hex_colors)] = ""
    display(df.style.apply(lambda x: apply_formatting(x, hex_colors)))

#  Draws ROC curve for any model, taking data from csv of 'pred-model.csv' in ./datasets/model_validation/ folder
#  Also plots perfect/random classifiers to compare
def drawROC(df,
            path_to_test,
            pngName = 'ROC image'):

    fpr,tpr,threshold = roc_curve(df['Label'], df['prob'])
    roc_auc =auc(fpr, tpr)

    plt.plot([0,1],[0,1], 'k--', label='Random Guess', linewidth=3)
    plt.plot([0,1],[1,1], 'k-', label='Perfect Estimator')
    plt.plot([0,0],[0,1], 'k-')
    plt.plot(fpr, tpr, color='#054d54', label='ROC curve', linewidth=3)
    plt.xlabel('False Positive Rate or (1 - Specifity)')
    plt.ylabel('True Positive Rate or (Sensitivity)')
    plt.legend(loc='lower right')

    plt.savefig(pngName + '.png', dpi=300, bbox_inches='tight')

def drawConfusion(colour='#002733',
                  pngName='Confusion Matrix',
                  cf=None):
    pallete = sns.light_palette("#002733", as_cmap=True, n_colors=4)
    
    labels = ['TP','FN','FP','TN']
    categories = ['1', '0']
    make_confusion_matrix(cf, 
                        group_names=labels,
                        categories=categories, 
                        cbar=False,
                        cmap=pallete,
                        percent=True,
                        pngName=pngName)
    
#  Displays every descriptive stat of a model which is 
def descriptiveStats(df):
    accuracy = accuracy_score(df['Label'], df['y_hat'])
    precision = precision_score(df['Label'], df['y_hat'])
    recall = recall_score(df['Label'], df['y_hat'])
    f1 = f1_score(df['Label'], df['y_hat'])
    roc_auc = roc_auc_score(df['Label'], df['prob'])
    cf_matrix = confusion_matrix(df['Label'], df['y_hat'], labels=[1,0]) #Depends what inversion you want to have 1:1 to begin with or 0:0

    print('accuracy: ', accuracy)
    print('precision: ', precision)
    print('recall: ', recall)
    print('f1: ', f1)
    print('roc_auc: ', roc_auc)
    print('matrix: \n', cf_matrix)

    return (accuracy, precision, recall, f1, roc_auc, cf_matrix)

In [None]:
#  Enter your HEX numbers to have a color pallete in sight
hex_list = ['#002733', '#054d54', '#1b8489', '#ef9f8d', '#fccec0']
display_hex_colors(hex_list)

In [None]:
#  Reading Albumin 'prediction' data
path = './datasets/model_validation/Albumin/pred-albert-base-v2_5e-05_0.1372.csv'
df = pd.read_csv(path)

#  If you need only descriptive model stat, for instance confusion matrix
cf_matrix = descriptiveStats(df)[5]

In [None]:
#  Drawing ROC for Albumin
drawROC(df=df,
        path_to_model_test=path,
        pngName='Albumin ROC.png')

In [None]:
#  Confusion matrix formated in more informative way
drawConfusion(colour='#002733',
              pngName='Confusion Matrix',
              cf=cf_matrix)