In [81]:
import mhcflurry, seaborn, numpy, pandas, pickle, sklearn, collections, scipy, time
import mhcflurry.dataset
import fancyimpute, locale
from matplotlib import pyplot


import sklearn.metrics
import sklearn.cross_validation
%matplotlib inline


def print_full(x):
    pandas.set_option('display.max_rows', len(x))
    print(x)
    pandas.reset_option('display.max_rows')

In [82]:
max_ic50 = 50000
data_dir = "../data/"

In [83]:
all_train_data = mhcflurry.dataset.Dataset.from_csv(data_dir + "bdata.2009.mhci.public.1.txt")

In [84]:
def make_scores(ic50_y, ic50_y_pred, sample_weight=None, threshold_nm=500):     
    y_pred = mhcflurry.regression_target.ic50_to_regression_target(ic50_y_pred, max_ic50)
    try:
        auc = sklearn.metrics.roc_auc_score(ic50_y <= threshold_nm, y_pred, sample_weight=sample_weight)
    except ValueError:
        auc = numpy.nan
    try:
        f1 = sklearn.metrics.f1_score(ic50_y <= threshold_nm, ic50_y_pred <= threshold_nm, sample_weight=sample_weight)
    except ValueError:
        f1 = numpy.nan
    try:
        tau = scipy.stats.kendalltau(ic50_y_pred, ic50_y)[0]
    except ValueError:
        tau = numpy.nan
    
    return dict(
        auc=auc,
        f1=f1,
        tau=tau,
    )  

In [85]:
models = pandas.read_csv("../data/validation_models.csv", converters={'layer_sizes': eval})
models["layer_size"] = [x[0] for x in models.layer_sizes]
del models["activation"]
models

Unnamed: 0,dropout_probability,embedding_output_dim,fraction_negative,impute,layer_sizes,layer_size
0,0.5,32,0.2,True,[64],64
1,0.5,32,0.2,True,[64],64
2,0.5,32,0.2,True,[64],64
3,0.5,32,0.2,True,[64],64
4,0.5,32,0.2,True,[64],64
5,0.5,32,0.2,True,[64],64
6,0.5,32,0.2,True,[64],64
7,0.5,32,0.2,True,[64],64
8,0.5,32,0.2,True,[64],64
9,0.5,32,0.2,True,[64],64


In [86]:
def name_model(row):
    size = "big" if row.embedding_output_dim == 32 else "small"
    pieces = [size]
    if row.dropout_probability > 0:
        pieces.append("dropout")
    if row.impute:
        pieces.append("impute")
    return " ".join(pieces)

models["num"] = models.index
models["name"] = [name_model(row) for (_, row) in models.iterrows()]
models

Unnamed: 0,dropout_probability,embedding_output_dim,fraction_negative,impute,layer_sizes,layer_size,num,name
0,0.5,32,0.2,True,[64],64,0,big dropout impute
1,0.5,32,0.2,True,[64],64,1,big dropout impute
2,0.5,32,0.2,True,[64],64,2,big dropout impute
3,0.5,32,0.2,True,[64],64,3,big dropout impute
4,0.5,32,0.2,True,[64],64,4,big dropout impute
5,0.5,32,0.2,True,[64],64,5,big dropout impute
6,0.5,32,0.2,True,[64],64,6,big dropout impute
7,0.5,32,0.2,True,[64],64,7,big dropout impute
8,0.5,32,0.2,True,[64],64,8,big dropout impute
9,0.5,32,0.2,True,[64],64,9,big dropout impute


In [87]:
model_groups = models.groupby("name").num.unique()
model_groups

name
big dropout impute    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Name: num, dtype: object

In [88]:
validation_df_with_mhcflurry_results = pandas.read_csv("../data/validation_predictions_full.csv")
validation_df_with_mhcflurry_results

