In [1]:
import mhcflurry, seaborn, numpy, pandas, pickle, sklearn, collections, scipy, time
import mhcflurry.dataset
import fancyimpute, locale
from matplotlib import pyplot


import sklearn.metrics
import sklearn.cross_validation
%matplotlib inline


def print_full(x):
    pandas.set_option('display.max_rows', len(x))
    print(x)
    pandas.reset_option('display.max_rows')

Using Theano backend.


Couldn't import dot_parser, loading of dot files will not be possible.




In [2]:
max_ic50 = 50000
data_dir = "../data/"

In [3]:
all_train_data = mhcflurry.dataset.Dataset.from_csv(data_dir + "bdata.2009.mhci.public.1.txt")

In [4]:
def make_scores(ic50_y, ic50_y_pred, sample_weight=None, threshold_nm=500):     
    y_pred = mhcflurry.regression_target.ic50_to_regression_target(ic50_y_pred, max_ic50)
    try:
        auc = sklearn.metrics.roc_auc_score(ic50_y <= threshold_nm, y_pred, sample_weight=sample_weight)
    except ValueError:
        auc = numpy.nan
    try:
        f1 = sklearn.metrics.f1_score(ic50_y <= threshold_nm, ic50_y_pred <= threshold_nm, sample_weight=sample_weight)
    except ValueError:
        f1 = numpy.nan
    try:
        tau = scipy.stats.kendalltau(ic50_y_pred, ic50_y)[0]
    except ValueError:
        tau = numpy.nan
    
    return dict(
        auc=auc,
        f1=f1,
        tau=tau,
    )  

In [5]:
models = pandas.read_csv("../data/validation_models.csv", converters={'layer_sizes': eval})
models["layer_size"] = [x[0] for x in models.layer_sizes]
del models["activation"]
models

Unnamed: 0,dropout_probability,embedding_output_dim,fraction_negative,impute,layer_sizes,layer_size
0,0.5,32,0.2,True,[64],64
1,0.5,32,0.2,False,[64],64
2,0.5,32,0.2,True,[64],64
3,0.5,32,0.2,False,[64],64
4,0.5,32,0.2,True,[64],64
5,0.5,32,0.2,False,[64],64
6,0.5,32,0.2,True,[64],64
7,0.5,32,0.2,False,[64],64
8,0.5,32,0.2,True,[64],64
9,0.5,32,0.2,False,[64],64


In [6]:
def name_model(row):
    size = "big" if row.embedding_output_dim == 32 else "small"
    pieces = [size]
    if row.dropout_probability > 0:
        pieces.append("dropout")
    if row.impute:
        pieces.append("impute")
    return " ".join(pieces)

models["num"] = models.index
models["name"] = [name_model(row) for (_, row) in models.iterrows()]
models

Unnamed: 0,dropout_probability,embedding_output_dim,fraction_negative,impute,layer_sizes,layer_size,num,name
0,0.5,32,0.2,True,[64],64,0,big dropout impute
1,0.5,32,0.2,False,[64],64,1,big dropout
2,0.5,32,0.2,True,[64],64,2,big dropout impute
3,0.5,32,0.2,False,[64],64,3,big dropout
4,0.5,32,0.2,True,[64],64,4,big dropout impute
5,0.5,32,0.2,False,[64],64,5,big dropout
6,0.5,32,0.2,True,[64],64,6,big dropout impute
7,0.5,32,0.2,False,[64],64,7,big dropout
8,0.5,32,0.2,True,[64],64,8,big dropout impute
9,0.5,32,0.2,False,[64],64,9,big dropout


In [7]:
model_groups = models.groupby("name").num.unique()
model_groups

