In [None]:
import os, sys, copy, re, random
from collections import OrderedDict
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from statannot import add_stat_annotation
import sklearn.metrics as metrics
from util import *
from scipy.stats.mstats import ttest_rel, ttest_ind

# Downsampling Performance

### Single Model

In [None]:
class DownsamplingPerformance():
    def __init__(self, prediction_file, decoy_fold_per_model, output_dir, index=None):
        self.df = pd.read_csv(prediction_file, index_col=0)
        self.decoy_fold_per_model = decoy_fold_per_model
        self.output_dir = output_dir
        if not os.path.isdir(self.output_dir):
            os.mkdir(self.output_dir)
        
        if index:
            self.df = self.df.loc[index].reset_index(drop=True)

        # metrics_dict
        bind = self.df['bind']
        decoy_list = sorted([int(i.split('_')[1]) for i in self.df.columns if 'decoy' in i])
        self.metrics_dict = OrderedDict()
        for i in tqdm(range(len(decoy_list))):
            cols = ['decoy_%d'%i for i in decoy_list[:i+1]]
            pred = self.df[cols].mean(axis=1)
            self.metrics_dict[decoy_list[i]] = CalculateMetrics(bind, pred)
    
    
    def __call__(self):
        # save metrics_dict
        json.dump(self.metrics_dict, open('%s/MetricsDict.json'%self.output_dir, 'w'))
        
        # plot metrics
        metrics_list = ['AUC', 'AUC0.1', 'AP', 'PPV']
        for m in metrics_list:
            self._plot_metrics(m)
    
    
    def _plot_metrics(self, metrics_name, savefig=True, figsize=(3.5,3.5), dpi=600, fontsize=10, linewidth=1.5):
        metrics_label = {'AUC': 'AUC-ROC', 'AP': 'AUC-PRC', 'AUC0.1': 'AUC-ROC 0.1', 'PPV': 'PPV'}
        fig = plt.figure(figsize=figsize, dpi=dpi)
        x = list(self.metrics_dict.keys())
        y = [self.metrics_dict[i][metrics_name] for i in x]
        sns.lineplot(x=x, y=y, linewidth=linewidth)
        plt.title(metrics_label[metrics_name], fontsize=fontsize)
        plt.xlabel('score', fontsize=fontsize)
        plt.ylabel('decoy number', fontsize=fontsize)
        plt.xticks(fontsize=fontsize)
        plt.yticks(fontsize=fontsize)
        fig.tight_layout()
        if savefig:
            fig.savefig('%s/Metrics_%s.png'%(self.output_dir, metrics_name))    

In [None]:
# Arguments

select_decoy = True
downsampling_factors = [1,5,10,15,30]
downsampling_outdir = '../analysis/performance/downsampling'
if not os.path.isdir(downsampling_outdir):
    os.mkdir(downsampling_outdir)

In [None]:
# Valid dataframe index for 30n decoys

if select_decoy:
    valid_df = pd.read_csv('../data/raw/dataframe/valid.csv', index_col=0)
    index = list(valid_df[valid_df['source']=='MS'].index)
    num = len(index)*15
    index += list(valid_df[valid_df['source']=='assay'].index)
    index += list(valid_df[valid_df['source']=='protein_decoy'].index)
    index += list(valid_df[valid_df['source'].str.contains('random_decoy')].sample(n=num, random_state=0).index)
else:
    index = None

In [None]:
# Downsampling results
for factor in downsampling_factors:
    model_num = 93//factor
    prediction_file = '../prediction/valid/res182_decoy%d_CNN_1_1_%d/tmp_prediction.csv'%(factor, model_num)
    output_dir = '%s/res182_decoy%d_CNN_1_1'%(downsampling_outdir, factor)
    DP = DownsamplingPerformance(prediction_file, factor, output_dir, index)
    DP()
    print("Factor %d Complete"%factor)

### Comparison between models

