In [32]:
import numpy, pandas, pathlib

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib_venn import venn3

from skops.io import load

pathlib.Path('pdf/figure-4').mkdir(exist_ok=True)

### Setup

First let's load the trained models, the Test dataset and performance of the models on the Training dataset which was calculated and written to disc by the last notebook

In [33]:
# load the trained models
best_model = {}
for model in ['LR', 'NN', 'XB']:
    best_model[model] = load('models/'+model.lower()+'.skops', trusted=True)

# load the test dataset
X={}
Y={}
Z={}
X['test']={}
Y['test']={}
Z['test']={}
with open('data/ds-test.npy', 'rb') as f:
    Y['test']['input'] = numpy.load(f)
    X['test']['input'] = numpy.load(f)
    Z['test']['input'] = numpy.load(f, allow_pickle=True)

errors = {}
errors['vme'] = {}
errors['me'] = {}

# load the results for the training dataset
results = pandas.read_csv('results-test.csv')
results

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters,diagnostic_odds_ratio_mean,diagnostic_odds_ratio_std
0,LR,test,78.640777,,70.103093,,82.684416,,68,29,22,81,,8.633229,0
1,LR,validation-samples,97.557471,,43.845535,,80.006802,,545,698,68,2716,,31.186162,0
2,LR,validation-samples-noU,98.615917,,58.116481,,85.212409,,469,338,32,2280,,98.864645,0
3,LR,validation-mutations,97.419355,,50.0,,87.243402,,22,22,4,151,,37.75,0
4,LR,mic,100.0,,14.285714,,68.0,,1,6,0,50,,inf,0
5,NN,test,76.699029,,67.010309,,79.861876,,65,32,24,79,,6.686198,0
6,NN,validation-samples,94.755747,,48.833467,,77.079895,,607,636,146,2638,,17.244615,0
7,NN,validation-samples-noU,96.237024,,63.07311,,82.534983,,509,298,87,2225,,43.682982,0
8,NN,validation-mutations,96.774194,,52.272727,,81.480938,,23,21,5,150,,32.857143,0
9,NN,mic,96.0,,28.571429,,69.714286,,2,5,2,48,,9.6,0


Having done that it is fairly simple to calculate the Very Major Errors (resistant mutations predicted susceptible) and Major Errors (susceptible mutations predicted resistant)

In [34]:
for model in ['LR', 'NN', 'XB']:

    y_predicted = best_model[model].predict(X['test']['input'])

    errors['vme'][model] = set(Z['test']['input'][(Y['test']['input']==1) & (y_predicted==0)])
    
    errors['me'][model] = set(Z['test']['input'][(Y['test']['input']==0) & (y_predicted==1)])

    print(model, '| very major errors =', len(errors['vme'][model]), '| major errors =', len(errors['me'][model]))

LR | very major errors = 22 | major errors = 29
NN | very major errors = 24 | major errors = 32
XB | very major errors = 23 | major errors = 24


First, let's draw some Venn diagrams so we can see if the VME/ME mutations are different for each model or whether they are similar

In [35]:
for i in ['vme','me']:
    fig = plt.figure(figsize=(6,6))
    axes = venn3([errors[i]['LR'], errors[i]['NN'], errors[i]['XB']], set_labels=['LinReg','MLP','XGB'], set_colors=('#e41a1c','#377eb8','#4daf4a'))
    fig.savefig('pdf/figure-4/fig-4-venn-'+i+'.pdf', bbox_inches="tight")
    plt.close()

The Venn diagrams show that there are 15 VMEs and 15 MEs that are common to all three models

In [36]:
for i in ['vme','me']:

    if i=='vme':
        print("Very Major Errors (R sample predicted S)")
    elif i=='me':
        print("Major Errors (S sample predicted R)")

    all3 = errors[i]['XB'].intersection(errors[i]['LR']).intersection(errors[i]['NN'])

    print(len(all3), all3)

    resids = []
    for i in all3:
        resid = int(i[1:-1])
        if resid not in resids:
            resids.append(resid)

    resids = sorted(resids)

    line = 'resid '
    for i in resids:
        line += str(i) + " "
    print("The string below is intended for pasting into VMD to create a Graphical Representation")
    print(line+'\n')

Very Major Errors (R sample predicted S)
15 {'N118Y', 'A30S', 'N112S', 'T76S', 'P115A', 'S18T', 'S88T', 'S32I', 'D110N', 'E127D', 'R29C', 'A38S', 'E174G', 'V93L', 'H82L'}
The string below is intended for pasting into VMD to create a Graphical Representation
resid 18 29 30 32 38 76 82 88 93 110 112 115 118 127 174 

