In [30]:
import pandas, numpy, pathlib

from skops.io import load

from sklearn.utils import resample

import matplotlib.pyplot as plt
import seaborn

from misc import construct_line

from statsmodels.stats.weightstats import ztest as ztest

pathlib.Path('pdf/figure-5').mkdir(exist_ok=True)

number_of_bootstraps = 10

### Model validation on `validation` and `mic` sets

As discussed in the manuscript we have two further datasets (`validation` & `mic`) to evaluate the models on.

First let's load the trained models, the datasets and the results for Suspect-PZA


In [31]:
best_model = {}

for model in ['LR', 'NN', 'XB']:
    best_model[model] = load('models/'+model.lower()+'.skops', trusted=True)

X={}
Y={}
Z={}

for i in ['validation-samples', 'validation-mutations', 'mic']:
    X[i]={}
    Y[i]={}
    Z[i]={}
    with open('data/ds-'+i+'.npy', 'rb') as f:
        Y[i]['input'] = numpy.load(f)
        X[i]['input'] = numpy.load(f)
        Z[i]['input'] = numpy.load(f, allow_pickle=True)

# load in the results for SuspectPZA
suspectpza={}
for i in ['validation-samples', 'validation-mutations','mic']:
    suspectpza[i]={}
    with open('data/suspectpza-'+i+'.npy', 'rb') as f:
        suspectpza[i]['input'] = numpy.load(f)
        suspectpza[i]['predicted'] = numpy.load(f)
        suspectpza[i]['muts'] = numpy.load(f, allow_pickle=True)
line = []

For comparison later, let's make a subset of `validation-samples` that exclude those mutations which have an inconsistent phenotype, either because too few are in the dataset or they do not consistently test as R or S, suggesting they may be near the breakpoint/ECOFF. 

We need to do this for both the `validation-samples` features and the SuspectPZA results.

In [32]:
df = pandas.read_csv('data/ds-validation-mutations-full.csv')

mask = ~numpy.isin(Z['validation-samples']['input'], df[df.CONSISTENT_PHENOTYPE=='U'].MUTATION.unique())

X['validation-samplesnoU']={}
Y['validation-samplesnoU']={}
Z['validation-samplesnoU']={}

X['validation-samplesnoU']['input']=X['validation-samples']['input'][mask]
Y['validation-samplesnoU']['input']=Y['validation-samples']['input'][mask]
Z['validation-samplesnoU']['input']=Z['validation-samples']['input'][mask]

mask = ~numpy.isin(suspectpza['validation-samples']['muts'],(df[df.CONSISTENT_PHENOTYPE=='U'].MUTATION.unique()))
suspectpza['validation-samplesnoU']={}
suspectpza['validation-samplesnoU']['input']=suspectpza['validation-samples']['input'][mask]
suspectpza['validation-samplesnoU']['predicted']=suspectpza['validation-samples']['predicted'][mask]
suspectpza['validation-samplesnoU']['muts']=suspectpza['validation-samples']['muts'][mask]

In [33]:
df[df.CONSISTENT_PHENOTYPE!='U'].MUTATION.unique()