In [None]:
# Arguments

downsampling_factors = [1,5,10,15,30]
comparison_decoy_num_list = [30, 60, 90]
downsampling_outdir = '../analysis/performance/downsampling'

metrics_list = ['AP', 'AUC']

figsize = (4.5,3.5)
dpi = 600
fontsize = 10
linewidth = 1.5

In [None]:
fig, ax = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
current_ax = 0

for metrics_name in metrics_list:
    # Comparison
    metrics_dict = dict()
    for factor in downsampling_factors:
        file = '%s/res182_decoy%d_CNN_1_1/MetricsDict.json'%(downsampling_outdir, factor)
        metrics_dict[factor] = json.load(open(file, 'r'))

    comparison_list = list()
    for decoy_num in comparison_decoy_num_list:
        for factor in downsampling_factors:
            comparison_list.append({
                'downsampling_factor': 'factor_%d'%factor,
                'decoy_num': decoy_num,
                metrics_name: metrics_dict[factor][str(decoy_num-factor+1)][metrics_name]
            })
    comparison_df = pd.DataFrame(comparison_list)
    
    # Plot
    sns.lineplot(data=comparison_df, x='decoy_num', y=metrics_name, hue='downsampling_factor',
                 linewidth=linewidth, ax=ax[current_ax])
    ax[current_ax].set_title(metrics_name, fontsize=fontsize)
    ax[current_ax].set_xlabel(None)
    ax[current_ax].set_ylabel(None)
    ax[current_ax].set_xticks([i for i in comparison_decoy_num_list])
    ax[current_ax].set_xticklabels(['%dn'%i for i in comparison_decoy_num_list], fontsize=fontsize)
    for i in ax[current_ax].get_yticklabels():
        i.set_fontsize(fontsize)
    
    h, l = ax[current_ax].get_legend_handles_labels()
    ax[current_ax].get_legend().remove()
    
    current_ax += 1

fig.tight_layout()

l = ["Downsampling Factor"] + downsampling_factors
leg = fig.legend(h, l, fontsize=fontsize,
                 ncol=6, loc="upper center", bbox_to_anchor=(0, 1, 1, 0.1), 
                 columnspacing=1, handlelength=0.5, handletextpad=0.2, borderpad=0.2)

fig.add_subplot(111, frame_on=False)
plt.tick_params(labelcolor="none", bottom=False, left=False)
plt.xlabel("Decoy Number", fontsize=fontsize)

fig.savefig('%s/DownsamplingComparison.png'%downsampling_outdir, bbox_inches='tight')

# Model Performance

In [None]:
# load df

test_file = '../prediction/test/prediction.csv'
test_df = pd.read_csv(test_file, index_col=0)

In [None]:
# arguments

with_mixmhcpred = False
previous_convert_dict = False

if with_mixmhcpred:
    tool_list = ['MHCfovea', 'NetMHCpan4.1', 'MHCflurry2.0', 'MixMHCpred2.1']
    test_df = test_df[~test_df['MixMHCpred2.1'].isna()]
else:
    tool_list = ['MHCfovea', 'NetMHCpan4.1', 'MHCflurry2.0']

if with_mixmhcpred:
    output_dir = '../analysis/performance/tool_comparison/with_mixmhcpred/'
else:
    output_dir = '../analysis/performance/tool_comparison/without_mixmhcpred/'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

In [None]:
# plot AUC-ROC comparing with other tools

figsize = (3.5,3.5)
dpi = 600
fontsize = 10
linewidth = 1.5

fig = plt.figure(figsize=figsize, dpi=dpi)
plt.title('ROC', fontsize=fontsize)

for col in tool_list:
    pred = test_df[col]
    fpr, tpr, _ = metrics.roc_curve(test_df['bind'], pred)
    auc = metrics.auc(fpr, tpr)
    plt.plot(fpr, tpr, label='%s,AUC=%.3f'%(col, auc), linewidth=linewidth)