Major Errors (S sample predicted R)
15 {'G17V', 'A28P', 'A46S', 'L35Q', 'L172Q', 'G17S', 'S32R', 'G23A', 'F58Y', 'L35R', 'A79T', 'N11I', 'N11H', 'L19V', 'A178D'}
The string below is intended for pasting into VMD to create a Graphical Representation
resid 11 17 19 23 28 32 35 46 58 79 172 178 



Plot some bar charts of the different performance metrics for each the different datasets

In [38]:
for metric in ['sensitivity', 'specificity', 'roc_auc', 'diagnostic_odds_ratio']:
    for dataset in ['test', 'validation-samples', 'validation-samples-noU', 'validation-mutations']:
        colour='#888888'
        fig = plt.figure(figsize=(2.2, 3.5))
        axes = plt.gca()
        axes.spines['top'].set_visible(False)
        axes.spines['right'].set_visible(False)
        axes.spines['left'].set_visible(False)
        axes.get_yaxis().set_visible(False)
        x=range(4)
        x=results[(results.dataset==dataset) & (results.model!='SP')].model
        y=results[(results.dataset==dataset) & (results.model!='SP')][metric+'_mean']
        e=results[(results.dataset==dataset) & (results.model!='SP')][metric+'_std']
        axes.set_ylim([0,100])
        axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

        if e.sum()>0:
            axes.errorbar(x,y,yerr=e, fmt='.',color=colour,linewidth=2)
            for (i,j) in zip(x,y+e):
                axes.text(i,j+2,'%.1f' % j,ha='center',color=colour)

        else:
            for (i,j) in zip(x,y):
                axes.text(i,j+2,'%.1f' % j,ha='center',color=colour)

        fig.savefig('pdf/figure-4/fig-4-'+dataset+'-'+metric+'.pdf', bbox_inches="tight")
        plt.close()

Repeat, but include the results of SuspectPZA

In [16]:
for metric in ['sensitivity', 'specificity']:
    for dataset in ['test', 'validation-samples', 'validation-samples-noU', 'validation-mutations']:
        colour=['#888888','#888888','#888888','pink']
        fig = plt.figure(figsize=(3.2, 3.5))
        axes = plt.gca()
        axes.spines['top'].set_visible(False)
        axes.spines['right'].set_visible(False)
        axes.spines['left'].set_visible(False)
        axes.get_yaxis().set_visible(False)
        x=range(4)
        x=results[(results.dataset==dataset)].model
        y=results[(results.dataset==dataset)][metric+'_mean']
        e=results[(results.dataset==dataset)][metric+'_std']
        axes.set_ylim([0,100])
        axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

        if e.sum()>0:
            axes.errorbar(x,y,yerr=e, fmt='.',color='#888888',linewidth=2)
            for (i,j,e,c) in zip(x,y,e,colour):
                if e>0:
                    axes.text(i,j+e+2,'%.1f' % j,ha='center',color=c)
                else:
                    axes.text(i,j+2,'%.1f' % j,ha='center',color=c)
 
        else:
            for (i,j,c) in zip(x,y,colour):                
                axes.text(i,j+2,'%.1f' % j,ha='center',color=c)

        fig.savefig('pdf/figure-4/fig-4-'+dataset+'-'+metric+'-suspectpza.pdf', bbox_inches="tight")
        plt.close()

In [18]:
for idx,row in results.iterrows():

    fig = plt.figure(figsize=(1.5, 1.5))
    axes = plt.gca()

    axes.add_patch(Rectangle((0,0),1,1,fc='#e41a1c',alpha=0.7))
    axes.add_patch(Rectangle((0,1),1,1,fc='#4daf4a',alpha=0.7))
    axes.add_patch(Rectangle((1,1),1,1,fc='#fc9272',alpha=0.7))
    axes.add_patch(Rectangle((1,0),1,1,fc='#4daf4a',alpha=0.7))

    axes.set_xlim([0,2])
    axes.set_ylim([0,2])

    axes.set_xticks([0.5,1.5],labels=['R','S'])
    axes.set_yticks([0.5,1.5],labels=['S','R'])

    axes.text(0.5,0.5,row['FN'],ha='center',va='center')
    axes.text(1.5,0.5,row['TN'],ha='center',va='center')
    axes.text(0.5,1.5,row['TP'],ha='center',va='center')
    axes.text(1.5,1.5,row['FP'],ha='center',va='center')

    fig.savefig('pdf/figure-4/truthtable-'+row['dataset']+'-'+row['model']+'.pdf', bbox_inches='tight')
    plt.close()