In [None]:
import os, sys, copy, re, random
from collections import OrderedDict
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from statannot import add_stat_annotation
import sklearn.metrics as metrics
from scipy.stats import ttest_ind, t
from util import *
from IPython.display import display

valid_file = ''
benchmark_file = ''
work_dir = ''

# Validation dataframe

In [None]:
# valid dataframe
valid_df = pd.read_csv(valid_file, index_col=0)

## index (D-E ratio about 30)
index = list(valid_df[valid_df['source']=='MS'].index)
num = len(index)*15
index += list(valid_df[valid_df['source']=='assay'].index)
index += list(valid_df[valid_df['source']=='protein_decoy'].index)
index += list(valid_df[valid_df['source'].str.contains('random_decoy')].sample(n=num, random_state=0).index)

valid_df = valid_df.iloc[index]
bind = valid_df['bind'].to_numpy()

# Hyperparameters

In [None]:
# columns
cols = [i for i in valid_df.columns if 'batch_size' in i]

# performance
perform_list = list()
for col in cols:
    batch_size = col.split('_')[2]
    learning_rate = col.split('_')[-1]
    pred = valid_df[col].to_numpy()
    metric = CalculateMetrics(bind, pred)
    perform_list.append({
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'AUC': metric['AUC'],
        'AUC0.1': metric['AUC0.1'],
        'AP': metric['AP'],
        'PPV': metric['PPV']
    })
perform_df = pd.DataFrame(perform_list)
display(perform_df)

# Downsampling Performance

In [None]:
# columns
cols = [i for i in valid_df.columns if 'DE' in i]

# performance
perform_list = list()
for col in cols:
    DE_downsized = col.split('_')[1]
    DE_training = col.split('_')[-1]
    pred = valid_df[col].to_numpy()
    metric = CalculateMetrics(bind, pred)
    perform_list.append({
        'DE_downsized': DE_downsized,
        'DE_training': DE_training,
        'AUC': metric['AUC'],
        'AUC0.1': metric['AUC0.1'],
        'AP': metric['AP'],
        'PPV': metric['PPV']
    })
perform_df = pd.DataFrame(perform_list)
display(perform_df)

In [None]:
# Arguments
DE_downsized_list = [1,5,10,15,30]
DE_training_list = [30, 60, 90]

metrics_list = ['AP', 'AUC']

figsize = (4.5,3.5)
dpi = 600
fontsize = 10
linewidth = 1.5

# figure
fig, ax = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
current_ax = 0

for metrics_name in metrics_list:
    # Plot
    sns.lineplot(data=perform_df, x='DE_training', y=metrics_name, hue='DE_downsized',
                 linewidth=linewidth, ax=ax[current_ax])
    ax[current_ax].set_title(metrics_name, fontsize=fontsize)
    ax[current_ax].set_xlabel(None)
    ax[current_ax].set_ylabel(None)
    ax[current_ax].set_xticks([i for i in range(len(DE_training_list))])
    ax[current_ax].set_xticklabels(['%d'%i for i in DE_training_list], fontsize=fontsize)
    for i in ax[current_ax].get_yticklabels():
        i.set_fontsize(fontsize)
    
    h, l = ax[current_ax].get_legend_handles_labels()
    ax[current_ax].get_legend().remove()
    
    current_ax += 1

fig.tight_layout()

l = DE_downsized_list
leg = fig.legend(h, l, fontsize=fontsize, title="D-E ratio in each downsized dataset",
                 ncol=5, loc="upper center", bbox_to_anchor=(0.05, 1.05, 1, 0.1), 
                 columnspacing=1, handlelength=0.5, handletextpad=0.2, borderpad=0.2)

fig.add_subplot(111, frame_on=False)
plt.tick_params(labelcolor="none", bottom=False, left=False)
plt.xlabel("D-E ratio in the training dataset", fontsize=fontsize)

# Model Performance