Unnamed: 0,allele,peptide,length,meas,netmhc,netmhcpan,smmpmbec_cpp,mhcflurry 0,mhcflurry 1,mhcflurry 2,mhcflurry 3,mhcflurry 4,mhcflurry 5,mhcflurry 6,mhcflurry 7,mhcflurry 8,mhcflurry 9
0,H-2-DB,AAACNVATA,9,657.657837,154.881662,711.213514,438.530698,896.371918,246.474087,718.592979,513.409634,1038.322913,410.089782,321.837171,476.360357,99.853278,436.243108
1,H-2-DB,AAFEFVYV,8,30831.879502,6456.542290,785.235635,10351.421667,15380.669503,13018.103870,16184.412481,12565.990353,13729.463449,12586.976303,15411.577054,16145.523804,11191.101172,12071.042841
2,H-2-DB,AAFVNDYSL,9,77.446180,17.458222,7.516229,28.054336,18.898478,19.902942,28.376859,24.462336,24.152721,15.615939,24.594960,18.602091,21.495659,21.526401
3,H-2-DB,AAIANQAAV,9,1.999862,9.638290,9.749896,25.703958,5.217476,4.592659,5.524621,6.383959,7.600944,4.600862,6.629486,5.670683,4.044471,4.594558
4,H-2-DB,AAIANQAVV,9,1.517050,8.550667,8.336812,28.773984,3.951842,3.851975,4.696085,4.735078,4.963316,3.562410,5.235347,4.820519,3.529904,3.496377
5,H-2-DB,AAIENYVRF,9,37.844258,252.348077,114.815362,187.068214,264.862456,391.912534,336.808244,376.763251,171.675278,150.441429,209.546602,381.269284,161.818654,232.514121
6,H-2-DB,AAINFITTM,9,3.155005,199.986187,389.045145,200.909281,68.752051,283.474222,156.760884,74.056941,182.614973,110.513145,153.883917,94.150197,88.539513,81.477160
7,H-2-DB,AAIPAPPPI,9,3243.396173,1059.253725,493.173804,295.120923,499.488283,291.264713,255.700708,336.855817,532.068890,417.036227,247.563730,608.113197,303.386100,283.513621
8,H-2-DB,AAKLNRPPL,9,654.636174,66.374307,77.268059,38.459178,130.049271,277.307559,195.828969,97.725190,277.119932,120.584336,255.880186,198.416331,128.302026,137.078634
9,H-2-DB,AALDMVDAL,9,229.614865,547.015963,597.035287,225.423921,1215.599573,698.941409,875.252839,1565.847844,1916.306665,904.529660,808.521060,517.176599,602.270740,1432.903463


In [90]:
# Extend with ensemble predictions
all_indices = sorted(set.union(*[set(indices) for (name, indices) in model_groups.iteritems()]))
all_indices_impute = sorted(
    set.union(*[set(indices) for (name, indices) in model_groups.iteritems() if 'impute' in name]))
#all_indices_not_impute = sorted(
#    set.union(*[set(indices) for (name, indices) in model_groups.iteritems() if 'impute' not in name]))

for (name, indices) in list(model_groups.iteritems()) + [("all", all_indices),
                                                         ("all impute", all_indices_impute),
                                                        # ("all not impute", all_indices_not_impute)
                                                        ]:
    validation_df_with_mhcflurry_results["mhcflurry ensemble %s" % name] = \
        scipy.stats.mstats.gmean(
            validation_df_with_mhcflurry_results[["mhcflurry %d" % i for i in indices]],
            axis=1)

validation_df_with_mhcflurry_results

