In [None]:
import os, sys, re, json, random
import numpy as np
import pandas as pd
from collections import OrderedDict
import matplotlib.pyplot as plt
import seaborn as sns
import logomaker as lm
from MHCInterp import MHCInterp
import warnings
warnings.filterwarnings('ignore')

## Loading Data

In [None]:
mhc_seq_dict = json.load(open('../data/MHCI/MHCI_res182_seq.json', 'r'))

dirname = '../prediction/pan_allele/output/'
mhc_motif_dict = dict()
for sub_dir in os.listdir(dirname):
    d = np.load('%s/%s/motif.npy'%(dirname, sub_dir), allow_pickle=True)[()]
    mhc_motif_dict = {**mhc_motif_dict, **d}
    
submotif_len = 4

position_dict = json.load(open('../analysis/CAMInterp/res182_decoy5_CNN_1_1/ResidueSelection.json', 'r'))

## Clustering

In [None]:
clustering_kwargs = {'DBSCAN_eps': 3,
                     'DBSCAN_metric': 'euclidean',
                     'DBSCAN_min_samples': 5,
                     'HDBSCAN_min_cluster_size': 100,
                     'HDBSCAN_min_samples': 1,
                     'Agglomerative_affinity': 'cosine',
                     'Agglomerative_linkage': 'average',
                     'Agglomerative_distance_threshold': None,
                     'Agglomerative_n_clusters': 6}

reduction_kwargs = {'UMAP_n_neighbors': 50,
                    'UMAP_min_dist': 0.1,
                    'TSNE_perplexity': 80,
                    'TSNE_n_iter': 50000}

### Pre-pdist + Agglomerative Clustering

In [None]:
output_dir = '../analysis/MHCInterp/res182_decoy5_CNN_1_1/'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

interp = MHCInterp(mhc_seq_dict, mhc_motif_dict, submotif_len, position_dict, output_dir)

In [None]:
noise_threshold = 0

clustering_method = 'Agglomerative'
clustering_kwargs = {'Agglomerative_affinity': 'cosine',
                     'Agglomerative_linkage': 'complete',
                     'Agglomerative_distance_threshold': None,
                     'Agglomerative_n_clusters': None}

reduction_method = None
reduction_kwargs = {}

pre_pdist = True
metric = 'cosine'
method = 'complete'

highlight=False
load_file = False
turn_off_label = True

In [None]:
args = [['A','N',8],
        ['A','C',5],
        ['B','N',7],
        ['B','C',5],
        ['C','N',5],
        ['C','C',3]]

for hla, side, n_clusters in args:
    clustering_kwargs['Agglomerative_n_clusters'] = n_clusters
    interp.Clustering(hla, side, noise_threshold,
                      clustering_method, clustering_kwargs,
                      reduction_method=reduction_method, reduction_kwargs=reduction_kwargs,
                      pre_pdist=pre_pdist, metric=metric, method=method,
                      highlight=highlight, load_file=load_file, turn_off_label=turn_off_label)

### tSNE + DBSCAN

In [None]:
output_dir = '../tmp/'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

interp = MHCInterp(mhc_seq_dict, mhc_motif_dict, submotif_len, position_dict, output_dir)

In [None]:
noise_threshold = 0.2

clustering_method = 'DBSCAN'
clustering_kwargs = {'DBSCAN_eps': 0,
                     'DBSCAN_metric': 'euclidean',
                     'DBSCAN_min_samples': 5}

reduction_method = 'tSNE'
reduction_kwargs = {'TSNE_perplexity': 80,
                    'TSNE_n_iter': 50000}

pre_pdist = None
metric = 'cosine'
method = 'average'

load_file = False
turn_off_label = True

In [None]:
args = [['A','N',4.5],
        ['A','C',3],
        ['B','N',3],
        ['B','C',3],
        ['C','N',2.5],
        ['C','B',2.5]]

for hla, side, eps in args:
    clustering_kwargs['DBSCAN_eps'] = eps
    interp.Clustering(hla, side, noise_threshold,
                      clustering_method, clustering_kwargs,
                      reduction_method=reduction_method, reduction_kwargs=reduction_kwargs,
                      pre_pdist=pre_pdist, metric=metric, method=method,
                      highlight=highlight, load_file=load_file, turn_off_label=turn_off_label)

## Grouping Counts

In [None]:
for hla in ['A','B','C']:
    interp.AlleleGrouping(hla)

## Analysis

In [None]:
hla = 'B'
middle_pos = 74

nside_args = [[1, '#ccefff'], [0, '#ffe6e6']]
cside_args = [[3, '#ffffcc'], [0, '#ccffdc']]
n_pos = [i for i in position_dict['selected'] if i <= middle_pos]
c_pos = [i for i in position_dict['selected'] if i > middle_pos ]

turn_off_label = True

for n_group, n_color in nside_args:
    highlight_pos_dict = {n_color: n_pos}
    interp.Analysis(hla, n_group, None, side='N',
                    turn_off_label=turn_off_label, highlight_pos_dict=highlight_pos_dict)

for c_group, c_color in cside_args:
    highlight_pos_dict = {c_color: c_pos}
    interp.Analysis(hla, None, c_group, side='C',
                    turn_off_label=turn_off_label, highlight_pos_dict=highlight_pos_dict)

for n_group, n_color in nside_args:
    highlight_pos_dict = {n_color: n_pos}
    for c_group, c_color in cside_args:
        highlight_pos_dict[c_color] = c_pos
        interp.Analysis(hla, n_group, c_group,
                        turn_off_label=turn_off_label, highlight_pos_dict=highlight_pos_dict)


## MHC-I sequence logo

In [None]:
output_dir = '../tmp/'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

interp = MHCInterp(mhc_seq_dict, mhc_motif_dict, submotif_len, position_dict, output_dir)

In [None]:
df = pd.read_csv('../data/raw/dataframe/train_hit.csv', index_col=0)
alleles = list(df.mhc.unique())

In [None]:
seqlogo_dict = dict()
seqlogo_dict['all'] = interp._mhc_seqlogo(alleles, list(range(182)))
for hla in ['A', 'B', 'C']:
    seqlogo_dict[hla] = interp._mhc_seqlogo([i for i in alleles if hla in i], list(range(182)))

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(6, 3), dpi=interp.dpi)
hla_list = ['A', 'B', 'C']
ylim = 0.8

for i in range(len(hla_list)):
    temp_seqlogo = seqlogo_dict[hla_list[i]] - seqlogo_dict['all']
    logo = lm.Logo(temp_seqlogo, color_scheme='skylign_protein', ax=ax[i])
    ax[i].set_ylim(-ylim, ylim)
    _ = ax[i].set_xticks([])
    _ = ax[i].set_yticks([])
    _ = ax[i].set_xticklabels([])
    _ = ax[i].set_yticklabels([])
    _ = ax[i].set_title(None)
    
fig.tight_layout()
fig.savefig('%s/MHCseqlogo.png'%interp.output_dir)