In [1]:
import numpy, pandas, pathlib

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

from skops.io import load

pathlib.Path('pdf/figure-3').mkdir(exist_ok=True)

In [3]:
best_model = {}

for model in ['LR', 'NN', 'XB']:
    best_model[model] = load('models/'+model.lower()+'.skops', trusted=True)

# load the training dataset
X={}
Y={}
Z={}
X['train']={}
Y['train']={}
Z['train']={}
with open('data/ds-train.npy', 'rb') as f:
    Y['train']['input'] = numpy.load(f)
    X['train']['input'] = numpy.load(f)
    Z['train']['input'] = numpy.load(f, allow_pickle=True)

very_major_errors = {}
major_errors = {}

# load the results for the training dataset
results = pandas.read_csv('results-training.csv')
results

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters,diagnostic_odds_ratio_mean,diagnostic_odds_ratio_std
0,LR,train,78.8,6.200717,84.357576,5.1144,82.48406,7.276335,180,38,49,197,"{""C"": 1.0, ""penalty"": ""l1"", ""solver"": ""libline...",19.044039,0
1,NN,train,78.916667,5.14795,83.682792,5.282488,81.760028,7.781499,218,0,0,246,"{""activation"": ""logistic"", ""alpha"": 0.01, ""hid...",inf,0
2,XB,train,79.216667,5.468115,85.781926,4.776355,83.502833,7.88022,192,26,43,203,"{""learning_rate"": 0.05, ""max_depth"": 4, ""min_c...",34.862254,0
3,SP,train,97.96748,,95.412844,,,,208,10,5,241,,1002.56,0


In [4]:
for metric in ['sensitivity', 'specificity', 'roc_auc', 'diagnostic_odds_ratio']:
    colour='#888888'
    fig = plt.figure(figsize=(2.2, 3.5))
    axes = plt.gca()
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['left'].set_visible(False)
    axes.get_yaxis().set_visible(False)
    x=range(4)
    x=results[(results.model!='SP')].model
    y=results[(results.model!='SP')][metric+'_mean']
    e=results[(results.model!='SP')][metric+'_std']
    axes.set_ylim([0,100])
    axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

    if e.sum()>0:
        axes.errorbar(x,y,yerr=e, fmt='.',color=colour,linewidth=2)
        for (i,j,k) in zip(x,y,e):
            axes.text(i,j+k+2,'%.1f' % j,ha='center',color=colour)

    else:
        for (i,j) in zip(x,y):
            axes.text(i,j+2,'%.1f' % j,ha='center',color=colour)

    fig.savefig('pdf/figure-3/fig-3-train-'+metric+'.pdf', bbox_inches="tight")
    plt.close()

posx and posy should be finite values
posx and posy should be finite values


In [5]:
for metric in ['sensitivity', 'specificity']:
    colour=['#888888','#888888','#888888','pink']
    fig = plt.figure(figsize=(3.2, 3.5))
    axes = plt.gca()
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['left'].set_visible(False)
    axes.get_yaxis().set_visible(False)
    x=results.model
    y=results[metric+'_mean']
    e=results[metric+'_std']
    axes.set_ylim([0,100])
    axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

    if e.sum()>0:
        axes.errorbar(x,y,yerr=e, fmt='.',color='#888888',linewidth=2)
        for (i,j,k,c) in zip(x,y,e,colour):
            if k>0:
                axes.text(i,j+k+2,'%.1f' % j,ha='center',color=c)
            else:
                axes.text(i,j+2,'%.1f' % j,ha='center',color=c)

    else:
        for (i,j,c) in zip(x,y,colour):                
            axes.text(i,j+2,'%.1f' % j,ha='center',color=c)

    fig.savefig('pdf/figure-3/fig-3-train-'+metric+'-suspectpza.pdf', bbox_inches="tight")
    plt.close()

In [6]:
for idx,row in results.iterrows():

    fig = plt.figure(figsize=(1.5, 1.5))
    axes = plt.gca()

    axes.add_patch(Rectangle((0,0),1,1,fc='#e41a1c',alpha=0.7))
    axes.add_patch(Rectangle((0,1),1,1,fc='#4daf4a',alpha=0.7))
    axes.add_patch(Rectangle((1,1),1,1,fc='#fc9272',alpha=0.7))
    axes.add_patch(Rectangle((1,0),1,1,fc='#4daf4a',alpha=0.7))

    axes.set_xlim([0,2])
    axes.set_ylim([0,2])

    axes.set_xticks([0.5,1.5],labels=['R','S'])
    axes.set_yticks([0.5,1.5],labels=['S','R'])

    axes.text(0.5,0.5,row['FN'],ha='center',va='center')
    axes.text(1.5,0.5,row['TN'],ha='center',va='center')
    axes.text(0.5,1.5,row['TP'],ha='center',va='center')
    axes.text(1.5,1.5,row['FP'],ha='center',va='center')

    fig.savefig('pdf/figure-3/truthtable-'+row['dataset']+'-'+row['model']+'.pdf', bbox_inches='tight')
    plt.close()