In [None]:
import os, sys, re, json, random, pickle
import numpy as np
import pandas as pd
from collections import OrderedDict
import matplotlib.pyplot as plt
import seaborn as sns
import logomaker as lm
from statannot import add_stat_annotation
from CAMInterp import CAMInterp
from MHCInterp import MHCInterp
from scipy.stats import ttest_ind, t
import warnings
warnings.filterwarnings('ignore')

dataframe_dir = ''
allele_expansion_dir = ''
summarization_dir = '' # the output directory of Analysis-summarization.ipynb
cam_analysis_dir = '' # the output directory of Analysis-ScoreCAM.ipynb
performance_dir = '' # the working directory of Analysis-performance.ipynb 
output_dir = ''
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

In [None]:
def WelchTest(x1, x2):
    np.set_printoptions(precision=2)
    n1 = x1.size
    n2 = x2.size
    m1 = np.mean(x1)
    m2 = np.mean(x2)
    v1 = np.var(x1, ddof=1)
    v2 = np.var(x2, ddof=1)
    
    pooled_se = np.sqrt(v1 / n1 + v2 / n2)
    delta = m1-m2
    
    tstat = delta / pooled_se
    df = (v1 / n1 + v2 / n2)**2 / (v1**2 / (n1**2 * (n1 - 1)) + v2**2 / (n2**2 * (n2 - 1)))
    
    # two side t-test
    p = 2 * t.cdf(-abs(tstat), df)
    
    # upper and lower bounds
    lb = delta - t.ppf(0.975,df)*pooled_se 
    ub = delta + t.ppf(0.975,df)*pooled_se
    
    # stat dict
    stat_dict = {
        'n': [n1,n2],
        'm': [m1,m2],
        'sd': [np.sqrt(v1), np.sqrt(v2)],
        'df': df,
        'psd': pooled_se,
        'tstat': tstat,
        'delta': delta,
        'pvalue': p,
        'lb': lb,
        'ub': ub
    }
  
    return stat_dict


def PrintStatDF(df):
    temp_df = pd.DataFrame(index=['n','m','sd','df','psd','tstat','pvalue','lb','ub'])
    for idx, row in df.iterrows():
        temp_df['{}_{}'.format(row['pair1'], idx)] = [
            row['n'][0],
            row['m'][0],
            row['sd'][0],
            row['df'],
            row['psd'],
            row['tstat'],
            row['pvalue'],
            row['lb'],
            row['ub']
        ]

        temp_df['{}_{}'.format(row['pair2'], idx)] = [
            row['n'][1],
            row['m'][1],
            row['sd'][1],
            row['df'],
            row['psd'],
            row['tstat'],
            row['pvalue'],
            row['lb'],
            row['ub']
        ]
    return temp_df

## Loading data

In [None]:
hla = 'B'

# MHC sequences
mhc_seq_filename = '../data/MHCI_res182_seq.json'
mhc_seq_dict = json.load(open(mhc_seq_filename, 'r'))

# dataset
train_df = pd.read_csv('%s/train_hit.csv'%dataframe_dir, index_col=0)
valid_df = pd.read_csv('%s/valid.csv'%dataframe_dir, index_col=0)
test_df = pd.read_csv('%s/benchmark.csv'%dataframe_dir, index_col=0)

# hyper-motif cluster
nside_df = pd.read_csv('%s/%s_NsideDF.csv'%(summarization_dir, hla), index_col=0)
cside_df = pd.read_csv('%s/%s_CsideDF.csv'%(summarization_dir, hla), index_col=0)

# MHCInterp
submotif_len = 4
position_dict = json.load(open('%s/ResidueSelection.json'%cam_analysis_dir, 'r'))

mhc_motif_dict = dict()
for sub_dir in os.listdir(allele_expansion_dir):
    d = np.load('{}/{}/motif.npy'.format(allele_expansion_dir, sub_dir), allow_pickle=True)[()]
    mhc_motif_dict = {**mhc_motif_dict, **d}

tmp = 'tmp/'
if not os.path.isdir(tmp):
    os.mkdir(tmp)

mhc_interp = MHCInterp(mhc_seq_dict, mhc_motif_dict, submotif_len, position_dict, tmp)

In [None]:
multicluster_groups = list()
minor_frequency = 0.1
minimal_num = 25

