In [None]:
import os, sys, re, json, random, importlib
import numpy as np
import pandas as pd
from collections import OrderedDict
from tqdm import tqdm
import torch
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import logomaker as lm
from venn import venn
from venn import generate_petal_labels, draw_venn
from scipy.cluster import hierarchy
from util import *
from CAMInterp import CAMInterp
import warnings
warnings.filterwarnings('ignore')

In [None]:
# arguments
mhc_seq_filename = '../data/MHCI/MHCI_res182_seq.json'
mask_dirname = ['../cam_result/res182_decoy5_CNN_1_1_train_hit/decoy_%d/mhc_2/ScoreCAM/'%i for i in range(1, 87, 5)]
df_filename = '../prediction/train_hit/res182_decoy5_CNN_1_1_18/prediction.csv'
output_dir = '../analysis/CAMInterp/res182_decoy5_CNN_1_1'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)
pred_basename = 'score'
pred_threshold = 0.9

In [None]:
# interpretation class
interp = CAMInterp(mhc_seq_filename, mask_dirname, df_filename, output_dir,
                   pred_basename=pred_basename, pred_threshold=pred_threshold)

In [None]:
# Residue analysis
cam_threshold = 0.4
importance_threshold = 0.4
interp.ResidueAnalysis(cam_threshold, importance_threshold)

In [None]:
# Cluster analysis

method = 'average'
metric = 'euclidean'
interp.ClusterAnalysis(method, metric, plot_each_mhc=False)

## Pairwise distance

In [None]:
df = pd.read_csv('../data/raw/dataframe/train_hit.csv', index_col=0)
alleles = df['mhc'].unique()
motif_dict = dict()

for allele in alleles:
    seqs = df.loc[(df['mhc']==allele) & (df['bind']==1), 'sequence']
    if len(seqs) >= interp.min_sample_num:
        seqs = seqs.apply(lambda x: x[:interp.submotif_len] + x[-interp.submotif_len:])
        temp_df = pd.DataFrame(columns=list(interp.aa_str))
        seqlogo_df = lm.alignment_to_matrix(sequences=seqs, to_type="information", characters_to_ignore="XU")
        temp_df = pd.concat([temp_df, seqlogo_df], axis=0)
        temp_df = temp_df.fillna(0.0)
        motif_dict[allele] = temp_df.to_numpy()

In [None]:
# motif pairwise distance

from scipy.spatial.distance import pdist, squareform

threshold = 0

motif_df = pd.DataFrame()
for k, v in motif_dict.items():
    arr = v.copy()
    arr[arr < threshold] = 0
    motif_df[k] = arr.flatten()

motif_df = motif_df.loc[(motif_df!=0).any(axis=1), :]
motif_dist = pdist(motif_df.T, metric='cosine')
motif_dist = squareform(motif_dist)
motif_dist_df = pd.DataFrame(motif_dist, columns=motif_df.columns, index=motif_df.columns)

In [None]:
for hla in ['A', 'B', 'C']:
    select_col = [i for i in motif_dist_df.columns if hla in i]
    temp_df = motif_dist_df.loc[select_col, select_col]
    g = sns.clustermap(temp_df,
                       method='average',
                       metric='cosine',
                       cbar_pos=None,
                       xticklabels=False,
                       yticklabels=True,
                       dendrogram_ratio=0.1,
                       figsize=(5, 5))
    _ = g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), fontsize = 5)
    plt.savefig('%s/Pdist_%s'%(interp.output_dir, hla), bbox_inches='tight', dpi=interp.dpi)
    allele_order = g.dendrogram_row.reordered_ind
    allele_order = [temp_df.columns[i] for i in allele_order]
    interp._motif_plot(allele_order, motif_dict, figfile='%s/PdistMotif_%s'%(interp.output_dir, hla))

## Random sequence