array(['A102R', 'A102T', 'A134V', 'A143G', 'A143T', 'A143V', 'A146E',
       'A146P', 'A146T', 'A146V', 'A171V', 'A30V', 'A3E', 'A46E', 'A46T',
       'A79T', 'A79V', 'C138R', 'C14R', 'C72R', 'C72Y', 'D129Y', 'D12A',
       'D12E', 'D12N', 'D136Y', 'D49A', 'D49G', 'D49N', 'D63A', 'D8A',
       'D8E', 'D8G', 'D8N', 'E15G', 'F106S', 'F13I', 'F13V', 'F58V',
       'F81C', 'F81S', 'F94C', 'F94S', 'G105D', 'G105V', 'G108R', 'G124D',
       'G132A', 'G132C', 'G132D', 'G132S', 'G162D', 'G17D', 'G17S',
       'G24D', 'G97C', 'G97D', 'G97R', 'G97S', 'G97V', 'H137P', 'H137Q',
       'H43P', 'H43Y', 'H51D', 'H51P', 'H51Q', 'H51R', 'H51Y', 'H57D',
       'H57L', 'H57P', 'H57Q', 'H57R', 'H57Y', 'H71P', 'H71Q', 'H71Y',
       'H82D', 'H82R', 'I31S', 'I5S', 'I5T', 'I6L', 'I6S', 'I6T', 'I6V',
       'I90T', 'K48N', 'K96E', 'K96N', 'K96R', 'K96T', 'L116R', 'L120P',
       'L120Q', 'L120R', 'L151S', 'L156P', 'L159R', 'L172P', 'L172R',
       'L182F', 'L19P', 'L27P', 'L35R', 'L4S', 'L85P', 'L85R', 'M175I

The below function takes a supplied model, applies it to the features of the different datasets and measures a range of metrics we can use to evaluate their performance.

Note that this uses the separate `construct_line` function which can be found in `misc.py`.

In [34]:
def validate_model(line, best_model, model_name, X, Y):

    for dataset in ['validation-samples', 'validation-samplesnoU', 'validation-mutations', 'mic']: 
        
        Y[dataset]['predicted'] = best_model.predict(X[dataset]['input'])
        Y[dataset]['scores'] = best_model.predict_proba(X[dataset]['input'])[:,1]

        row = construct_line(model_name , dataset, None, Y[dataset], None)
        line.append(row)

    return(line)

In [35]:
for model in ['NN', 'XB', 'LR']:
    line = validate_model(line, best_model[model], model, X, Y)

In [36]:
for i in ['validation-samples', 'validation-samplesnoU', 'validation-mutations', 'mic']:
    line.append(construct_line('SP', i, None, suspectpza[i], None))

In [37]:
def bootstrap_model(line, best_model, model_name, X, Y):
    x={}
    y={}
    for dataset in ['validation-samples', 'validation-samplesnoU', 'validation-mutations', 'mic']:
        print(model_name, dataset)
        for i in range(number_of_bootstraps):
            x['input'],y['input'] = resample(X[dataset]['input'], Y[dataset]['input'])
            y['predicted'] = best_model.predict(x['input'])
            y['scores'] = best_model.predict_proba(x['input'])[:,1]
            row = construct_line(model_name, dataset+'_'+str(i), None, y, None)
            line.append(row)
    return(line)

In [38]:
for model in ['LR', 'NN', 'XB']:
    line = bootstrap_model(line, best_model[model], model, X, Y)

LR validation-samples


LR validation-samplesnoU
LR validation-mutations
LR mic
NN validation-samples
NN validation-samplesnoU
NN validation-mutations
NN mic
XB validation-samples
XB validation-samplesnoU
XB validation-mutations
XB mic


In [39]:
test_results = pandas.DataFrame(line, columns=['model', 'dataset', 'sensitivity_mean', 'sensitivity_std', 'specificity_mean', 'specificity_std' ,'roc_auc_mean', 'roc_auc_std','TN','FP','FN','TP', 'model_parameters'])
test_results[:3]

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters
0,NN,validation-samples,90.517241,,51.649236,,73.731156,,642,601,264,2520,
1,NN,validation-samplesnoU,91.479239,,66.666667,,80.821306,,538,269,197,2115,
2,NN,validation-mutations,94.193548,,52.272727,,77.565982,,23,21,9,146,


In [40]:
line = []
for i in ['validation-samples', 'validation-samplesnoU', 'validation-mutations', 'mic']:
    for model in ['LR', 'NN', 'XB']:
        sens_mean = test_results[(test_results.model==model) & (test_results.dataset!=i) & (test_results.dataset.str.contains(i+"_"))].sensitivity_mean.mean()
        sens_std = 1.96*test_results[(test_results.model==model) & (test_results.dataset!=i) & (test_results.dataset.str.contains(i+"_"))].sensitivity_mean.std()/(number_of_bootstraps**0.5)
        spec_mean = test_results[(test_results.model==model) & (test_results.dataset!=i) & (test_results.dataset.str.contains(i+"_"))].specificity_mean.mean()
        spec_std = 1.96*test_results[(test_results.model==model) & (test_results.dataset!=i) & (test_results.dataset.str.contains(i+"_"))].specificity_mean.std()/(number_of_bootstraps**0.5)
        roc_mean = test_results[(test_results.model==model) & (test_results.dataset!=i) & (test_results.dataset.str.contains(i+"_"))].roc_auc_mean.mean()
        row = [model, 'bootstrapped-'+i, sens_mean, sens_std, spec_mean, spec_std, roc_mean, None, None, None, None, None, None ]
        line.append(row)

extra_rows = pandas.DataFrame(line, columns=['model', 'dataset', 'sensitivity_mean', 'sensitivity_std', 'specificity_mean', 'specificity_std' ,'roc_auc_mean', 'roc_auc_std','TN','FP','FN','TP', 'model_parameters'])
test_results  = pandas.concat([test_results, extra_rows])
test_results[:3]

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters
0,NN,validation-samples,90.517241,,51.649236,,73.731156,,642,601,264,2520,
1,NN,validation-samplesnoU,91.479239,,66.666667,,80.821306,,538,269,197,2115,
2,NN,validation-mutations,94.193548,,52.272727,,77.565982,,23,21,9,146,


In [41]:
extra_rows


Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters
0,LR,bootstrapped-validation-samples,98.104809,0.136916,40.755377,1.232571,79.632668,,,,,,
1,NN,bootstrapped-validation-samples,90.547766,0.296145,51.094362,0.939891,73.441742,,,,,,
2,XB,bootstrapped-validation-samples,97.140088,0.291099,46.145977,1.616154,80.841641,,,,,,
3,LR,bootstrapped-validation-samplesnoU,98.468819,0.155325,56.212887,0.956602,85.116371,,,,,,
4,NN,bootstrapped-validation-samplesnoU,91.820808,0.257823,66.338479,1.546297,80.936538,,,,,,
5,XB,bootstrapped-validation-samplesnoU,97.736557,0.128989,62.890063,0.843872,87.238188,,,,,,
6,LR,bootstrapped-validation-mutations,98.275016,0.46019,46.16129,5.256236,86.226189,,,,,,
7,NN,bootstrapped-validation-mutations,93.685099,1.235595,48.635791,4.88539,75.308746,,,,,,
8,XB,bootstrapped-validation-mutations,97.412368,0.71384,59.727047,3.525336,87.384076,,,,,,
9,LR,bootstrapped-mic,100.0,0.0,0.0,0.0,74.153233,,,,,,


In [42]:
test_results[(test_results.model=='XB') & (test_results.dataset!='validation-samplesnoU') & (test_results.dataset.str.contains('validation-samplesnoU'+"_"))]

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters
106,XB,validation-samplesnoU_0,97.860262,,62.484922,,87.033439,,518,311,49,2241,
107,XB,validation-samplesnoU_1,97.794118,,61.462206,,87.585058,,496,311,51,2261,
108,XB,validation-samplesnoU_2,97.645007,,64.527845,,86.388646,,533,293,54,2239,
109,XB,validation-samplesnoU_3,97.759227,,62.752076,,87.476833,,529,314,51,2225,
110,XB,validation-samplesnoU_4,97.638471,,62.151899,,86.429744,,491,299,55,2274,
111,XB,validation-samplesnoU_5,97.483731,,64.496314,,87.949789,,525,289,58,2247,
112,XB,validation-samplesnoU_6,97.396293,,61.664713,,87.426522,,526,327,59,2207,
113,XB,validation-samplesnoU_7,98.062528,,61.438679,,87.102712,,521,327,44,2227,
114,XB,validation-samplesnoU_8,97.994769,,62.787879,,87.109984,,518,307,46,2248,
115,XB,validation-samplesnoU_9,97.731164,,65.1341,,87.879151,,510,273,53,2283,


In [43]:
for dataset in ['validation-samples', 'validation-samplesnoU', 'validation-mutations', 'mic']:
    for metric in ['sensitivity_mean', 'specificity_mean']:
        for i in ['XB', 'NN', 'LR']:
            for j in ['XB', 'NN', 'LR']:
                if i!=j:
                    a = numpy.array(test_results[(test_results.model==i)& (test_results.dataset.str.contains(dataset+'_'))][metric])
                    b = numpy.array(test_results[(test_results.model==j)& (test_results.dataset.str.contains(dataset+'_'))][metric])
                    assert len(a) == number_of_bootstraps
                    tstat, pvalue = ztest(a,b)
                    if pvalue < 0.05:
                        print(dataset,metric, i,j,"Signficant",pvalue)
        print()

validation-samples sensitivity_mean XB NN Signficant 1.492923143746279e-212
validation-samples sensitivity_mean XB LR Signficant 4.155784157354011e-09
validation-samples sensitivity_mean NN XB Signficant 1.492923143746279e-212
validation-samples sensitivity_mean NN LR Signficant 0.0
validation-samples sensitivity_mean LR XB Signficant 4.155784157354011e-09
validation-samples sensitivity_mean LR NN Signficant 0.0

validation-samples specificity_mean XB NN Signficant 2.129146296974386e-07
validation-samples specificity_mean XB LR Signficant 2.0119005277541832e-07
validation-samples specificity_mean NN XB Signficant 2.129146296974386e-07
validation-samples specificity_mean NN LR Signficant 4.668934629228787e-39
validation-samples specificity_mean LR XB Signficant 2.0119005277541832e-07
validation-samples specificity_mean LR NN Signficant 4.668934629228787e-39

validation-samplesnoU sensitivity_mean XB NN Signficant 0.0
validation-samplesnoU sensitivity_mean XB LR Signficant 1.172104175894

  zstat = (value1 - value2 - diff) / std_diff


In [45]:
# test_results = pandas.DataFrame(line, columns=['model', 'dataset', 'sensitivity_mean', 'sensitivity_std', 'specificity_mean', 'specificity_std' ,'roc_auc_mean', 'roc_auc_std','TN','FP','FN','TP', 'model_parameters'])

# calculate the diagnostic odds ration
# test_results['diagnostic_odds_ratio_mean'] = (test_results['TN']*test_results['TP'])/(test_results['FN']*test_results['FP'])
test_results['diagnostic_odds_ratio_mean'] = 0
test_results['diagnostic_odds_ratio_std'] = 0

# save to disc as a CSV
test_results.to_csv('results-validation.csv', index=False)

test_results[:20]

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters,diagnostic_odds_ratio_mean,diagnostic_odds_ratio_std
0,NN,validation-samples,90.517241,,51.649236,,73.731156,,642,601,264,2520,,0,0
1,NN,validation-samplesnoU,91.479239,,66.666667,,80.821306,,538,269,197,2115,,0,0
2,NN,validation-mutations,94.193548,,52.272727,,77.565982,,23,21,9,146,,0,0
3,NN,mic,94.0,,42.857143,,63.142857,,3,4,3,47,,0,0
4,XB,validation-samples,97.198276,,46.017699,,80.722592,,572,671,78,2706,,0,0
5,XB,validation-samplesnoU,97.621107,,63.07311,,86.880823,,509,298,55,2257,,0,0
6,XB,validation-mutations,97.419355,,59.090909,,88.284457,,26,18,4,151,,0,0
7,XB,mic,100.0,,14.285714,,67.714286,,1,6,0,50,,0,0
8,LR,validation-samples,98.060345,,40.305712,,79.642377,,501,742,54,2730,,0,0
9,LR,validation-samplesnoU,98.442907,,56.133829,,84.84152,,453,354,36,2276,,0,0


In [248]:
test_results.model.value_counts()

model
NN    16
XB    16
LR    16
SP     4
Name: count, dtype: int64

In [249]:
test_results[test_results.model=='XB']

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters,diagnostic_odds_ratio_mean,diagnostic_odds_ratio_std
4,XB,validation-samples,97.198276,,46.017699,,80.722592,,572.0,671.0,78.0,2706.0,,0,0
5,XB,validation-samplesnoU,97.621107,,63.07311,,86.880823,,509.0,298.0,55.0,2257.0,,0,0
6,XB,validation-mutations,97.419355,,59.090909,,88.284457,,26.0,18.0,4.0,151.0,,0,0
7,XB,mic,100.0,,14.285714,,67.714286,,1.0,6.0,0.0,50.0,,0,0
32,XB,validation-samples_0,96.872753,,44.016064,,80.022506,,548.0,697.0,87.0,2695.0,,0,0
33,XB,validation-samples_1,96.967535,,46.405229,,81.855968,,568.0,656.0,85.0,2718.0,,0,0
34,XB,validation-samplesnoU_0,97.194213,,61.097852,,86.169707,,512.0,326.0,64.0,2217.0,,0,0
35,XB,validation-samplesnoU_1,97.348161,,61.971831,,86.459065,,484.0,297.0,62.0,2276.0,,0,0
36,XB,validation-mutations_0,99.367089,,63.414634,,87.017598,,26.0,15.0,1.0,157.0,,0,0
37,XB,validation-mutations_1,97.278912,,61.538462,,89.782836,,32.0,20.0,4.0,143.0,,0,0


In [250]:
test_results[(test_results.model=='XB') & (test_results.dataset.str.contains('validation-samples'))]

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters,diagnostic_odds_ratio_mean,diagnostic_odds_ratio_std
4,XB,validation-samples,97.198276,,46.017699,,80.722592,,572.0,671.0,78.0,2706.0,,0,0
5,XB,validation-samplesnoU,97.621107,,63.07311,,86.880823,,509.0,298.0,55.0,2257.0,,0,0
32,XB,validation-samples_0,96.872753,,44.016064,,80.022506,,548.0,697.0,87.0,2695.0,,0,0
33,XB,validation-samples_1,96.967535,,46.405229,,81.855968,,568.0,656.0,85.0,2718.0,,0,0
34,XB,validation-samplesnoU_0,97.194213,,61.097852,,86.169707,,512.0,326.0,64.0,2217.0,,0,0
35,XB,validation-samplesnoU_1,97.348161,,61.971831,,86.459065,,484.0,297.0,62.0,2276.0,,0,0
2,XB,bootstrapped-validation-samples,96.920144,0.092886,45.210647,2.341381,80.939237,,,,,,,0,0
5,XB,bootstrapped-validation-samplesnoU,97.271187,0.150869,61.534842,0.856499,86.314386,,,,,,,0,0


In [44]:
test_results

Unnamed: 0,model,dataset,sensitivity_mean,sensitivity_std,specificity_mean,specificity_std,roc_auc_mean,roc_auc_std,TN,FP,FN,TP,model_parameters
0,NN,validation-samples,90.517241,,51.649236,,73.731156,,642,601,264,2520,
1,NN,validation-samplesnoU,91.479239,,66.666667,,80.821306,,538,269,197,2115,
2,NN,validation-mutations,94.193548,,52.272727,,77.565982,,23,21,9,146,
3,NN,mic,94.000000,,42.857143,,63.142857,,3,4,3,47,
4,XB,validation-samples,97.198276,,46.017699,,80.722592,,572,671,78,2706,
...,...,...,...,...,...,...,...,...,...,...,...,...,...
7,NN,bootstrapped-validation-mutations,93.685099,1.235595,48.635791,4.88539,75.308746,,,,,,
8,XB,bootstrapped-validation-mutations,97.412368,0.71384,59.727047,3.525336,87.384076,,,,,,
9,LR,bootstrapped-mic,100.000000,0.0,0.000000,0.0,74.153233,,,,,,
10,NN,bootstrapped-mic,93.489481,1.794188,43.191198,14.952192,63.003833,,,,,,
