In [None]:
import os, sys, re, json, random, pickle, copy
import numpy as np
import pandas as pd
from collections import OrderedDict
import matplotlib.pyplot as plt
import seaborn as sns
import logomaker as lm
from MHCInterp import MHCInterp
import warnings
warnings.filterwarnings('ignore')

allele_expansion_dir = ''
cam_analysis_dir = '' # result from Analysis-ScoreCAM.ipynb
output_dir = ''
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

## Loading Data

In [None]:
mhc_seq_dict = json.load(open('../data/MHCI_res182_seq.json', 'r'))

mhc_motif_dict = dict()

for sub_dir in os.listdir(allele_expansion_dir):
    d = np.load('{}/{}/motif.npy'.format(allele_expansion_dir, sub_dir), allow_pickle=True)[()]
    mhc_motif_dict = {**mhc_motif_dict, **d}

submotif_len = 4

position_dict = json.load(open('%s/ResidueSelection.json'%cam_analysis_dir, 'r'))

## Clustering
Fig. 4 and Supplementary Fig. 6-7

### Pre-pdist + Agglomerative Clustering

In [None]:
interp = MHCInterp(mhc_seq_dict, mhc_motif_dict, submotif_len, position_dict, output_dir)

In [None]:
noise_threshold = 0

clustering_method = 'Agglomerative'
clustering_kwargs = {'Agglomerative_affinity': 'cosine',
                     'Agglomerative_linkage': 'complete',
                     'Agglomerative_distance_threshold': None,
                     'Agglomerative_n_clusters': None}

reduction_method = None
reduction_kwargs = {}

pre_pdist = True
metric = 'cosine'
method = 'complete'

highlight=False
load_file = False
turn_off_label = False

In [None]:
# [hla, terminus, cluster number]
args = [['A','N',8],
        ['A','C',8],
        ['B','N',8],
        ['B','C',5],
        ['C','N',6],
        ['C','C',2]]

# interpretation dict
interp_dict = dict()
interp_dict['seq'] = copy.deepcopy(mhc_seq_dict)
interp_dict['motif'] = copy.deepcopy(mhc_motif_dict)
interp_dict['important_positions'] = position_dict['selected']
interp_dict['cluster'] = dict()
interp_dict['hyper_motif'] = dict()
interp_dict['allele_signature'] = dict()

for hla, side, n_clusters in args:
    clustering_kwargs['Agglomerative_n_clusters'] = n_clusters
    labels, hyper_motif, allele_signature = interp.Clustering(hla,
                                                              side,
                                                              noise_threshold,
                                                              clustering_method,
                                                              clustering_kwargs,
                                                              reduction_method=reduction_method,
                                                              reduction_kwargs=reduction_kwargs,
                                                              pre_pdist=pre_pdist,
                                                              metric=metric,
                                                              method=method,
                                                              highlight=highlight,
                                                              load_file=load_file,
                                                              turn_off_label=turn_off_label)
    
    interp_dict['cluster']['%s_%s'%(hla, side)] = labels
    interp_dict['hyper_motif']['%s_%s'%(hla, side)] = hyper_motif
    interp_dict['allele_signature']['%s_%s'%(hla, side)] = allele_signature
    
pickle.dump(interp_dict, open('%s/interpretation.pkl'%output_dir, 'wb'))

## Grouping Counts

In [None]:
for hla in ['A', 'B', 'C']:
    N_group_df, C_group_df = interp.AlleleGrouping(hla)

## Combination of N-terminus and C-terminus
Fig. 5 and Supplementary Fig. 8

In [None]:
hla = 'B'
interp_dict = pickle.load(open('%s/interpretation.pkl'%output_dir, 'rb'))

### heatmap of training and testing dataset

In [None]:
def build_heatmap_df(df, N_terminus, C_terminus):
    heatmap_df = pd.DataFrame(columns=N_terminus, index=C_terminus)
    for n in heatmap_df.columns:
        for c in heatmap_df.index:
            heatmap_df.loc[c][n] = int(df[(df['N_terminus']==n) & (df['C_terminus']==c)].shape[0])
    heatmap_df = heatmap_df.sort_index()
    heatmap_df = heatmap_df[sorted(heatmap_df.columns)]
    return heatmap_df.astype(float)

In [None]:
# get alleles from training dataset
train_count = pd.read_csv('../data/train_dataset_count.csv', index_col=0)
train_alleles = [i for i in train_count.index.to_list() if hla in i]

# get alleles from training dataset
test_count = pd.read_csv('../data/test_dataset_count.csv', index_col=0)
test_alleles = [i for i in test_count.index.to_list() if hla in i]

# build label df
label_df = pd.DataFrame.from_dict(interp_dict['cluster']['%s_N'%hla], orient='index', columns=['N_terminus'])
label_df['C_terminus'] = pd.DataFrame.from_dict(interp_dict['cluster']['%s_C'%hla], orient='index', columns=['C_terminus'])['C_terminus']
label_df['group'] = label_df.index.to_series().apply(lambda x: x.split(':')[0])