In [None]:
class Mutation():
    def __init__(self, mhc_seq_len, sub_motif_len,
                 mhc_seq_file, df_file, dataset_file, batch_size,
                 model_file, model_state_file, input_dim, cuda, data_num=None):
        
        self.mhc_seq_len = mhc_seq_len
        self.sub_motif_len = sub_motif_len
        
        # Data
        self.mhc_seq_dict = json.load(open(mhc_seq_file, 'r'))
        self.df = pd.read_csv(df_file, index_col=0)
        self.dataset = torch.load(dataset_file)
        if data_num:
            self.df = self.df.sample(n=data_num)
            idx = self.df.index
            self.dataset = [self.dataset[i] for i in idx]
            self.df = self.df.reset_index(drop=True)
        self.batch_size = batch_size
        self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False)
        
        # Model
        if cuda:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        model_file = '.'.join(model_file.split('.')[0].split('/'))
        module = importlib.import_module(model_file)
        self.model = module.CombineModel(module.MHCModel(input_dim), module.EpitopeModel(input_dim))
        model_state_dict = torch.load(model_state_file, map_location=self.device)
        self.model.load_state_dict(model_state_dict['model_state_dict'])
        self.model.to(self.device)
        
        
    def Mutate2Seqlogo(self, allele, mutate, threshold, motif_side):
        seq = self.mhc_seq_dict[allele]
        if mutate:
            seq = self._mutate_seq(seq, mutate)
        encode = OneHotEncoder(seq, self.mhc_seq_len, False)
        preds = self._predict(encode)
        idx = np.where(preds > threshold)[0]
        #print('Allele: %s; Mutate: %s; Positive num: %d'%(allele, mutate, len(idx)))
        seqlogo_df = self._get_seqlogo_df(idx, motif_side)
        return seqlogo_df
    
    
    def MutateAllResidues(self, allele, pred_threshold, mhc_pos_list, motif_side, motif_pos, figfile=None):
        information_dict = dict()
        fig, axes = plt.subplots(len(mhc_pos_list)+1, figsize=(5, (len(mhc_pos_list)+1)*2))
        
        # original
        mutate = None
        seqlogo_df = self.Mutate2Seqlogo(allele, mutate, pred_threshold, motif_side)
        logo = lm.Logo(seqlogo_df, color_scheme='skylign_protein', ax=axes[0])
        information_dict['origin'] = seqlogo_df.iloc[motif_pos].to_dict()
        
        # mutate
        for i in range(len(mhc_pos_list)):
            mutate = (mhc_pos_list[i], '.')
            seqlogo_df = self.Mutate2Seqlogo(allele, mutate, pred_threshold, motif_side)
            logo = lm.Logo(seqlogo_df, color_scheme='skylign_protein', ax=axes[i+1])
            information_dict['mutate_%d'%mhc_pos_list[i]] = seqlogo_df.iloc[motif_pos].to_dict()
        
        if figfile:
            fig.savefig(figfile)
        
        return information_dict
    
    
    def MutationHeatmap(self, alleles, pred_threshold, mhc_pos_list, motif_side):
        if type(mhc_pos_list[0]) == list:
            pdist_df = pd.DataFrame(columns=range(len(mhc_pos_list)), index=alleles)
        else:
            pdist_df = pd.DataFrame(columns=mhc_pos_list, index=alleles)
        
        for allele in alleles:
            information_df = pd.DataFrame()
            
            # original
            mutate = None
            seqlogo_df = self.Mutate2Seqlogo(allele, mutate, pred_threshold, motif_side)
            information_df['origin'] = seqlogo_df.to_numpy().flatten()
            
            # mutate
            for i in range(len(mhc_pos_list)):
                if type(mhc_pos_list[i]) == list:
                    mutates = [(mhc_pos_list[i][j], '.') for j in range(len(mhc_pos_list[i]))]
                else:
                    mutates = (mhc_pos_list[i], '.')
                seqlogo_df = self.Mutate2Seqlogo(allele, mutates, pred_threshold, motif_side)
                information_df[i] = seqlogo_df.to_numpy().flatten()
            
            # pairwise distance
            information_df = information_df.T.to_numpy()
            pdist = metrics.pairwise_distances([information_df[0]], information_df[1:])
            pdist_df.loc[allele] = pdist[0]
            
            print('%s Complete'%allele)
        
        return pdist_df.astype(float)
        
    
    def _predict(self, mhc_encode):
        self.model.eval()
        for j, (x, y) in enumerate(self.dataloader):
            with torch.no_grad():
                num = x.shape[0]
                epitope_encode = x.to(self.device).float()
                mhc_encode_tile = torch.FloatTensor(np.tile(mhc_encode, (num, 1, 1))).to(self.device)
                pred = self.model(mhc_encode_tile, epitope_encode).to('cpu')
                pred = pred.view(-1,).numpy()
                if j==0:
                    preds = pred
                else:
                    preds = np.append(preds, pred, axis=0)
        return preds
    
    
    def _mutate_seq(self, seq, mutates): # mutate = (position, amino acid)
        seq_list = list(seq)
        for pos, mut in mutates:
            seq_list[pos] = mut
        return ''.join(seq_list)
    
    
    def _get_seqlogo_df(self, idx, side='both'):
        seqs = self.df.iloc[idx]['sequence']
        if side == 'N':
            seqs = seqs.apply(lambda x: x[:self.sub_motif_len])
        elif side == 'C':
            seqs = seqs.apply(lambda x: x[-self.sub_motif_len:])
        else:
            seqs = seqs.apply(lambda x: x[:self.sub_motif_len] + x[-self.sub_motif_len:])
        seqlogo_df = lm.alignment_to_matrix(sequences=seqs, to_type='information', characters_to_ignore="XU")
        return seqlogo_df