name
big dropout           [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25...
big dropout impute    [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24...
Name: num, dtype: object

In [8]:
validation_df_with_mhcflurry_results = pandas.read_csv("../data/validation_predictions_full.csv")
validation_df_with_mhcflurry_results

Unnamed: 0,allele,peptide,length,meas,netmhc,netmhcpan,smmpmbec_cpp,mhcflurry 0,mhcflurry 1,mhcflurry 2,...,mhcflurry 54,mhcflurry 55,mhcflurry 56,mhcflurry 57,mhcflurry 58,mhcflurry 59,mhcflurry 60,mhcflurry 61,mhcflurry 62,mhcflurry 63
0,H-2-DB,AAACNVATA,9,657.657837,154.881662,711.213514,438.530698,313.802565,605.426263,395.292236,...,360.145415,239.959211,528.333474,1558.652226,808.809197,216.276775,765.882346,318.346678,555.803209,860.878859
1,H-2-DB,AAFEFVYV,8,30831.879502,6456.542290,785.235635,10351.421667,12620.897081,13963.382225,13924.810174,...,13073.502232,12914.565052,13259.055839,16553.972218,15368.253558,11996.603908,14264.470465,13114.823463,12333.075112,13425.395827
2,H-2-DB,AAFVNDYSL,9,77.446180,17.458222,7.516229,28.054336,19.408797,26.851752,22.741957,...,11.907585,23.064450,12.082066,36.695309,31.888881,16.218054,21.185453,19.990712,13.560549,29.475245
3,H-2-DB,AAIANQAAV,9,1.999862,9.638290,9.749896,25.703958,5.806578,6.472855,6.664925,...,5.351417,5.410978,4.335345,6.807515,7.501682,5.178998,7.829220,4.801642,4.589926,8.606542
4,H-2-DB,AAIANQAVV,9,1.517050,8.550667,8.336812,28.773984,4.966659,4.492635,4.849117,...,3.938183,4.728728,3.744065,4.616987,5.885797,3.987434,5.098143,3.925527,3.534521,6.281922
5,H-2-DB,AAIENYVRF,9,37.844258,252.348077,114.815362,187.068214,475.546334,203.843872,250.378300,...,271.215920,266.066764,149.692068,581.852336,674.091539,331.696926,206.674902,214.138051,232.844997,426.492030
6,H-2-DB,AAINFITTM,9,3.155005,199.986187,389.045145,200.909281,105.729578,94.888619,79.354410,...,97.739623,99.461874,79.858609,105.738579,94.234512,106.659335,115.439559,85.025227,85.966129,101.920974
7,H-2-DB,AAIPAPPPI,9,3243.396173,1059.253725,493.173804,295.120923,271.969162,401.527094,623.375917,...,361.502045,690.827083,298.654552,329.597292,850.876022,258.580355,351.038520,392.850338,613.049068,626.549884
8,H-2-DB,AAKLNRPPL,9,654.636174,66.374307,77.268059,38.459178,228.465656,173.332323,60.580604,...,91.832559,126.020070,132.196259,101.517078,161.281892,128.928011,159.606284,157.680821,100.970577,236.620440
9,H-2-DB,AALDMVDAL,9,229.614865,547.015963,597.035287,225.423921,2198.405790,1101.740698,1647.942340,...,798.535166,1019.671021,601.093263,1101.140115,1213.485529,1240.514970,1035.913059,587.853925,589.922772,1603.083374


In [11]:
# Extend with ensemble predictions
all_indices = sorted(set.union(*[set(indices) for (name, indices) in model_groups.iteritems()]))
all_indices_impute = sorted(
    set.union(*[set(indices) for (name, indices) in model_groups.iteritems() if 'impute' in name]))
all_indices_not_impute = sorted(
    set.union(*[set(indices) for (name, indices) in model_groups.iteritems() if 'impute' not in name]))

for (name, indices) in list(model_groups.iteritems()) + [("all", all_indices),
                                                         ("all impute", all_indices_impute),
                                                        ("all not impute", all_indices_not_impute)
                                                        ]:
    validation_df_with_mhcflurry_results["mhcflurry ensemble %s" % name] = \
        scipy.stats.mstats.gmean(
            validation_df_with_mhcflurry_results[["mhcflurry %d" % i for i in indices]],
            axis=1)

validation_df_with_mhcflurry_results

Unnamed: 0,allele,peptide,length,meas,netmhc,netmhcpan,smmpmbec_cpp,mhcflurry 0,mhcflurry 1,mhcflurry 2,...,mhcflurry 59,mhcflurry 60,mhcflurry 61,mhcflurry 62,mhcflurry 63,mhcflurry ensemble big dropout,mhcflurry ensemble big dropout impute,mhcflurry ensemble all,mhcflurry ensemble all impute,mhcflurry ensemble all not impute
0,H-2-DB,AAACNVATA,9,657.657837,154.881662,711.213514,438.530698,313.802565,605.426263,395.292236,...,216.276775,765.882346,318.346678,555.803209,860.878859,478.341442,422.719572,449.671313,422.719572,478.341442
1,H-2-DB,AAFEFVYV,8,30831.879502,6456.542290,785.235635,10351.421667,12620.897081,13963.382225,13924.810174,...,11996.603908,14264.470465,13114.823463,12333.075112,13425.395827,13876.140640,13803.621136,13839.833389,13803.621136,13876.140640
2,H-2-DB,AAFVNDYSL,9,77.446180,17.458222,7.516229,28.054336,19.408797,26.851752,22.741957,...,16.218054,21.185453,19.990712,13.560549,29.475245,19.629567,17.028405,18.282785,17.028405,19.629567
3,H-2-DB,AAIANQAAV,9,1.999862,9.638290,9.749896,25.703958,5.806578,6.472855,6.664925,...,5.178998,7.829220,4.801642,4.589926,8.606542,6.117030,5.873083,5.993815,5.873083,6.117030
4,H-2-DB,AAIANQAVV,9,1.517050,8.550667,8.336812,28.773984,4.966659,4.492635,4.849117,...,3.987434,5.098143,3.925527,3.534521,6.281922,4.578733,4.383984,4.480300,4.383984,4.578733
5,H-2-DB,AAIENYVRF,9,37.844258,252.348077,114.815362,187.068214,475.546334,203.843872,250.378300,...,331.696926,206.674902,214.138051,232.844997,426.492030,273.175836,244.336548,258.354100,244.336548,273.175836
6,H-2-DB,AAINFITTM,9,3.155005,199.986187,389.045145,200.909281,105.729578,94.888619,79.354410,...,106.659335,115.439559,85.025227,85.966129,101.920974,103.891421,105.068753,104.478428,105.068753,103.891421
7,H-2-DB,AAIPAPPPI,9,3243.396173,1059.253725,493.173804,295.120923,271.969162,401.527094,623.375917,...,258.580355,351.038520,392.850338,613.049068,626.549884,400.774208,450.530438,424.924675,450.530438,400.774208
8,H-2-DB,AAKLNRPPL,9,654.636174,66.374307,77.268059,38.459178,228.465656,173.332323,60.580604,...,128.928011,159.606284,157.680821,100.970577,236.620440,128.365361,145.248982,136.546468,145.248982,128.365361
9,H-2-DB,AALDMVDAL,9,229.614865,547.015963,597.035287,225.423921,2198.405790,1101.740698,1647.942340,...,1240.514970,1035.913059,587.853925,589.922772,1603.083374,1124.857663,974.315958,1046.884316,974.315958,1124.857663


In [12]:
scores_df = collections.defaultdict(list)
predictors = validation_df_with_mhcflurry_results.columns[4:]
pairs = [
    ("overall", validation_df_with_mhcflurry_results)
] + list(validation_df_with_mhcflurry_results.groupby("allele"))

for (allele, grouped) in pairs:
    scores_df["allele"].append(allele)
    scores_df["test_size"].append(len(grouped.meas))
    for predictor in predictors:
        scores = make_scores(grouped.meas, grouped[predictor])
        for (key, value) in scores.items():
            scores_df["%s_%s" % (predictor, key)].append(value)
            
scores_df = pandas.DataFrame(scores_df)
scores_df["train_size"] = [
        len(all_train_data.groupby_allele_dictionary()[a]) if a != 'overall' else numpy.nan
        for a in scores_df.allele
    ]
scores_df.index = scores_df.allele
scores_df

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


Unnamed: 0_level_0,allele,mhcflurry 0_auc,mhcflurry 0_f1,mhcflurry 0_tau,mhcflurry 10_auc,mhcflurry 10_f1,mhcflurry 10_tau,mhcflurry 11_auc,mhcflurry 11_f1,mhcflurry 11_tau,...,netmhc_f1,netmhc_tau,netmhcpan_auc,netmhcpan_f1,netmhcpan_tau,smmpmbec_cpp_auc,smmpmbec_cpp_f1,smmpmbec_cpp_tau,test_size,train_size
allele,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
overall,overall,0.932253,0.781061,0.585723,0.931884,0.784018,0.585343,0.931122,0.786562,0.578821,...,0.807217,0.586325,0.932637,0.799565,0.581384,0.921343,0.790263,0.564884,26888,
H-2-DB,H-2-DB,0.906615,0.552632,0.630969,0.912256,0.587719,0.634917,0.913774,0.573991,0.631975,...,0.621212,0.600337,0.874574,0.577236,0.574262,0.884187,0.628571,0.571252,564,3216.0
H-2-KB,H-2-KB,0.911299,0.780488,0.593369,0.909247,0.795539,0.588733,0.907039,0.780669,0.587471,...,0.813675,0.573199,0.825565,0.665354,0.486836,0.915994,0.859967,0.589218,558,3407.0
H-2-KD,H-2-KD,0.78895,0.573529,0.377489,0.788784,0.571429,0.374484,0.787125,0.577778,0.375794,...,0.657718,0.403275,0.819189,0.64557,0.390333,0.753692,0.538462,0.365247,229,452.0
HLA-A0101,HLA-A0101,0.917331,0.576471,0.527686,0.913558,0.584795,0.525529,0.908177,0.589595,0.524694,...,0.619565,0.524866,0.894895,0.594286,0.498767,0.832665,0.437811,0.428064,696,3725.0
HLA-A0201,HLA-A0201,0.930098,0.871772,0.630444,0.930506,0.871523,0.629146,0.930507,0.878238,0.628033,...,0.884336,0.635498,0.930479,0.880963,0.637338,0.927358,0.885121,0.626224,2126,9565.0
HLA-A0202,HLA-A0202,0.908537,0.790698,0.61675,0.909368,0.790698,0.62724,0.907428,0.772727,0.618285,...,0.755556,0.627143,0.898697,0.769231,0.62428,0.882206,0.727273,0.606938,126,3919.0
HLA-A0203,HLA-A0203,0.976839,0.953516,0.593213,0.976922,0.948626,0.594227,0.976787,0.94813,0.601407,...,0.948626,0.586911,0.974158,0.944578,0.591463,0.972885,0.946746,0.583908,651,5542.0
HLA-A0206,HLA-A0206,0.90585,0.871411,0.541627,0.90615,0.868159,0.54365,0.90468,0.872367,0.534772,...,0.872902,0.543184,0.910796,0.866258,0.535067,0.904317,0.878282,0.527571,682,4827.0
HLA-A0301,HLA-A0301,0.921956,0.873807,0.591366,0.921374,0.870588,0.588338,0.920937,0.87605,0.591832,...,0.900621,0.629236,0.927287,0.885106,0.61124,0.933966,0.897275,0.610891,811,6141.0


In [13]:
print_full(scores_df.ix["overall"].sort(inplace=False, ascending=False))

allele                                         overall
test_size                                        26888
mhcflurry 33_auc                             0.9331735
mhcflurry 49_auc                             0.9329895
mhcflurry ensemble all not impute_auc        0.9329674
mhcflurry ensemble big dropout_auc           0.9329674
mhcflurry 9_auc                              0.9329307
mhcflurry ensemble all_auc                   0.9329301
mhcflurry 1_auc                              0.9326766
mhcflurry 3_auc                              0.9326545
netmhcpan_auc                                0.9326371
mhcflurry ensemble big dropout impute_auc    0.9326042
mhcflurry ensemble all impute_auc            0.9326042
mhcflurry 19_auc                             0.9325866
mhcflurry 17_auc                             0.9325477
mhcflurry 32_auc                             0.9324706
mhcflurry 61_auc                             0.9324466
mhcflurry 43_auc                             0.9324336
mhcflurry 

In [14]:
print_full(scores_df.ix[(scores_df.index != "overall")].mean(0).sort(inplace=False, ascending=False))

train_size                                   2337.490196
test_size                                     527.215686
netmhcpan_auc                                   0.911105
mhcflurry 17_auc                                0.910589
mhcflurry ensemble big dropout_auc              0.910511
mhcflurry ensemble all not impute_auc           0.910511
mhcflurry ensemble all_auc                      0.910500
mhcflurry 22_auc                                0.910277
mhcflurry ensemble all impute_auc               0.910060
mhcflurry ensemble big dropout impute_auc       0.910060
mhcflurry 5_auc                                 0.910033
mhcflurry 28_auc                                0.909990
mhcflurry 3_auc                                 0.909934
mhcflurry 0_auc                                 0.909931
mhcflurry 46_auc                                0.909827
mhcflurry 49_auc                                0.909746
mhcflurry 24_auc                                0.909698
mhcflurry 52_auc               

In [15]:
print_full(scores_df.ix[(scores_df.index != "overall") & (scores_df.train_size >= 500)].mean(0).sort(inplace=False, ascending=False))

train_size                                   2656.227273
test_size                                     558.181818
mhcflurry 3_auc                                 0.912493
mhcflurry ensemble big dropout_auc              0.912458
mhcflurry ensemble all not impute_auc           0.912458
mhcflurry 29_auc                                0.912436
mhcflurry ensemble all_auc                      0.912366
mhcflurry 59_auc                                0.912298
mhcflurry 37_auc                                0.912253
mhcflurry 17_auc                                0.912210
mhcflurry 5_auc                                 0.912185
mhcflurry 9_auc                                 0.912135
mhcflurry 22_auc                                0.912091
mhcflurry 53_auc                                0.912082
mhcflurry 23_auc                                0.912079
mhcflurry 0_auc                                 0.912076
mhcflurry 25_auc                                0.912066
mhcflurry 1_auc                

In [19]:
print_full(scores_df.ix[(scores_df.index != "overall") & (scores_df.train_size < 500)].mean(0).sort(inplace=False, ascending=False))

train_size                                   334.000000
test_size                                    332.571429
netmhcpan_auc                                  0.920887
netmhc_auc                                     0.908777
mhcflurry 17_auc                               0.900630
mhcflurry 40_auc                               0.900301
mhcflurry 28_auc                               0.899614
mhcflurry 30_auc                               0.899473
mhcflurry 7_auc                                0.899201
mhcflurry 22_auc                               0.899130
mhcflurry 39_auc                               0.899081
mhcflurry ensemble all_auc                     0.899039
mhcflurry ensemble big dropout impute_auc      0.898738
mhcflurry ensemble all impute_auc              0.898738
mhcflurry ensemble big dropout_auc             0.898553
mhcflurry ensemble all not impute_auc          0.898553
mhcflurry 38_auc                               0.898424
mhcflurry 49_auc                               0

In [16]:
scores_df.to_csv("../data/validation_scores.csv", index=False)