In [46]:
import numpy, pandas, pathlib

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib_venn import venn3

from skops.io import load

pathlib.Path('pdf/figure-4').mkdir(exist_ok=True)

### Setup

First let's load the trained models, the Test dataset and performance of the models on the Training dataset

In [92]:
# load the trained models
best_model = {}
for model in ['LR', 'NN', 'XB']:
    best_model[model] = load('models/'+model.lower()+'.skops', trusted=True)

# load the test dataset
X={}
Y={}
Z={}
X['test']={}
Y['test']={}
Z['test']={}
with open('data/ds-test.npy', 'rb') as f:
    Y['test']['input'] = numpy.load(f)
    X['test']['input'] = numpy.load(f)
    Z['test']['input'] = numpy.load(f, allow_pickle=True)

errors = {}
errors['vme'] = {}
errors['me'] = {}

# load the results for the training dataset
results = pandas.read_csv('results-test.csv')
results

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters,diagnostic_odds_ratio
0,LR,test,78.640777,,70.103093,,82.684416,,68,29,22,81,,8.633229
1,LR,validation,97.557471,,43.845535,,80.006802,,545,698,68,2716,,31.186162
2,LR,mic,100.0,,14.285714,,68.0,,1,6,0,50,,inf
3,NN,test,76.699029,,67.010309,,79.861876,,65,32,24,79,,6.686198
4,NN,validation,94.755747,,48.833467,,77.079895,,607,636,146,2638,,17.244615
5,NN,mic,96.0,,28.571429,,69.714286,,2,5,2,48,,9.6
6,XB,test,77.669903,,75.257732,,82.764488,,73,24,23,80,,10.57971
7,XB,validation,97.413793,,44.328238,,80.775215,,551,692,72,2712,,29.991811
8,XB,mic,100.0,,14.285714,,66.0,,1,6,0,50,,inf
9,SP,test,95.145631,,97.938144,,,,95,2,5,98,,931.0


Having done that it is fairly simple to calculate the Very Major Errors (resistant mutations predicted susceptible) and Major Errors (susceptible mutations predicted resistant)

In [93]:
for model in ['LR', 'NN', 'XB']:

    y_predicted = best_model[model].predict(X['test']['input'])

    errors['vme'][model] = set(Z['test']['input'][(Y['test']['input']==1) & (y_predicted==0)])
    
    errors['me'][model] = set(Z['test']['input'][(Y['test']['input']==0) & (y_predicted==1)])

    print(model, len(errors['vme'][model]), len(errors['me'][model]))

LR 22 29
NN 24 32
XB 23 24


First, let's draw some Venn diagrams so we can see if the VME/ME mutations are different for each model or whether they are similar

In [95]:
for i in ['vme','me']:
    fig = plt.figure(figsize=(6,6))
    axes = venn3([errors[i]['LR'], errors[i]['NN'], errors[i]['XB']], set_labels=['LinReg','MLP','XGB'], set_colors=('#e41a1c','#377eb8','#4daf4a'))
    fig.savefig('pdf/figure-4/fig-4-venn-'+i+'.pdf', bbox_inches="tight")
    plt.close()


The Venn diagrams show that there are 15 VMEs and 15 MEs that are common to all three models

In [85]:
for i in ['vme','me']:

    if i=='vme':
        print("Very Major Errors (R sample predicted S)")
    elif i=='me':
        print("Major Errors (S sample predicted R)")

    all3 = errors[i]['XB'].intersection(errors[i]['LR']).intersection(errors[i]['NN'])

    print(len(all3), all3)

    resids = []
    for i in all3:
        resid = int(i[1:-1])
        if resid not in resids:
            resids.append(resid)

    resids = sorted(resids)

    line = 'resid '
    for i in resids:
        line += str(i) + " "
    print("The string below is intended for pasting into VMD to create a Graphical Representation")
    print(line+'\n')


Very Major Errors (R sample predicted S)
15 {'D110N', 'N112S', 'V93L', 'E174G', 'R29C', 'A38S', 'P115A', 'H82L', 'S18T', 'N118Y', 'S88T', 'A30S', 'E127D', 'T76S', 'S32I'}
The string below is intended for pasting into VMD to create a Graphical Representation
resid 18 29 30 32 38 76 82 88 93 110 112 115 118 127 174 

Major Errors (S sample predicted R)
15 {'A28P', 'G17S', 'A46S', 'L35Q', 'G23A', 'N11I', 'L172Q', 'S32R', 'L19V', 'A79T', 'L35R', 'F58Y', 'A178D', 'G17V', 'N11H'}
The string below is intended for pasting into VMD to create a Graphical Representation
resid 11 17 19 23 28 32 35 46 58 79 172 178 



In [96]:
for metric in ['sensitivity', 'specificity', 'roc_auc']:
    for dataset in ['test', 'validation']:
        colour='#888888'
        fig = plt.figure(figsize=(2.2, 3.5))
        axes = plt.gca()
        axes.spines['top'].set_visible(False)
        axes.spines['right'].set_visible(False)
        axes.spines['left'].set_visible(False)
        axes.get_yaxis().set_visible(False)
        x=range(4)
        x=results[(results.dataset==dataset) & (results.model!='SP')].model
        y=results[(results.dataset==dataset) & (results.model!='SP')][metric+'_mean']
        e=results[(results.dataset==dataset) & (results.model!='SP')][metric+'_std']
        axes.set_ylim([0,100])
        axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

        if e.sum()>0:
            axes.errorbar(x,y,yerr=e, fmt='.',color=colour,linewidth=2)
            for (i,j) in zip(x,y+e):
                axes.text(i,j+2,'%.1f' % j,ha='center',color=colour)

        else:
            for (i,j) in zip(x,y):
                axes.text(i,j+2,'%.1f' % j,ha='center',color=colour)

        fig.savefig('pdf/figure-4/fig-4-'+dataset+'-'+metric+'.pdf', bbox_inches="tight")
        plt.close()