Unnamed: 0,allele,peptide,length,meas,netmhc,netmhcpan,smmpmbec_cpp,mhcflurry 0,mhcflurry 1,mhcflurry 2,mhcflurry 3,mhcflurry 4,mhcflurry 5,mhcflurry 6,mhcflurry 7,mhcflurry 8,mhcflurry 9,mhcflurry ensemble big dropout impute,mhcflurry ensemble all,mhcflurry ensemble all impute
0,H-2-DB,AAACNVATA,9,657.657837,154.881662,711.213514,438.530698,896.371918,246.474087,718.592979,513.409634,1038.322913,410.089782,321.837171,476.360357,99.853278,436.243108,433.020104,433.020104,433.020104
1,H-2-DB,AAFEFVYV,8,30831.879502,6456.542290,785.235635,10351.421667,15380.669503,13018.103870,16184.412481,12565.990353,13729.463449,12586.976303,15411.577054,16145.523804,11191.101172,12071.042841,13721.160510,13721.160510,13721.160510
2,H-2-DB,AAFVNDYSL,9,77.446180,17.458222,7.516229,28.054336,18.898478,19.902942,28.376859,24.462336,24.152721,15.615939,24.594960,18.602091,21.495659,21.526401,21.473709,21.473709,21.473709
3,H-2-DB,AAIANQAAV,9,1.999862,9.638290,9.749896,25.703958,5.217476,4.592659,5.524621,6.383959,7.600944,4.600862,6.629486,5.670683,4.044471,4.594558,5.388706,5.388706,5.388706
4,H-2-DB,AAIANQAVV,9,1.517050,8.550667,8.336812,28.773984,3.951842,3.851975,4.696085,4.735078,4.963316,3.562410,5.235347,4.820519,3.529904,3.496377,4.236900,4.236900,4.236900
5,H-2-DB,AAIENYVRF,9,37.844258,252.348077,114.815362,187.068214,264.862456,391.912534,336.808244,376.763251,171.675278,150.441429,209.546602,381.269284,161.818654,232.514121,251.751831,251.751831,251.751831
6,H-2-DB,AAINFITTM,9,3.155005,199.986187,389.045145,200.909281,68.752051,283.474222,156.760884,74.056941,182.614973,110.513145,153.883917,94.150197,88.539513,81.477160,116.916027,116.916027,116.916027
7,H-2-DB,AAIPAPPPI,9,3243.396173,1059.253725,493.173804,295.120923,499.488283,291.264713,255.700708,336.855817,532.068890,417.036227,247.563730,608.113197,303.386100,283.513621,359.448923,359.448923,359.448923
8,H-2-DB,AAKLNRPPL,9,654.636174,66.374307,77.268059,38.459178,130.049271,277.307559,195.828969,97.725190,277.119932,120.584336,255.880186,198.416331,128.302026,137.078634,170.362071,170.362071,170.362071
9,H-2-DB,AALDMVDAL,9,229.614865,547.015963,597.035287,225.423921,1215.599573,698.941409,875.252839,1565.847844,1916.306665,904.529660,808.521060,517.176599,602.270740,1432.903463,968.799959,968.799959,968.799959


In [91]:
scores_df = collections.defaultdict(list)
predictors = validation_df_with_mhcflurry_results.columns[4:]
pairs = [
    ("overall", validation_df_with_mhcflurry_results)
] + list(validation_df_with_mhcflurry_results.groupby("allele"))

for (allele, grouped) in pairs:
    scores_df["allele"].append(allele)
    scores_df["test_size"].append(len(grouped.meas))
    for predictor in predictors:
        scores = make_scores(grouped.meas, grouped[predictor])
        for (key, value) in scores.items():
            scores_df["%s_%s" % (predictor, key)].append(value)
            
scores_df = pandas.DataFrame(scores_df)
scores_df["train_size"] = [
        len(all_train_data.groupby_allele_dictionary()[a]) if a != 'overall' else numpy.nan
        for a in scores_df.allele
    ]
scores_df.index = scores_df.allele
scores_df

Unnamed: 0_level_0,allele,mhcflurry 0_auc,mhcflurry 0_f1,mhcflurry 0_tau,mhcflurry 1_auc,mhcflurry 1_f1,mhcflurry 1_tau,mhcflurry 2_auc,mhcflurry 2_f1,mhcflurry 2_tau,...,netmhc_f1,netmhc_tau,netmhcpan_auc,netmhcpan_f1,netmhcpan_tau,smmpmbec_cpp_auc,smmpmbec_cpp_f1,smmpmbec_cpp_tau,test_size,train_size
allele,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
overall,overall,0.931281,0.783092,0.584331,0.932004,0.783676,0.585116,0.931926,0.783396,0.585277,...,0.807217,0.586325,0.932637,0.799565,0.581384,0.921343,0.790263,0.564884,26888,
H-2-DB,H-2-DB,0.914909,0.60262,0.642927,0.911671,0.604255,0.632981,0.907149,0.551402,0.632867,...,0.621212,0.600337,0.874574,0.577236,0.574262,0.884187,0.628571,0.571252,564,3216.0
H-2-KB,H-2-KB,0.911338,0.810909,0.592313,0.908091,0.790353,0.587381,0.910429,0.810909,0.592571,...,0.813675,0.573199,0.825565,0.665354,0.486836,0.915994,0.859967,0.589218,558,3407.0
H-2-KD,H-2-KD,0.789779,0.575758,0.37957,0.791356,0.571429,0.375794,0.793844,0.575758,0.381805,...,0.657718,0.403275,0.819189,0.64557,0.390333,0.753692,0.538462,0.365247,229,452.0
HLA-A0101,HLA-A0101,0.908228,0.552147,0.520068,0.91261,0.610169,0.525722,0.918803,0.578035,0.530888,...,0.619565,0.524866,0.894895,0.594286,0.498767,0.832665,0.437811,0.428064,696,3725.0
HLA-A0201,HLA-A0201,0.928945,0.876522,0.625578,0.92989,0.867288,0.626369,0.929562,0.875709,0.627535,...,0.884336,0.635498,0.930479,0.880963,0.637338,0.927358,0.885121,0.626224,2126,9565.0
HLA-A0202,HLA-A0202,0.906042,0.795455,0.615726,0.905765,0.771084,0.625449,0.906874,0.786517,0.616238,...,0.755556,0.627143,0.898697,0.769231,0.62428,0.882206,0.727273,0.606938,126,3919.0
HLA-A0203,HLA-A0203,0.977183,0.952038,0.592257,0.977433,0.950898,0.590096,0.977423,0.944712,0.590287,...,0.948626,0.586911,0.974158,0.944578,0.591463,0.972885,0.946746,0.583908,651,5542.0
HLA-A0206,HLA-A0206,0.906213,0.872682,0.540045,0.903291,0.874534,0.532213,0.906858,0.868159,0.539881,...,0.872902,0.543184,0.910796,0.866258,0.535067,0.904317,0.878282,0.527571,682,4827.0
HLA-A0301,HLA-A0301,0.925704,0.870864,0.595008,0.923318,0.867238,0.595069,0.922716,0.872999,0.59307,...,0.900621,0.629236,0.927287,0.885106,0.61124,0.933966,0.897275,0.610891,811,6141.0


