In [143]:
import numpy, pandas, pathlib

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

from skops.io import load

from sklearn.utils import resample

from misc import construct_line

pandas.options.display.max_columns=100

pathlib.Path('pdf/figure-3').mkdir(exist_ok=True)

In [130]:
best_model = {}
X={}
Y={}
Z={}
very_major_errors = {}
major_errors = {}

In [131]:
for model in ['LR', 'NN', 'XB']:
    best_model[model] = load('models/'+model.lower()+'.skops', trusted=True)

# load the training dataset

X['train']={}
Y['train']={}
Z['train']={}
with open('data/ds-train.npy', 'rb') as f:
    Y['train']['input'] = numpy.load(f)
    X['train']['input'] = numpy.load(f)
    Z['train']['input'] = numpy.load(f, allow_pickle=True)

# load the results for the training dataset
results = pandas.read_csv('results-training.csv')
results

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters,diagnostic_odds_ratio_mean,diagnostic_odds_ratio_std
0,LR,train,78.8,6.200717,84.357576,5.1144,82.48406,7.276335,180,38,49,197,"{""C"": 1.0, ""penalty"": ""l1"", ""solver"": ""libline...",19.044039,0
1,NN,train,78.916667,5.14795,83.682792,5.282488,81.760028,7.781499,218,0,0,246,"{""activation"": ""logistic"", ""alpha"": 0.01, ""hid...",inf,0
2,XB,train,79.216667,5.468115,85.781926,4.776355,83.502833,7.88022,192,26,43,203,"{""learning_rate"": 0.05, ""max_depth"": 4, ""min_c...",34.862254,0
3,SP,train,97.96748,,95.412844,,,,208,10,5,241,,1002.56,0


In [145]:
for metric in ['sensitivity', 'specificity', 'roc_auc', 'diagnostic_odds_ratio']:
    colour='#888888'
    fig = plt.figure(figsize=(2.2, 3.5))
    axes = plt.gca()
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['left'].set_visible(False)
    axes.get_yaxis().set_visible(False)
    x=range(4)
    x=results[(results.model!='SP')].model
    y=results[(results.model!='SP')][metric+'_mean']
    # these are standard deviations so convert to standard error at 95% given n=10
    e=results[(results.model!='SP')][metric+'_std']*1.96/10**0.5
    axes.set_ylim([0,100])
    axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

    if e.sum()>0:
        axes.errorbar(x,y,yerr=e, fmt='.',color=colour,linewidth=2)
        for (i,j,k) in zip(x,y,e):
            axes.text(i,j+k+2,'%.1f' % j,ha='center',color=colour)

    else:
        for (i,j) in zip(x,y):
            axes.text(i,j+2,'%.1f' % j,ha='center',color=colour)

    fig.savefig('pdf/figure-3/fig-3-train-'+metric+'.pdf', bbox_inches="tight")
    plt.close()

posx and posy should be finite values
posx and posy should be finite values


In [146]:
for metric in ['sensitivity', 'specificity']:
    colour=['#888888','#888888','#888888','pink']
    fig = plt.figure(figsize=(3.2, 3.5))
    axes = plt.gca()
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['left'].set_visible(False)
    axes.get_yaxis().set_visible(False)
    x=results.model
    y=results[metric+'_mean']
    # these are standard deviations so convert to standard error at 95% given n=10
    e=results[metric+'_std']*1.96/10**0.5
    axes.set_ylim([0,100])
    axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

    if e.sum()>0:
        axes.errorbar(x,y,yerr=e, fmt='.',color='#888888',linewidth=2)
        for (i,j,k,c) in zip(x,y,e,colour):
            if k>0:
                axes.text(i,j+k+2,'%.1f' % j,ha='center',color=c)
            else:
                axes.text(i,j+2,'%.1f' % j,ha='center',color=c)

    else:
        for (i,j,c) in zip(x,y,colour):                
            axes.text(i,j+2,'%.1f' % j,ha='center',color=c)

    fig.savefig('pdf/figure-3/fig-3-train-'+metric+'-suspectpza.pdf', bbox_inches="tight")
    plt.close()

In [134]:
for idx,row in results.iterrows():

    fig = plt.figure(figsize=(1.5, 1.5))
    axes = plt.gca()

    axes.add_patch(Rectangle((0,0),1,1,fc='#e41a1c',alpha=0.7))
    axes.add_patch(Rectangle((0,1),1,1,fc='#4daf4a',alpha=0.7))
    axes.add_patch(Rectangle((1,1),1,1,fc='#fc9272',alpha=0.7))
    axes.add_patch(Rectangle((1,0),1,1,fc='#4daf4a',alpha=0.7))

    axes.set_xlim([0,2])
    axes.set_ylim([0,2])

    axes.set_xticks([0.5,1.5],labels=['R','S'])
    axes.set_yticks([0.5,1.5],labels=['S','R'])

    axes.text(0.5,0.5,row['FN'],ha='center',va='center')
    axes.text(1.5,0.5,row['TN'],ha='center',va='center')
    axes.text(0.5,1.5,row['TP'],ha='center',va='center')
    axes.text(1.5,1.5,row['FP'],ha='center',va='center')

    fig.savefig('pdf/figure-3/truthtable-'+row['dataset']+'-'+row['model']+'.pdf', bbox_inches='tight')
    plt.close()