In [None]:
def WelchTest(x1, x2):
    np.set_printoptions(precision=2)
    n1 = x1.size
    n2 = x2.size
    m1 = np.mean(x1)
    m2 = np.mean(x2)
    v1 = np.var(x1, ddof=1)
    v2 = np.var(x2, ddof=1)
    
    pooled_se = np.sqrt(v1 / n1 + v2 / n2)
    delta = m1-m2
    
    tstat = delta / pooled_se
    df = (v1 / n1 + v2 / n2)**2 / (v1**2 / (n1**2 * (n1 - 1)) + v2**2 / (n2**2 * (n2 - 1)))
    
    # two side t-test
    p = 2 * t.cdf(-abs(tstat), df)
    
    # upper and lower bounds
    lb = delta - t.ppf(0.975,df)*pooled_se 
    ub = delta + t.ppf(0.975,df)*pooled_se
    
    # stat dict
    stat_dict = {
        'n': [n1,n2],
        'm': [m1,m2],
        'sd': [np.sqrt(v1), np.sqrt(v2)],
        'df': df,
        'psd': pooled_se,
        'tstat': tstat,
        'delta': delta,
        'pvalue': p,
        'lb': lb,
        'ub': ub
    }
  
    return stat_dict


def PrintStatDF(df):
    temp_df = pd.DataFrame(index=['n','m','sd','df','psd','tstat','pvalue','lb','ub'])
    for idx, row in df.iterrows():
        pair_1, pair_2 = row['pair'].split('-')

        temp_df['{}_{}_{}'.format(pair_1, idx, row['metric'])] = [
            row['n'][0],
            row['m'][0],
            row['sd'][0],
            row['df'],
            row['psd'],
            row['tstat'],
            row['pvalue'],
            row['lb'],
            row['ub']
        ]

        temp_df['{}_{}_{}'.format(pair_2, idx, row['metric'])] = [
            row['n'][1],
            row['m'][1],
            row['sd'][1],
            row['df'],
            row['psd'],
            row['tstat'],
            row['pvalue'],
            row['lb'],
            row['ub']
        ]
    return temp_df


def ViolinPlotStat(df, xcol, ycol, box_pairs, figfile,
                   title=None, xlabel=None, ylabel=None, ytick_limit=1.0, xtick_rotate=45,
                   figsize=(3.5, 3.5), dpi=600, fontsize=10, linewidth=1.5):
    
    # plot
    fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
    sns.stripplot(x=xcol, y=ycol, data=df, ax=ax, s=2)
    sns.violinplot(x=xcol, y=ycol, data=df, ax=ax, color='.8', cut=0)
    
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_yticks([i for i in ax.get_yticks() if i <= ytick_limit])
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(fontsize)
    for item in ax.get_xticklabels():
        item.set_rotation(xtick_rotate)
        
    # stats
    test_results = add_stat_annotation(ax=ax, data=df, x=xcol, y=ycol,
                                       box_pairs=box_pairs, comparisons_correction=None,
                                       test='t-test_ind', stats_params={'equal_var': False},
                                       text_format='star', loc='inside',
                                       fontsize=fontsize, linewidth=linewidth,
                                       line_offset_to_box=0.15)
    
    stat_dict_list = list()
    for p1, p2 in box_pairs:
        stat_dict = WelchTest(df[df[xcol]==p1][ycol], df[df[xcol]==p2][ycol])
        stat_dict['pair'] = '{}-{}'.format(p1, p2)
        stat_dict_list.append(stat_dict)
        
    # savefig
    fig.savefig(figfile, bbox_inches='tight')
    
    return stat_dict_list

In [None]:
# load df

test_df = pd.read_csv(test_file, index_col=0)

# arguments
with_mixmhcpred = True
previous_convert_dict = True

# with/without MixMHCpred2.1
if with_mixmhcpred:
    tool_list = ['MHCfovea', 'NetMHCpan4.1', 'MHCflurry2.0', 'MixMHCpred2.1']
    test_df = test_df[~test_df['MixMHCpred2.1'].isna()]
else:
    tool_list = ['MHCfovea', 'NetMHCpan4.1', 'MHCflurry2.0']

if with_mixmhcpred:
    output_dir = '%s/with_mixmhcpred/'%work_dir