In [92]:
print_full(scores_df.ix["overall"].sort(inplace=False, ascending=False))

allele                                         overall
test_size                                        26888
netmhcpan_auc                                0.9326371
mhcflurry ensemble all_auc                   0.9325053
mhcflurry ensemble all impute_auc            0.9325053
mhcflurry ensemble big dropout impute_auc    0.9325053
netmhc_auc                                   0.9323441
mhcflurry 7_auc                              0.9322377
mhcflurry 3_auc                              0.9320403
mhcflurry 1_auc                              0.9320037
mhcflurry 2_auc                              0.9319256
mhcflurry 4_auc                               0.931854
mhcflurry 5_auc                              0.9315867
mhcflurry 6_auc                              0.9315867
mhcflurry 9_auc                              0.9315061
mhcflurry 8_auc                              0.9315033
mhcflurry 0_auc                              0.9312809
smmpmbec_cpp_auc                             0.9213434
netmhc_f1 

In [94]:
print_full(scores_df.ix[(scores_df.index != "overall")].mean(0).sort(inplace=False, ascending=False))

train_size                                   2337.490196
test_size                                     527.215686
netmhcpan_auc                                   0.911105
mhcflurry ensemble big dropout impute_auc       0.909867
mhcflurry ensemble all impute_auc               0.909867
mhcflurry ensemble all_auc                      0.909867
mhcflurry 3_auc                                 0.909310
mhcflurry 5_auc                                 0.909286
mhcflurry 9_auc                                 0.909162
mhcflurry 7_auc                                 0.909092
mhcflurry 0_auc                                 0.909077
mhcflurry 2_auc                                 0.909010
mhcflurry 4_auc                                 0.908804
mhcflurry 6_auc                                 0.908730
netmhc_auc                                      0.908603
mhcflurry 1_auc                                 0.908563
mhcflurry 8_auc                                 0.908262
smmpmbec_cpp_auc               

In [99]:
print_full(scores_df.ix[(scores_df.index != "overall") & (scores_df.train_size >= 500)].mean(0).sort(inplace=False, ascending=False))

train_size                                   2656.227273
test_size                                     558.181818
mhcflurry ensemble big dropout impute_auc       0.911697
mhcflurry ensemble all impute_auc               0.911697
mhcflurry ensemble all_auc                      0.911697
mhcflurry 9_auc                                 0.911425
mhcflurry 4_auc                                 0.911022
mhcflurry 7_auc                                 0.910952
mhcflurry 5_auc                                 0.910924
mhcflurry 1_auc                                 0.910866
mhcflurry 0_auc                                 0.910863
mhcflurry 2_auc                                 0.910839
mhcflurry 3_auc                                 0.910836
mhcflurry 6_auc                                 0.910747
mhcflurry 8_auc                                 0.910127
netmhcpan_auc                                   0.909513
netmhc_auc                                      0.908575
smmpmbec_cpp_auc               

In [134]:
scores_df.to_csv("../data/validation_scores.csv", index=False)