### Model validation on `test` set

We can now evaluate the trained models on the `test` dataset.

First let's load the trained models, the datasets and the results for Suspect-PZA


In [136]:
for model in ['LR', 'NN', 'XB']:
    best_model[model] = load('models/'+model.lower()+'.skops', trusted=True)

# load the test dataset
X['test']={}
Y['test']={}
Z['test']={}
with open('data/ds-test.npy', 'rb') as f:
    Y['test']['input'] = numpy.load(f)
    X['test']['input'] = numpy.load(f)
    Z['test']['input'] = numpy.load(f, allow_pickle=True)

# load in the results for SuspectPZA
suspectpza={}
for i in ['test']:
    suspectpza[i]={}
    with open('data/suspectpza-'+i+'.npy', 'rb') as f:
        suspectpza[i]['input'] = numpy.load(f)
        suspectpza[i]['predicted'] = numpy.load(f)
        suspectpza[i]['muts'] = numpy.load(f, allow_pickle=True)
line = []

def validate_model(line, best_model, model_name, X, Y):
        
    Y['test']['predicted'] = best_model.predict(X['test']['input'])
    Y['test']['scores'] = best_model.predict_proba(X['test']['input'])[:,1]
    row = construct_line(model_name , 'test', None, Y['test'], None)
    line.append(row)
    return(line)

for model in ['LR', 'NN', 'XB']:
    line = validate_model(line, best_model[model], model, X, Y)

line.append(construct_line('SP', 'test', None, suspectpza['test'], None))

# load the results for the training dataset
# results = pandas.read_csv('results-training.csv')
# results

In [137]:
def bootstrap_model(line, best_model, model_name, X, Y):
    x={}
    y={}
    for i in range(10):
        x['input'],y['input'] = resample(X['test']['input'], Y['test']['input'])
        y['predicted'] = best_model.predict(x['input'])
        y['scores'] = best_model.predict_proba(x['input'])[:,1]
        row = construct_line(model_name, 'test_'+str(i), None, y, None)
        line.append(row)
    return(line)

In [138]:
for model in ['LR', 'NN', 'XB']:
    line = bootstrap_model(line, best_model[model], model, X, Y)

In [139]:
for model in ['LR', 'NN', 'XB']:
    sens_mean = test_results[(test_results.model==model) & (test_results.dataset!='test')].sensitivity_mean.mean()
    sens_std = 1.96*test_results[(test_results.model==model) & (test_results.dataset!='test')].sensitivity_mean.std()/10**0.5
    spec_mean = test_results[(test_results.model==model) & (test_results.dataset!='test')].specificity_mean.mean()
    spec_std = 1.96*test_results[(test_results.model==model) & (test_results.dataset!='test')].specificity_mean.std()/10**0.5
    roc_mean = test_results[(test_results.model==model) & (test_results.dataset!='test')].roc_auc_mean.mean()
    row = [model, 'bootstrapped', sens_mean, sens_std, spec_mean, spec_std, roc_mean, None, None, None, None, None, None ]
    line.append(row)

LR 78.87218069348816 2.3913554230931937 68.87180565551073 3.0209801803721175
NN 76.02552201717394 2.290590892565823 67.60875329733275 2.8394328346775106
XB 76.3662297547237 2.297805414413918 74.55042185808142 2.717459163756119


In [140]:
test_results = pandas.DataFrame(line, columns=['model', 'dataset', 'sensitivity_mean', 'sensitivity_std', 'specificity_mean', 'specificity_std' ,'roc_auc_mean', 'roc_auc_std','TN','FP','FN','TP', 'model_parameters'])
test_results['diagnostic_odds_ratio_mean'] = (test_results['TN']*test_results['TP'])/(test_results['FN']*test_results['FP'])
test_results['diagnostic_odds_ratio_std'] = 0
test_results.to_csv('results-test.csv', index=False)
test_results

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters,diagnostic_odds_ratio_mean,diagnostic_odds_ratio_std
0,LR,test,78.640777,,70.103093,,82.684416,,68.0,29.0,22.0,81.0,,8.633229,0
1,NN,test,76.699029,,67.010309,,79.861876,,65.0,32.0,24.0,79.0,,6.686198,0
2,XB,test,77.669903,,75.257732,,82.764488,,73.0,24.0,23.0,80.0,,10.57971,0
3,SP,test,95.145631,,97.938144,,,,95.0,2.0,5.0,98.0,,931.0,0
4,LR,test_0,76.595745,,71.698113,,81.011642,,76.0,30.0,22.0,72.0,,8.290909,0
5,LR,test_1,81.308411,,72.043011,,85.629585,,67.0,26.0,20.0,87.0,,11.209615,0
6,LR,test_2,71.296296,,67.391304,,75.925926,,62.0,30.0,31.0,77.0,,5.133333,0
7,LR,test_3,77.227723,,65.656566,,83.058306,,65.0,34.0,23.0,78.0,,6.483376,0
8,LR,test_4,78.070175,,63.953488,,81.599347,,55.0,31.0,25.0,89.0,,6.316129,0
9,LR,test_5,74.107143,,75.0,,81.006494,,66.0,22.0,29.0,83.0,,8.586207,0