else:
    output_dir = '%s/without_mixmhcpred/'%work_dir
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

    
# get allele metrics
if previous_convert_dict:
    convert_dict = json.load(open('%s/AlleleMetrics.json'%output_dir, 'r'))

else:
    allele_metrics_dict = dict()
    for col in tool_list:
        allele_metrics_dict[col] = CalculateAlleleMetrics(test_df['mhc'], test_df['bind'], test_df[col])

    convert_dict = dict({'AUC':list(), 'AUC0.1':list(), 'AP':list(), 'PPV':list()})
    for method, allele_metrics in allele_metrics_dict.items():
        for allele, metrics_dict in allele_metrics.items():
            for metric, val in metrics_dict.items():
                convert_dict[metric].append({'allele': allele, 'value': val, 'method': method})

    json.dump(convert_dict, open('%s/AlleleMetrics.json'%output_dir, 'w'))

## Comparison between tools

### overall benchmark

In [None]:
# get metrics

metrics_dict = dict()
y = test_df['bind'].to_numpy()
for tool in tool_list:
    metrics_dict[tool] = CalculateMetrics(y, test_df[tool].to_numpy())

pd.DataFrame(metrics_dict).T

In [None]:
# plot AUC-ROC comparing with other tools

figsize = (3.5,3.5)
dpi = 600
fontsize = 10
linewidth = 1.5

fig = plt.figure(figsize=figsize, dpi=dpi)
plt.title('ROC', fontsize=fontsize)

for col in tool_list:
    pred = test_df[col]
    fpr, tpr, _ = metrics.roc_curve(test_df['bind'], pred)
    auc = metrics.auc(fpr, tpr)
    plt.plot(fpr, tpr, label='%s,AUC=%.3f'%(col, auc), linewidth=linewidth)