for hla in ['A', 'B', 'C']:
    Nside_group_df = pd.read_csv('%s/%s_NsideGroupCount.csv'%(summarization_dir, hla), index_col=0)
    Cside_group_df = pd.read_csv('%s/%s_CsideGroupCount.csv'%(summarization_dir, hla), index_col=0)
    
    threshold = Nside_group_df.sum(axis=0)*minor_frequency
    threshold[threshold >= minimal_num] = minimal_num
    temp_df = (Nside_group_df - threshold > 0).sum(axis=0)
    multicluster_groups += temp_df[temp_df > 1].index.tolist()
    
    threshold = Cside_group_df.sum(axis=0)*minor_frequency
    threshold[threshold > minimal_num] = minimal_num
    temp_df = (Cside_group_df - threshold > 0).sum(axis=0)
    multicluster_groups += temp_df[temp_df > 1].index.tolist()

multicluster_groups = sorted(list(set(multicluster_groups)))
print(multicluster_groups)

## HLA group polymorphism

In [None]:
# group dict
group_dict = dict()
for allele in mhc_seq_dict.keys():
    group = allele.split(':')[0]
    if group_dict.get(group):
        group_dict[group].append(allele)
    else:
        group_dict[group] = [allele]

In [None]:
# group polymorphism df
positions = list(range(182))
group_polymorphism_dict = dict()
allele_num_dict = dict()
for group in group_dict.keys():
    alleles = group_dict[group]
    allele_num_dict[group] = len(alleles)
    seqlogo_df = mhc_interp._mhc_seqlogo(alleles, positions)
    polymorphism = -(seqlogo_df*np.log(seqlogo_df)).sum(axis=1).values
    group_polymorphism_dict[group] = polymorphism
group_polymorphism_df = pd.DataFrame(group_polymorphism_dict).T
group_polymorphism_df = group_polymorphism_df.rename(columns={i:positions[i] for i in group_polymorphism_df.columns})

# allele number
group_polymorphism_df['num'] = pd.Series(allele_num_dict)

# mean of positions
res34_pos = [6, 8, 23, 44, 58, 61, 62, 65, 66, 68, 69, 72, 73, 75, 76, 79, 80, 83, 94,
             96, 98, 113, 115, 117, 142, 146, 149, 151, 155, 157, 158, 162, 166, 170]
group_polymorphism_df['all'] = group_polymorphism_df[positions].mean(axis=1)
group_polymorphism_df['important'] = group_polymorphism_df[mhc_interp.position_dict['selected']].mean(axis=1)
group_polymorphism_df['multicluster'] = 'mono-cluster'
group_polymorphism_df.loc[multicluster_groups, 'multicluster'] = 'multi-cluster'
##group_polymorphism_df['34-residue'] = group_polymorphism_df[res34_pos].mean(axis=1)
group_polymorphism_df.to_csv('%s/GroupPolymorphism.csv'%output_dir)
group_polymorphism_df.head()

In [None]:
# temp df for seaborn
temp = list()
for idx, row in group_polymorphism_df.iterrows():
    temp.append({'group': idx, 'type': 'all', 'multicluster': row['multicluster'], 'polymorphism': row['all']})
    temp.append({'group': idx, 'type': 'important', 'multicluster': row['multicluster'], 'polymorphism': row['important']})
    ##temp.append({'group': idx, 'type': '34-residue', 'multicluster': row['multicluster'], 'polymorphism': row['34-residue']})
temp = pd.DataFrame(temp)
temp.head()

In [None]:
# barplot
fig, ax = plt.subplots(3, figsize=(8,12), dpi=600)
hla_list = ['A', 'B', 'C']
for i in range(len(hla_list)):
    temp_hla = hla_list[i]
    groups = [j for j in temp['group'] if temp_hla in j]
    sns.barplot(x='group', y='polymorphism', hue='type', data=temp[temp['group'].isin(groups)], ax=ax[i])
    ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation = 90)
    ax[i].set_title(temp_hla)

fig.tight_layout()
fig.savefig('%s/PolyCompBarplot.png'%output_dir)

In [None]:
# boxplot
data = temp

fig, ax = plt.subplots(1, figsize=(6,3), dpi=600)
sns.boxplot(x='type', y='polymorphism', hue='multicluster', data=data, ax=ax)

box_pairs = [(('all', 'mono-cluster'), ('all', 'multi-cluster')),
             (('important', 'mono-cluster'), ('important', 'multi-cluster'))]
test_results = add_stat_annotation(ax=ax, data=data, x='type', y='polymorphism', hue='multicluster',
                                   box_pairs=box_pairs, comparisons_correction=None,
                                   test='t-test_ind', stats_params={'equal_var': False},
                                   text_format='star', loc='outside',
                                   fontsize=mhc_interp.fontsize, linewidth=0.3,
                                   line_offset_to_box=0.05)