In [None]:
mhc_seq_len = 182
sub_motif_len = 4
mhc_seq_file = 'data/MHCI/res182_seq.json'
df_file = 'data/random_peptide/random.csv'
dataset_file = 'data/random_peptide/random_onehot.pt'
batch_size = 2048
model_file = 'gitlab/kohan/model/res182_CNN_8.py'
model_state_file = 'result/single/res182_clf_downsampling_onehot_CNN_81_1.20201005212128/model/model_best.tar'
input_dim = 21
cuda = True

mutation = Mutation(mhc_seq_len, sub_motif_len, mhc_seq_file,
                    df_file, dataset_file, batch_size,
                    model_file, model_state_file, input_dim, cuda, data_num=100000)

In [None]:
alleles = list()
for g in groups:
    alleles += g

mutate_range_list = [(7,13), (40,46), (60,70), (70,82), (93,100),
                     (111,117), (140,147), (150,158), (161,167), (176, 182)]
mutate_list = [list(range(i[0],i[1])) for i in mutate_range_list]

pred_threshold = 0.9

motif_side = 'C'

pdist_df = mutation.MutationHeatmap(alleles, pred_threshold, mutate_list, motif_side)

In [None]:
alleles = list()
for g in groups:
    alleles += g

pred_threshold = 0.9

mhc_pos_list = sorted(list(position_set))

motif_side = 'C'

pdist_df = mutation.MutationHeatmap(alleles, pred_threshold, mhc_pos_list, motif_side)

In [None]:
groups

In [None]:
plt.figure(figsize=(10,10))
sns.heatmap(pdist_df.astype(float))

In [None]:
sns.heatmap(pdist_df.astype(float))

In [None]:
groups

In [None]:
pred_threshold = 0.8
motif_side = 'C'
motif_pos = -1

figdir = 'fig'

allele = 'B*15:02'
mutate_positions = [8, 10, 11, 31, 44, 66, 93, 94, 96, 102, 113, 115, 162]

allele_aa_pairs={
    allele: "Y"
}

name = allele[0] + allele[2:4] + allele[5:7]
figfile = '%s/%s.png'%(figdir, name)
information_dict = mutation.MutateAllResidues(allele, pred_threshold, mutate_positions, motif_side, motif_pos, figfile)
information_df = pd.DataFrame(information_dict)
fig = plt.figure(figsize=(10,10))
sns.barplot(x=information_df.columns, y=information_df.loc[allele_aa_pairs[allele]])
_ = plt.xticks(rotation=90)
fig.savefig('%s/%s_information.png'%(figdir, name))
print("%s Complete"%allele)

In [None]:
allele_aa_pairs={
    allele: "L"
}
fig = plt.figure(figsize=(10,10))
sns.barplot(x=information_df.columns, y=information_df.loc[allele_aa_pairs[allele]])
_ = plt.xticks(rotation=90)

In [None]:
allele_mutate_pairs = {
    "B*07:02": [8, 23, 44, 62, 66, 70, 96, 102, 112],
    "B*55:01": [8, 23, 44, 62, 66, 70, 96, 102, 112],
    "B*40:01": [8, 23, 44, 62, 66],
    "B*18:01": [8, 23, 44, 62, 66],
    "B*15:10": [8, 10, 23, 62, 66, 76, 112, 115],
    "B*38:01": [8, 10, 23, 62, 66, 76, 112, 115],
    "B*58:01": [23, 64, 65, 66, 76, 79, 80, 81, 82, 162],
    "B*15:17": [23, 64, 65, 66, 76, 79, 80, 81, 82, 162],
    "B*13:01": [23, 44, 45, 62, 66, 76, 79, 80, 81, 82, 112, 115, 144, 162],
    "B*15:01": [23, 44, 45, 62, 66, 76, 79, 80, 81, 82, 112, 115, 144, 162]
}

allele_aa_pairs = {
    "B*07:02": 'P',
    "B*55:01": 'P',
    "B*40:01": 'E',
    "B*18:01": 'E',
    "B*15:10": 'H',
    "B*38:01": 'H',
    "B*58:01": 'S',
    "B*15:17": 'S',
    "B*13:01": 'Q',
    "B*15:01": 'Q'
}

pred_threshold = 0.9
motif_side = 'N'
motif_pos = 1

figdir = 'fig'

for allele, mutate_positions in allele_mutate_pairs.items():
    name = allele[0] + allele[2:4] + allele[5:7]
    figfile = '%s/%s.png'%(figdir, name)
    information_dict = mutation.MutateAllResidues(allele, pred_threshold, mutate_positions, motif_side, motif_pos, figfile)
    information_df = pd.DataFrame(information_dict)
    fig = plt.figure(figsize=(10,10))
    sns.barplot(x=information_df.columns, y=information_df.loc[allele_aa_pairs[allele]])
    _ = plt.xticks(rotation=90)
    fig.savefig('%s/%s_information.png'%(figdir, name))
    print("%s Complete"%allele)