plt.legend(loc = 'lower right', fontsize=fontsize, handletextpad=0.2, borderpad=0.2)
plt.plot([0, 1], [0, 1], '--', color='black')
#plt.plot([0.1, 0.1], [0, 1], '--', color='red')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate', fontsize=fontsize)
plt.xlabel('False Positive Rate', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.show()

fig.tight_layout()
fig.savefig('%s/ComparisonROC.png'%output_dir)

In [None]:
# plot AUC-PRC comparing with other tools

figsize = (3.5,3.5)
dpi = 600
fontsize = 10
linewidth = 1.5

fig = plt.figure(figsize=figsize, dpi=dpi)
plt.title('PRC', fontsize=fontsize)

for col in tool_list:
    pred = test_df[col]
    precision, recall, thresholds = metrics.precision_recall_curve(test_df['bind'], pred)
    auc = metrics.average_precision_score(test_df['bind'], pred)
    plt.plot(precision, recall, label='%s,AP=%.3f'%(col, auc), linewidth=linewidth)
    
    # search max f1_score
    f1_scores = 2*recall*precision/(recall+precision)
    print('Tool: ', col)
    print('Best threshold: ', thresholds[np.argmax(f1_scores)])
    print('Best F1-Score: ', np.max(f1_scores))

plt.legend(loc = 'lower left', fontsize=fontsize, handletextpad=0.2, borderpad=0.2)
#plt.plot([0, 1], [0, 1],'--',color='black')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('Recall', fontsize=fontsize)
plt.xlabel('Precision', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.show()

fig.tight_layout()
fig.savefig('%s/ComparisonPRC.png'%output_dir)

### by alleles

In [None]:
# plot metrics by allele

stat_df = pd.DataFrame()

# box pairs
box_pairs = list()
for i in range(1, len(tool_list)):
    box_pairs.append((tool_list[0], tool_list[i]))

# by metrics
metric_list = ['AUC', 'AUC0.1', 'AP', 'PPV']
for i in range(len(metric_list)):
    metric = metric_list[i]
    temp_df = pd.DataFrame(convert_dict[metric])
    stat_list = ViolinPlotStat(temp_df, 'method', 'value', box_pairs,
                               '{}/AlleleGroup{}.png'.format(output_dir, metric), ylabel=metric)
    plt.show()
    temp_df = pd.DataFrame(stat_list)
    temp_df['metric'] = metric
    stat_df = pd.concat([stat_df, temp_df])

stat_df = PrintStatDF(stat_df)
display(stat_df)

## Unobserved Alleles

In [None]:
convert_dict_file = '{}/without_mixmhcpred/AlleleMetrics.json'.format(work_dir)
convert_dict = json.load(open(convert_dict_file, 'r'))

output_dir = '{}/unobserved'.format(work_dir)

metric_list = ['AUC', 'AUC0.1', 'AP', 'PPV']

unobserved_alleles = ['A*24:07', 'A*33:03', 'A*34:01', 'A*34:02', 'A*36:01', 'B*07:04', 'B*15:10', 'B*35:07',
                      'B*38:02', 'B*40:06', 'B*55:01', 'B*55:02', 'C*03:02', 'C*04:03', 'C*08:01', 'C*14:03']

# MixMHCpred2.1 misses B*35:07
consensus_rare_alleles = ['A*24:07', 'A*34:01', 'A*34:02', 'A*36:01',
                          'B*07:04', 'B*35:07', 'B*38:02', 'B*40:06',
                          'C*03:02', 'C*04:03', 'C*14:03']

In [None]:
# observed vs. unobserved alleles

metrics_df = pd.DataFrame()
stat_df = pd.DataFrame()

# box pairs
box_pairs = [('Observed Allele', 'Unobserved Allele')]

# by metrics
metric_list = ['AUC', 'AUC0.1', 'AP', 'PPV']
for i in range(len(metric_list)):
    metric = metric_list[i]
    temp_df = pd.DataFrame(convert_dict[metric])
    temp_df = temp_df[temp_df['method']=='MHCfovea']
    temp_df = temp_df.sort_values(by='allele')
    
    if metrics_df.shape[0] == 0:
        metrics_df = temp_df
        metrics_df = metrics_df.rename(columns={'value': metric})
        metrics_df['tag'] = 'Observed Allele'
        metrics_df.loc[metrics_df['allele'].isin(unobserved_alleles), 'tag'] = 'Unobserved Allele'
    else:
        metrics_df[metric] = temp_df['value']
    
    stat_list = ViolinPlotStat(metrics_df, 'tag', metric, box_pairs,
                               '{}/UnobservedAllele{}.png'.format(output_dir, metric), ylabel=metric, xtick_rotate=0)
    plt.show()
    temp_df = pd.DataFrame(stat_list)
    temp_df['metric'] = metric
    stat_df = pd.concat([stat_df, temp_df])

metrics_df.to_csv('{}/UnobservedAlleleMetrics.csv'.format(output_dir))
stat_df = PrintStatDF(stat_df)
display(stat_df)

In [None]:
# comparison with other tools on common unobserved alleles

convert_dict_file = '{}/with_mixmhcpred/AlleleMetrics.json'.format(work_dir)
convert_dict = json.load(open(convert_dict_file, 'r'))

stat_df = pd.DataFrame()

# box pairs
tool_list = ['MHCfovea', 'NetMHCpan4.1', 'MHCflurry2.0', 'MixMHCpred2.1']
box_pairs = list()
for i in range(1, len(tool_list)):
    box_pairs.append((tool_list[0], tool_list[i]))

# by metrics
metric_list = ['AUC', 'AUC0.1', 'AP', 'PPV']
for i in range(len(metric_list)):
    metric = metric_list[i]
    temp_df = pd.DataFrame(convert_dict[metric])
    temp_df = temp_df[temp_df['allele'].isin(consensus_rare_alleles)]
    stat_list = ViolinPlotStat(temp_df, 'method', 'value', box_pairs,
                               '{}/CommonUnobservedAllele{}.png'.format(output_dir, metric), ylabel=metric)
    plt.show()
    temp_df = pd.DataFrame(stat_list)
    temp_df['metric'] = metric
    stat_df = pd.concat([stat_df, temp_df])

stat_df = PrintStatDF(stat_df)
display(stat_df)