In [None]:
import os, sys, re, json, random, importlib
import numpy as np
import pandas as pd
from collections import OrderedDict
from tqdm import tqdm
import torch
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import logomaker as lm
from scipy.cluster import hierarchy
from util import *
from CAMInterp import CAMInterp
import warnings
warnings.filterwarnings('ignore')

cam_result_dir = ''
train_predict_dir = ''
dataframe_dir = ''
output_dir = ''
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

# CAM analysis
Fig. 3

In [None]:
# arguments
mhc_seq_filename = '../data/MHCI_res182_seq.json'
allele_mask_dirname = ['{}/decoy_{}/mhc_2/ScoreCAM/'.format(cam_result_dir, i) for i in range(1, 87, 5)]
epitope_mask_dirname = ['{}/decoy_{}/epitope_0/ScoreCAM'.format(cam_result_dir, i) for i in range(1, 87, 5)]
df_filename = '{}/prediction.csv'.format(train_predict_dir)

pred_basename = 'score'
pred_threshold = 0.9

In [None]:
# interpretation class
interp = CAMInterp(mhc_seq_filename, allele_mask_dirname, epitope_mask_dirname, df_filename, output_dir,
                   pred_basename=pred_basename, pred_threshold=pred_threshold)

In [None]:
# Residue analysis
cam_threshold = 0.4
importance_threshold = 0.4
barplot_figsize = (8, 2.1)
square_figsize = (3.5, 3.5)
interp.ResidueAnalysis(cam_threshold, importance_threshold,
                       barplot_figsize=barplot_figsize, square_figsize=square_figsize)

In [None]:
# Cluster analysis
method = 'average'
metric = 'euclidean'
allele_figsize= (10, 2)
epitope_figsize= (3.5, 3.5)
interp.ClusterAnalysis(method, metric,
                       allele_figsize=allele_figsize, epitope_figsize=epitope_figsize)

# Pairwise distance
Supplementary Fig. 5

In [None]:
def _motif_plot(alleles, motif_dict, dpi=600, figfile=None):
    aa_str = 'ACDEFGHIKLMNPQRSTVWY'
    allele_num = len(alleles)
    fig, ax = plt.subplots(allele_num, figsize=(0.8, allele_num*0.2), dpi=dpi)
    for i in range(allele_num):
        allele = alleles[i]
        seqlogo_df = pd.DataFrame(motif_dict[allele], columns=list(aa_str))
        logo = lm.Logo(seqlogo_df, ax=ax[i], color_scheme="skylign_protein")
        _ = ax[i].set_xticks([])
        _ = ax[i].set_yticks([])
        for side in ['top','bottom','left','right']:
            ax[i].spines[side].set_linewidth(0.1)

    fig.tight_layout()
    if figfile:
        fig.savefig(figfile)

In [None]:
df = pd.read_csv('%s/train_hit.csv'%dataframe_dir, index_col=0)
min_sample_num = 100
submotif_len = 4
aa_str = 'ACDEFGHIKLMNPQRSTVWY'
dpi = 600

alleles = df['mhc'].unique()
motif_dict = dict()

for allele in alleles:
    seqs = df.loc[(df['mhc']==allele) & (df['bind']==1), 'sequence']
    if len(seqs) >= min_sample_num:
        seqs = seqs.apply(lambda x: x[:submotif_len] + x[-submotif_len:])
        temp_df = pd.DataFrame(columns=list(aa_str))
        seqlogo_df = lm.alignment_to_matrix(sequences=seqs, to_type="information", characters_to_ignore="XU")
        temp_df = pd.concat([temp_df, seqlogo_df], axis=0)
        temp_df = temp_df.fillna(0.0)
        motif_dict[allele] = temp_df.to_numpy()

In [None]:
# motif pairwise distance
from scipy.spatial.distance import pdist, squareform

threshold = 0

motif_df = pd.DataFrame()
for k, v in motif_dict.items():
    arr = v.copy()
    arr[arr < threshold] = 0
    motif_df[k] = arr.flatten()

motif_df = motif_df.loc[(motif_df!=0).any(axis=1), :]
motif_dist = pdist(motif_df.T, metric='cosine')
motif_dist = squareform(motif_dist)
motif_dist_df = pd.DataFrame(motif_dist, columns=motif_df.columns, index=motif_df.columns)

In [None]:
for hla in ['A', 'B', 'C']:
    select_col = [i for i in motif_dist_df.columns if hla in i]
    temp_df = motif_dist_df.loc[select_col, select_col]
    g = sns.clustermap(temp_df,
                       method='average',
                       metric='cosine',
                       cbar_pos=(.3, -.05, .4, .02),
                       cbar_kws={'orientation': 'horizontal', 'label': 'pairwise distance'},
                       xticklabels=False,
                       yticklabels=True,
                       dendrogram_ratio=0.1,
                       figsize=(5, 5))
    _ = g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), fontsize = 5)
    _ = g.ax_heatmap.set_xlabel('HLA-{} allele'.format(hla))
    plt.savefig('%s/Pdist_%s'%(output_dir, hla), bbox_inches='tight', dpi=dpi)
    allele_order = g.dendrogram_row.reordered_ind
    allele_order = [temp_df.columns[i] for i in allele_order]
    _motif_plot(allele_order, motif_dict, figfile='%s/PdistMotif_%s'%(output_dir, hla))