In [88]:
results

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters,diagnostic_odds_ratio
0,LR,test,78.640777,,70.103093,,82.684416,,68,29,22,81,,8.633229
1,LR,validation,97.557471,,43.845535,,80.006802,,545,698,68,2716,,31.186162
2,LR,mic,100.0,,14.285714,,68.0,,1,6,0,50,,inf
3,NN,test,76.699029,,67.010309,,79.861876,,65,32,24,79,,6.686198
4,NN,validation,94.755747,,48.833467,,77.079895,,607,636,146,2638,,17.244615
5,NN,mic,96.0,,28.571429,,69.714286,,2,5,2,48,,9.6
6,XB,test,77.669903,,75.257732,,82.764488,,73,24,23,80,,10.57971
7,XB,validation,97.413793,,44.328238,,80.775215,,551,692,72,2712,,29.991811
8,XB,mic,100.0,,14.285714,,66.0,,1,6,0,50,,inf


In [97]:
for metric in ['sensitivity', 'specificity']:
    for dataset in ['test', 'validation']:
        colour=['#888888','#888888','#888888','pink']
        fig = plt.figure(figsize=(3.2, 3.5))
        axes = plt.gca()
        axes.spines['top'].set_visible(False)
        axes.spines['right'].set_visible(False)
        axes.spines['left'].set_visible(False)
        axes.get_yaxis().set_visible(False)
        x=range(4)
        x=results[(results.dataset==dataset)].model
        y=results[(results.dataset==dataset)][metric+'_mean']
        e=results[(results.dataset==dataset)][metric+'_std']
        axes.set_ylim([0,100])
        axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

        if e.sum()>0:
            axes.errorbar(x,y,yerr=e, fmt='.',color='#888888',linewidth=2)
            for (i,j,e,c) in zip(x,y,e,colour):
                if e>0:
                    axes.text(i,j+e+2,'%.1f' % j,ha='center',color=c)
                else:
                    axes.text(i,j+2,'%.1f' % j,ha='center',color=c)
 
        else:
            for (i,j,c) in zip(x,y,colour):                
                axes.text(i,j+2,'%.1f' % j,ha='center',color=c)

        fig.savefig('pdf/figure-4/fig-4-'+dataset+'-'+metric+'-suspectpza.pdf', bbox_inches="tight")
        plt.close()

In [98]:
for metric in ['diagnostic_odds_ratio']:
    for dataset in ['test', 'validation']:
        colour=['#888888','#888888','#888888','#888888','pink']
        fig = plt.figure(figsize=(2.2, 3.5))
        axes = plt.gca()
        axes.spines['top'].set_visible(False)
        axes.spines['right'].set_visible(False)
        axes.spines['left'].set_visible(False)
        axes.get_yaxis().set_visible(False)
        x=range(4)
        x=results[(results.dataset==dataset) & (results.model!='RF') & (results.model!='SP')].model
        y=results[(results.dataset==dataset) & (results.model!='RF') & (results.model!='SP')].diagnostic_odds_ratio
        # axes.set_ylim([0,1])
        axes.bar(x,y, label=y, edgecolor=colour, color='None',linewidth=2)

        # if e.sum()>0:
        #     axes.errorbar(x,y,yerr=e, fmt='.',color='k',linewidth=2)
        #     for (i,j,c) in zip(x,y+e,colour):
        #         axes.text(i,j+0.02,'%.2f' % j,ha='center',color=c)

        # else:
        for (i,j,c) in zip(x,y,colour):
            axes.text(i,j+0.1,'%.1f' % j,ha='center',color=c)

        fig.savefig('pdf/figure-4/fig-4-'+dataset+'-'+metric+'.pdf', bbox_inches="tight")
        plt.close()

In [99]:
for idx,row in results.iterrows():

    fig = plt.figure(figsize=(1.5, 1.5))
    axes = plt.gca()

    axes.add_patch(Rectangle((0,0),1,1,fc='#e41a1c',alpha=0.7))
    axes.add_patch(Rectangle((0,1),1,1,fc='#4daf4a',alpha=0.7))
    axes.add_patch(Rectangle((1,1),1,1,fc='#fc9272',alpha=0.7))
    axes.add_patch(Rectangle((1,0),1,1,fc='#4daf4a',alpha=0.7))

    axes.set_xlim([0,2])
    axes.set_ylim([0,2])

    axes.set_xticks([0.5,1.5],labels=['R','S'])
    axes.set_yticks([0.5,1.5],labels=['S','R'])

    axes.text(0.5,0.5,row['FN'],ha='center',va='center')
    axes.text(1.5,0.5,row['TN'],ha='center',va='center')
    axes.text(0.5,1.5,row['TP'],ha='center',va='center')
    axes.text(1.5,1.5,row['FP'],ha='center',va='center')

    fig.savefig('pdf/figure-4/truthtable-'+row['dataset']+'-'+row['model']+'.pdf', bbox_inches='tight')
    plt.close()