stat_dict_list = list()
for p1, p2 in box_pairs:
    stat_dict = WelchTest(data[(data['type']==p1[0]) & (data['multicluster']==p1[1])]['polymorphism'],
                          data[(data['type']==p2[0]) & (data['multicluster']==p2[1])]['polymorphism'])
    stat_dict['pair1'] = '{}: {}'.format(p1[0], p1[1])
    stat_dict['pair2'] = '{}: {}'.format(p1[0], p2[1])
    stat_dict_list.append(stat_dict)

_ = ax.legend(title='HLA groups')
_ = ax.set_xlabel(None)
_ = ax.set_xticklabels(['all positions (182 a.a.)', 'important positions (42 a.a.)'])

fig.tight_layout()
fig.savefig('%s/PolyCompBoxplot.png'%output_dir)

stat_df = PrintStatDF(pd.DataFrame(stat_dict_list))
stat_df

## Unobserved alleles within multi-cluster groups

In [None]:
pf = json.load(open('%s/without_mixmhcpred/AlleleMetrics.json'%performance_dir, 'r'))

target = "Unobserved" # Rare or Unobserved

unobserved_alleles = ['A*24:07', 'A*33:03', 'A*34:01', 'A*34:02', 'A*36:01', 'B*07:04', 'B*15:10', 'B*35:07',
                      'B*38:02', 'B*40:06', 'B*55:01', 'B*55:02', 'C*03:02', 'C*04:03', 'C*08:01', 'C*14:03']

multicluster_groups = list(group_polymorphism_df[group_polymorphism_df['multicluster']=='multi-cluster'].index)

target_alleles = list()
for allele in unobserved_alleles:
    group = allele.split(':')[0]
    if group in multicluster_groups:
        target_alleles.append(allele)

other_alleles = list(set(unobserved_alleles) - set(target_alleles))

In [None]:
metrics_list = ['AUC', 'AP']
data = pd.DataFrame()

for metrics in metrics_list:
    temp = pd.DataFrame(pf[metrics])
    temp = temp[(temp['method']=='MHCfovea') & (temp['allele'].isin(alleles))]
    temp['multicluster'] = 'mono-cluster'
    temp.loc[temp['allele'].isin(target_alleles), 'multicluster'] = 'multi-cluster'
    temp['metrics'] = metrics
    data = pd.concat([data, temp], axis=0, ignore_index=True)

data = data.sort_values(by=['metrics', 'allele'])
data.to_csv('%s/GroupUnobserved.csv'%output_dir)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(6,3), dpi=600)
stat_dict_list = list()
for i in range(len(metrics_list)):
    metrics = metrics_list[i]
    temp_data = data[data['metrics']==metrics]
    
    sns.boxplot(x='multicluster', y='value', data=temp_data, ax=ax[i])
    ax[i].set_yticks([i for i in ax[i].get_yticks() if i <= 1.0])
    
    box_pairs = [('mono-cluster', 'multi-cluster')]
    test_results = add_stat_annotation(ax=ax[i], data=temp_data, x='multicluster', y='value',
                                       box_pairs=box_pairs, comparisons_correction=None,
                                       test='t-test_ind', stats_params={'equal_var': False},
                                       text_format='star', loc='outside',
                                       fontsize=mhc_interp.fontsize, linewidth=0.3,
                                       line_offset_to_box=0.05)
    
    for p1, p2 in box_pairs:
        stat_dict = WelchTest(temp_data[temp_data['multicluster']==p1]['value'],
                              temp_data[temp_data['multicluster']==p2]['value'])
        stat_dict['pair1'] = '{}: {}'.format(metrics, p1)
        stat_dict['pair2'] = '{}: {}'.format(metrics, p2)
        stat_dict_list.append(stat_dict)
    
    ax[i].set_xlabel(None)
    ax[i].set_ylabel(metrics)

fig.tight_layout()
fig.savefig('{}/MultiClusterComp{}.png'.format(output_dir, target))
stat_df = PrintStatDF(pd.DataFrame(stat_dict_list))
stat_df

## Examples of HLA group

In [None]:
def seqlogo_plot(self, seqlogo_df, positions, ax, highlight_pos_dict=dict(),
                 ylim=1, title=None, turn_off_label=False):
    logo = lm.Logo(seqlogo_df, color_scheme='skylign_protein', ax=ax)
    
    _ = ax.set_ylim(0, ylim)
    _ = ax.set_xticks(range(len(positions)))
    _ = ax.set_xticklabels([i+1 for i in positions], rotation=90)
    _ = ax.set_title(title)
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(self.fontsize)
    for item in ax.get_xticklabels():
        item.set_fontsize(self.fontsize-3)
    
    _ = ax.set_yticks([])
    _ = ax.set_yticklabels([])
    if turn_off_label:
        _ = ax.set_xticks([])
        _ = ax.set_xticklabels([])
        _ = ax.set_title(None)
        
    if highlight_pos_dict != dict():
        for color, highlight_pos in highlight_pos_dict.items():
            highlight_pos = sorted(set(highlight_pos) & set(positions))
            for pos in highlight_pos:
                logo.highlight_position(p=positions.index(pos), color=color)