# build heatmap df
heatmap_all_df = build_heatmap_df(label_df, label_df['N_terminus'].unique(), label_df['C_terminus'].unique())
heatmap_train_df = build_heatmap_df(label_df.loc[train_alleles], label_df['N_terminus'].unique(), label_df['C_terminus'].unique())
heatmap_test_df = build_heatmap_df(label_df.loc[test_alleles], label_df['N_terminus'].unique(), label_df['C_terminus'].unique())

In [None]:
# plot heatmap
print("heatmap of all alleles")
sns.heatmap(heatmap_all_df, cmap='Blues', linewidths=0.3, cbar=False, annot=True, fmt='g')
plt.show()

print("heatmap of training alleles")
sns.heatmap(heatmap_train_df, cmap='Blues', linewidths=0.3, cbar=False, annot=True, fmt='g')
plt.show()

print("heatmap of testing alleles")
sns.heatmap(heatmap_test_df, cmap='Blues', linewidths=0.3, cbar=False, annot=True, fmt='g')
plt.show()

### combination dataframe

In [None]:
select_comb_df = interp.Combining(hla, interp_dict)

In [None]:
# adjust the annotation of heatmap
N_class_num = len(select_comb_df.index.get_level_values(0).unique())
C_class_num = len(select_comb_df.index.get_level_values(1).unique())
labels = select_comb_df.groups.to_numpy().reshape(N_class_num, C_class_num)

if hla == 'B':
    labels[0][2] = 'B54  B55  B56\nB59  B78'
    labels[0][3] = 'B07  B08  B35\nB42  B56  B67\nB81  B82'
    labels[3][0] = 'B15  B57\nB58  B78'
    labels[6][2] = 'B40  B41  B45\nB49  B50'
    labels[6][3] = 'B40  B41\nB44  B47'
if hla == 'C':
    labels[2][0] = 'C02  C03\nC12  C16'
    labels[2][1] = 'C03  C15\nC16  C17'

In [None]:
# combination heatmap of all alleles
if hla == 'A':
    n_range = [0, 1, 2, 4, 5, 6, 7]
    c_range = [0, 1, 2, 3, 4, 5]
elif hla == 'B':
    n_range = list(range(7))
    c_range = list(range(5))
elif hla == 'C':
    n_range = list(range(1, 6))
    c_range = list(range(2))

fig, ax = plt.subplots(1, 1, figsize=((len(n_range)+1.5)*2, len(c_range)+3), dpi=600)
sns.set(font_scale=1.3)
sns.heatmap(np.log10(heatmap_all_df.loc[c_range, n_range]).replace(-np.inf, 0), cbar=True,
            cbar_kws={'orientation': 'vertical', 'pad': 0.01, 'fraction': 0.03, 'label': '$log_{10}$(allele number)'}, 
            linewidths=0.3, cmap='Blues', annot=labels[n_range][:, c_range].T, fmt='', ax=ax)

ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_xlabel('N-terminal hyper-motifs of HLA-%s'%hla, labelpad=80, fontsize=20)
ax.set_ylabel('C-terminal hyper-motifs of HLA-%s'%hla, labelpad=120, fontsize=20)

fig.tight_layout()
fig.savefig('%s/%s_CombinationHeatmap.png'%(output_dir, hla))
plt.show()

### demonstration

In [None]:
hla = 'B'
N_targets = {0: '#ccefff', 6: '#ffe6e6'}
C_targets = {2: '#ffffcc', 3: '#ccffdc'}
interp.Demo(hla, interp_dict, N_targets, C_targets)

## Comparison between HLA sequences
Supplementary Fig. 4

In [None]:
interp_dict = pickle.load(open('%s/interpretation.pkl'%output_dir, 'rb'))

train_count = pd.read_csv('../data/train_dataset_count.csv', index_col=0)
train_alleles = [i for i in train_count.index.to_list()]

genes = ['A', 'B', 'C']
positions = {'Prev': list(range(91)), 'Post': list(range(91,182))}

for pos_name, pos in positions.items():
    background_seqlogo = interp._mhc_seqlogo(train_alleles, pos)
    fig, ax = plt.subplots(len(genes), 1, figsize=(10, 2*len(genes)), dpi=600)
    for i in range(len(genes)):
        gene = genes[i]
        alleles = [j for j in train_alleles if gene in j]
        if i == len(genes) - 1:
            interp._mhcseq_plot(alleles, pos, ax[i], ylim=1, title=None,
                                diff_df=background_seqlogo, turn_off_label=False)
        else:
            interp._mhcseq_plot(alleles, pos, ax[i], ylim=1, title=None,
                                diff_df=background_seqlogo, turn_off_label=True)
    
    fig.tight_layout()
    fig.savefig('{}/HLASeqlogo{}.png'.format(output_dir, pos_name))