In [141]:
for metric in ['sensitivity', 'specificity', 'roc_auc', 'diagnostic_odds_ratio']:
    colour='#888888'
    fig = plt.figure(figsize=(2.2, 3.5))
    axes = plt.gca()
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['left'].set_visible(False)
    axes.get_yaxis().set_visible(False)
    x=range(4)
    x=test_results[(test_results.model!='SP') & (test_results.dataset=='bootstrapped')].model
    y=test_results[(test_results.model!='SP') & (test_results.dataset=='bootstrapped')][metric+'_mean']
    e=test_results[(test_results.model!='SP') & (test_results.dataset=='bootstrapped')][metric+'_std']
    axes.set_ylim([0,100])
    axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

    if e.sum()>0:
        axes.errorbar(x,y,yerr=e, fmt='.',color=colour,linewidth=2)
        for (i,j,k) in zip(x,y,e):
            axes.text(i,j+k+2,'%.1f' % j,ha='center',color=colour)

    else:
        for (i,j) in zip(x,y):
            axes.text(i,j+2,'%.1f' % j,ha='center',color=colour)

    fig.savefig('pdf/figure-3/fig-3-test-'+metric+'.pdf', bbox_inches="tight")
    plt.close()

posx and posy should be finite values
posx and posy should be finite values
posx and posy should be finite values
posx and posy should be finite values
posx and posy should be finite values
posx and posy should be finite values


In [142]:
for metric in ['sensitivity', 'specificity']:
    colour=['#888888','#888888','#888888','pink']
    fig = plt.figure(figsize=(3.2, 3.5))
    axes = plt.gca()
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['left'].set_visible(False)
    axes.get_yaxis().set_visible(False)
    x=test_results[(test_results.dataset=='bootstrapped')].model
    y=test_results[(test_results.dataset=='bootstrapped')][metric+'_mean']
    e=test_results[(test_results.dataset=='bootstrapped')][metric+'_std']
    axes.set_ylim([0,100])
    axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

    if e.sum()>0:
        axes.errorbar(x,y,yerr=e, fmt='.',color='#888888',linewidth=2)
        for (i,j,k,c) in zip(x,y,e,colour):
            if k>0:
                axes.text(i,j+k+2,'%.1f' % j,ha='center',color=c)
            else:
                axes.text(i,j+2,'%.1f' % j,ha='center',color=c)

    else:
        for (i,j,c) in zip(x,y,colour):                
            axes.text(i,j+2,'%.1f' % j,ha='center',color=c)

    fig.savefig('pdf/figure-3/fig-3-test-'+metric+'-suspectpza.pdf', bbox_inches="tight")
    plt.close()

In [144]:
for idx,row in test_results[test_results.dataset=='test'].iterrows():

    fig = plt.figure(figsize=(1.5, 1.5))
    axes = plt.gca()

    axes.add_patch(Rectangle((0,0),1,1,fc='#e41a1c',alpha=0.7))
    axes.add_patch(Rectangle((0,1),1,1,fc='#4daf4a',alpha=0.7))
    axes.add_patch(Rectangle((1,1),1,1,fc='#fc9272',alpha=0.7))
    axes.add_patch(Rectangle((1,0),1,1,fc='#4daf4a',alpha=0.7))

    axes.set_xlim([0,2])
    axes.set_ylim([0,2])

    axes.set_xticks([0.5,1.5],labels=['R','S'])
    axes.set_yticks([0.5,1.5],labels=['S','R'])

    axes.text(0.5,0.5,row['FN'],ha='center',va='center')
    axes.text(1.5,0.5,row['TN'],ha='center',va='center')
    axes.text(0.5,1.5,row['TP'],ha='center',va='center')
    axes.text(1.5,1.5,row['FP'],ha='center',va='center')

    fig.savefig('pdf/figure-3/truthtable-'+row['dataset']+'-'+row['model']+'.pdf', bbox_inches='tight')
    plt.close()