In [None]:
hla = 'B'
nside_df = pd.read_csv('%s/%s_NsideDF.csv'%(summarization_dir, hla), index_col=0)
cside_df = pd.read_csv('%s/%s_CsideDF.csv'%(summarization_dir, hla), index_col=0)

# add hla group
train_df['hla_group'] = train_df.mhc.apply(lambda x: x.split(':')[0])
valid_df['hla_group'] = valid_df.mhc.apply(lambda x: x.split(':')[0])
test_df['hla_group'] = test_df.mhc.apply(lambda x: x.split(':')[0])
nside_df['hla_group'] = [x.split(':')[0] for x in nside_df.index]
nside_df = nside_df[['label', 'select_label', 'hla_group']]
cside_df['hla_group'] = [x.split(':')[0] for x in cside_df.index]
cside_df = cside_df[['label', 'select_label', 'hla_group']]

In [None]:
# build group df of training and testing dataset
group = 'B*56'
side = 'C'

if side == 'N':
    side_df = nside_df
else:
    side_df = cside_df

alleles = list(set(list(train_df[train_df['hla_group']==group].mhc.unique()) + list(test_df[test_df['hla_group']==group].mhc.unique())))

group_dict = list()
for allele in alleles:
    try:
        label = side_df.loc[allele]['label']
    except:
        label = -1
    group_dict.append({
        'allele': allele,
        'train_num': train_df[train_df['mhc']==allele].shape[0],
        'valid_num': valid_df[(valid_df['mhc']==allele) & (valid_df['source'].isin(['MS', 'assay']))].shape[0],
        'test_num': test_df[(test_df['mhc']==allele) & (test_df['source']=='MS')].shape[0],
        'label': label
    })
group_df = pd.DataFrame(group_dict)
group_df.sort_values(by='allele')

print(group_df)
print('-----------------------------')
print('label counts')
print(group_df.label.value_counts())

In [None]:
# important positions
positions = mhc_interp.position_dict['selected']

# backgroup seqlogo_df
alleles = side_df[side_df['select_label'] != -1].index.tolist()
hla_seqlogo_df = mhc_interp._mhc_seqlogo(alleles, positions)

# clusters
clusters = sorted(list(side_df[side_df['hla_group']==group].label.unique()))

print('Clusters: ', clusters)
print('Value counts of clusters')
print(side_df[side_df['hla_group']==group].label.value_counts())

In [None]:
'''
# B15
target_clusters = [2,3,4]
highlight_pos_dict = {'#f2f2f2': [44,61,62,64,65,66,68,69,70]}
'''
# B56
target_clusters = [2,3,4]
highlight_pos_dict = {'#f2f2f2': [93,94,96,97,108,113,115]}

figfile = '%s/MultiCluster_%s%s.png'%(output_dir, group[0], group[2:])

In [None]:
# plot
fig, ax = plt.subplots(len(target_clusters), 2, figsize=(6, len(target_clusters)*1.5), dpi=600,
                       gridspec_kw={'width_ratios': [1, 5]})

for i in range(len(target_clusters)):
    cluster = target_clusters[i]
    temp_alleles = list(side_df[(side_df['hla_group']==group) & (side_df['label']==cluster)].index)
    num = len(temp_alleles)
    
    # hyper motif of a specific group
    mhc_interp._motif_plot(temp_alleles, side, ax[i][0], turn_off_label=True)
    
    # allele signature of a specific group
    group_seqlogo_df = mhc_interp._mhc_seqlogo(temp_alleles, positions)
    group_seqlogo_df = group_seqlogo_df - hla_seqlogo_df
    group_seqlogo_df[group_seqlogo_df > 0] = 1
    group_seqlogo_df[group_seqlogo_df < 0] = 0
    
    # allele signature of a corresponding cluster
    temp_alleles = list(side_df[side_df['select_label']==cluster].index)
    cluster_seqlogo_df = mhc_interp._mhc_seqlogo(temp_alleles, positions)
    cluster_seqlogo_df = cluster_seqlogo_df - hla_seqlogo_df
    cluster_seqlogo_df[cluster_seqlogo_df < 0] = 0
    
    # highlight
    seqlogo_df = group_seqlogo_df * cluster_seqlogo_df
    
    seqlogo_plot(mhc_interp, seqlogo_df, positions, ax[i][1], highlight_pos_dict=highlight_pos_dict,
                 ylim=1, turn_off_label=False)

fig.tight_layout()
fig.savefig(figfile)