plt.legend(loc = 'lower right', fontsize=fontsize, handletextpad=0.2, borderpad=0.2)
plt.plot([0, 1], [0, 1], '--', color='black')
#plt.plot([0.1, 0.1], [0, 1], '--', color='red')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate', fontsize=fontsize)
plt.xlabel('False Positive Rate', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.show()

fig.tight_layout()
fig.savefig('%s/ComparisonROC.png'%output_dir)

In [None]:
# plot AUC-PRC comparing with netMHCpan4.1 and MHCflurry2.0

figsize = (3.5,3.5)
dpi = 600
fontsize = 10
linewidth = 1.5

fig = plt.figure(figsize=figsize, dpi=dpi)
plt.title('PRC', fontsize=fontsize)

for col in tool_list:
    pred = test_df[col]
    precision, recall, _ = metrics.precision_recall_curve(test_df['bind'], pred)
    auc = metrics.average_precision_score(test_df['bind'], pred)
    plt.plot(precision, recall, label='%s,AP=%.3f'%(col, auc), linewidth=linewidth)

plt.legend(loc = 'lower left', fontsize=fontsize, handletextpad=0.2, borderpad=0.2)
#plt.plot([0, 1], [0, 1],'--',color='black')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('Recall', fontsize=fontsize)
plt.xlabel('Precision', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.show()

fig.tight_layout()
fig.savefig('%s/ComparisonPRC.png'%output_dir)

In [None]:
# get allele metrics

if previous_convert_dict:
    convert_dict = json.load(open('%s/AlleleMetrics.json'%output_dir, 'r'))

else:
    allele_metrics_dict = dict()
    for col in tool_list:
        allele_metrics_dict[col] = CalculateAlleleMetrics(test_df['mhc'], test_df['bind'], test_df[col])

    convert_dict = dict({'AUC':list(), 'AUC0.1':list(), 'AP':list(), 'PPV':list()})
    for method, allele_metrics in allele_metrics_dict.items():
        for allele, metrics_dict in allele_metrics.items():
            for metric, val in metrics_dict.items():
                convert_dict[metric].append({'allele': allele, 'value': val, 'method': method})

    json.dump(convert_dict, open('%s/AlleleMetrics.json'%output_dir, 'w'))

In [None]:
# plot metrics by allele

figsize = (3.5,3.5)
dpi = 600
fontsize = 10
linewidth = 1.5

metric_list = ['AUC', 'AUC0.1', 'AP', 'PPV']
box_pairs = list()

for i in range(1, len(tool_list)):
    box_pairs.append((tool_list[0], tool_list[i]))

for i in range(len(metric_list)):
    metric = metric_list[i]
    temp_df = pd.DataFrame(convert_dict[metric])
    fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
    sns.stripplot(x='method', y='value', data=temp_df, ax=ax, s=2)
    sns.violinplot(x='method', y='value', data=temp_df, ax=ax, color=".8")
    
    ax.set_title(metric)
    ax.set_xlabel('')
    ax.set_ylabel('Value')
    ax.set_yticks([i for i in ax.get_yticks() if i <= 1.0])
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(fontsize)
    for item in ax.get_xticklabels():
        item.set_rotation(45)
    
    test_results = add_stat_annotation(ax=ax, data=temp_df, x='method', y='value',
                                       box_pairs=box_pairs, test='t-test_ind', comparisons_correction=None,
                                       text_format='star', loc='inside',
                                       fontsize=fontsize, linewidth=linewidth,
                                       line_offset_to_box=0.15)
    
    fig.savefig('%s/AlleleGroup%s.png'%(output_dir, metric), bbox_inches='tight')

## Unobserved Alleles
set(alleles of testing dataset) - set(alleles of training hit dataset, # > 100)

In [None]:
output_dir = '../analysis/performance/tool_comparison/without_mixmhcpred/'
convert_dict = json.load(open('%s/AlleleMetrics.json'%output_dir, 'r'))

figsize = (3.5,3.5)
dpi = 600
fontsize = 10
linewidth = 1.5

metric_list = ['AUC', 'AUC0.1', 'AP', 'PPV']

rare_alleles = ['A*02:05', 'A*11:02', 'A*24:07', 'A*33:03', 'A*34:01',
                'A*34:02', 'A*36:01', 'A*74:01', 'B*07:04', 'B*13:01',
                'B*13:02', 'B*15:10', 'B*35:07', 'B*37:01', 'B*38:02',
                'B*40:06', 'B*52:01', 'B*55:01', 'B*55:02', 'B*58:02',
                'C*03:02', 'C*04:03', 'C*07:04', 'C*08:01', 'C*14:03']

consensus_rare_alleles = ['A*24:07', 'A*34:01', 'A*34:02', 'A*36:01',
                          'B*07:04', 'B*35:07', 'B*38:02', 'B*40:06',
                          'C*03:02', 'C*04:03', 'C*14:03']

In [None]:
# unobserved alleles v.s. observed alleles

box_pairs=[('Observed Allele', 'Unobserved Allele')]

for i in range(len(metric_list)):
    metric = metric_list[i]
    temp_df = pd.DataFrame(convert_dict[metric])
    temp_df = temp_df[temp_df['method']=='MHCfovea']
    temp_df['tag'] = 'Observed Allele'
    temp_df.loc[temp_df['allele'].isin(rare_alleles), 'tag'] = 'Unobserved Allele'
    
    fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
    sns.stripplot(x='tag', y='value', data=temp_df, ax=ax, s=2)
    sns.violinplot(x='tag', y='value', data=temp_df, ax=ax, color=".8")
    
    ax.set_title(metric)
    ax.set_xlabel('')
    ax.set_ylabel('Value')
    ax.set_yticks([i for i in ax.get_yticks() if i <= 1.0])
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(fontsize)
    
    test_results = add_stat_annotation(ax=ax, data=temp_df, x='tag', y='value',
                                       box_pairs=box_pairs, test='t-test_ind', comparisons_correction=None,
                                       text_format='star', loc='inside',
                                       fontsize=fontsize, linewidth=linewidth,
                                       line_offset_to_box=0.15)
    
    fig.savefig('../analysis/performance/rare_allele/RareAllele%s.png'%metric, bbox_inches='tight')

In [None]:
# comparison with other tools on unobserved alleles

tool_list = ['MHCfovea', 'NetMHCpan4.1', 'MHCflurry2.0']
box_pairs = list()

for i in range(1, len(tool_list)):
    box_pairs.append((tool_list[0], tool_list[i]))

for i in range(len(metric_list)):
    metric = metric_list[i]
    temp_df = pd.DataFrame(convert_dict[metric])
    temp_df = temp_df[temp_df['allele'].isin(consensus_rare_alleles)]
    fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
    sns.stripplot(x='method', y='value', data=temp_df, ax=ax, s=2)
    sns.violinplot(x='method', y='value', data=temp_df, ax=ax, color=".8")
    
    ax.set_title(metric)
    ax.set_xlabel('')
    ax.set_ylabel('Value')
    ax.set_yticks([i for i in ax.get_yticks() if i <= 1.0])
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(fontsize)
    for item in ax.get_xticklabels():
        item.set_rotation(45)
    
    test_results = add_stat_annotation(ax=ax, data=temp_df, x='method', y='value',
                                       box_pairs=box_pairs, test='t-test_ind', comparisons_correction=None,
                                       text_format='star', loc='inside',
                                       fontsize=fontsize, linewidth=linewidth,
                                       line_offset_to_box=0.15)
    
    fig.savefig('../analysis/performance/rare_allele/RareAlleleGroup%s.png'%metric, bbox_inches='tight')