In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import pathlib, json
import torch
import os
import pandas as pd
import numpy as np
import pandas as pd
import pickle, sys
import warnings
from tqdm import tqdm
from sklearn.model_selection import KFold

sys.path.append('../')
sys.path.append('../../')

from recipes.dataset import MCPASDataset

# from analysis_util import display_mat_from_ind
# from analysis_util import convert_len

from Bio.PDB import *

from pdb_util import get_chain_list, calc_dist, remove_HOH
# get_structure_from_id, from_str_to_chain_names, get_residues_from_names
# from pdb_util import get_cdrs_from_anarci
# from pdb_util import get_seqs_from_residues
warnings.filterwarnings(action='once')


def display_mat_from_ind(ind, wantdisplay=False):
    attn_output_weights1, attn_output_weights2, abseq_with_comma, peptide, ypred, sign = ATTENTION_MATRIX_DICT[ind]
    aseq = DF_TCR.iloc[ind]['tcra']
    bseq = DF_TCR.iloc[ind]['tcrb']
    abseq = convert_len(aseq, MAXLENGTH_A) + convert_len(bseq, MAXLENGTH_B)
    peptide = convert_len(DF_TCR.iloc[ind]['peptide'], max_len_epitope)


    attn_output_weights2_list = []
    for head_i in range(4):
        a = attn_output_weights2[head_i]
        dfa = pd.DataFrame(a)
        dfa.insert(27, "delimiter", [0.1**9 for _ in range(len(dfa))])
        dfa = dfa.loc[:, ((dfa!=0).sum()!=0).values]
        dfa.columns = [f'{c}_{i}' for i,c in enumerate(list(abseq_with_comma))]
        dfa = dfa.head(len(peptide.replace('8','')))
        dfa.index = list(peptide.replace('8',''))
        dfa.index = [f'{ind}_{i}' for i,ind in enumerate(dfa.index)]
        if wantdisplay:
            display(px.imshow(dfa, width=800, height=480))
        attn_output_weights2_list.append(dfa)

    abseq_index = convert_len(aseq, MAXLENGTH_A) + ':' + convert_len(bseq, MAXLENGTH_B)

    attn_output_weights1_list = []
    for head_i in range(4):
        a = attn_output_weights1[head_i]
        dfa = pd.DataFrame(a).T
        dfa.insert(27, "delimiter", [0.1**9 for _ in range(len(dfa))])
        dfa = dfa.T
        dfa = dfa.loc[:, ((dfa!=0).sum()!=0).values]
        dfa.index = [f'{c}_{i}' for i,c in enumerate(list(abseq_index))]

        dfa.columns = [f'{ind}_{i}' for i,ind in enumerate(peptide)]
        selector_columns = [c for c in dfa.columns if '8_' not in c]
        selector_index = [c for c in dfa.index if '8_' not in c]
        dfa = dfa.loc[selector_index]
        dfa.index = [f'{c}_{i}' for i,c in enumerate(abseq_with_comma)]
        dfa = dfa[selector_columns]
        if wantdisplay:
            display(px.imshow(dfa, width=800, height=480))
        attn_output_weights1_list.append(dfa)
    return attn_output_weights1_list, attn_output_weights2_list

def convert_len(seq, maxlen):
    if len(seq) >= maxlen:
        return seq[:maxlen]
    else:  # padding
        pad = '8' * int(maxlen - len(seq))
        return seq + pad
    
import warnings
warnings.filterwarnings('ignore')


In [3]:
def remove_UNK_and_take_subset(df_train):
    df_nounk = df_train[df_train['tcra']!='UNK']
    print("df_nounk.value_counts() \n\t ", df_nounk.sign.value_counts().to_dict())
    print("df_nounk.drop_duplicates('tcrb') \n\t ", df_nounk.drop_duplicates('tcrb').sign.value_counts().to_dict())
    return df_nounk


def get_df(datapath):
    return pd.DataFrame(pickle.load(open(datapath, "rb")))

def get_df_from_path(p_list):
    return pd.concat([get_df(d) for d in p_list]).reset_index(drop=True)

def split_and_get_first(a):
    if ',' in a:
        a = a.split(',')[0]
    return a


In [None]:
AACODES = pd.read_csv('/Users/kyoheikoyama/workspace/tcrpred/analysis/aa_codes.csv')

AACODES['Abbreviation'] = AACODES['Abbreviation'].apply(lambda a: a.upper())
AACODES_DICT = {row['Abbreviation']:row['1 letter abbreviation'] for i, row in AACODES.iterrows()}

KFOLD_I = 0

dt = "20211029_011638" #logdf.checkp[0].split('/')[-1]
print('dt', dt)
OUT_DIR = '../../../tcr_attention_cachedir'
ATTENTION_MATRIX_DICT = pickle.load(open(f'{OUT_DIR}/{dt}_attention_matrix_dict.pickle','rb'))
# attn_output_weights1, attn_output_weights2, abseq_with_comma, peptide, ypred, sign = ATTENTION_MATRIX_DICT[ind]

p_list = [f"../../../external_data/ERGO-II/Samples/vdjdb_train_samples.pickle",
          f"../../../external_data/ERGO-II/Samples/mcpas_train_samples.pickle", 
          '../../../external_data/ERGO-II/Samples/vdjdb_test_samples.pickle',
          '../../../external_data/ERGO-II/Samples/mcpas_test_samples.pickle',
         ]

df_all = get_df_from_path(p_list)

print("df_all(train).value_counts() \n\t", df_all.sign.value_counts().to_dict())
kf = KFold(n_splits=5, shuffle=True, random_state=2)
train_index, valid_index = [
    (train_index, valid_index) for train_index, valid_index in kf.split(df_all)
][KFOLD_I]
df_train, df_valid = df_all.loc[train_index], df_all.loc[valid_index]

DF_TCR = df_train = remove_UNK_and_take_subset(df_train)
dataset_train, dataset_valid = MCPASDataset(df_train), MCPASDataset(df_valid)

print('len(dataset_train),', len(dataset_train), len(df_train))


# SCEPTRE

In [None]:
pd.read_csv('/Users/kyoheikoyama/Downloads/sceptre_result_v2.csv')['pdb_id'].unique()

In [None]:
df_sceptre_result = pd.read_csv('../../sceptre_result.csv')

df_sceptre_result = df_sceptre_result[[
    'chain1_type', 'chain2_type', 'chain1_cdr3_seq_calculated', 'chain2_cdr3_seq_calculated', 
    'epitope_seq', 'epitope_accession_IRI', 'epitope_organism_IRI', 
    'pdb_id', 'tcr_c1_pdb_chain', 'tcr_c2_pdb_chain', 'mhc_c1_pdb_chain', 'mhc_c2_pdb_chain', 'e_pdb_chain', 'pdb_cell_contact_area', 'chain1_cdr3_pdb_pos', 'chain2_cdr3_pdb_pos',
    'calc_e_residues', 'calc_e_tcr_residues', 'calc_e_mhc_residues', 'calc_tcr_e_residues', 'calc_tcr_mhc_residues', 'calc_mhc_e_residues', 'calc_mhc_tcr_residues', 'calc_e_contact_area', 'calc_cell_contact_area'
                                      ]]
df_sceptre_result = pd.concat([
    df_sceptre_result,
    pd.DataFrame(df_sceptre_result.apply(tcr_a_or_b, axis=1))], axis=1)


In [None]:
print(len(df_sceptre_result))

# Attention based analysis

In [None]:
# DF_TCR[DF_TCR['peptide'] == PEPSEQ].head(3)

In [None]:
MAXLENGTH_A, MAXLENGTH_B, max_len_epitope = 28, 28, 25

In [None]:
import plotly.express as px

In [None]:
# attn_output_weights1_list, attn_output_weights2_list = display_mat_from_ind(ind)

In [None]:
# os.system(f'aws s3 cp s3://sg-playground-kkoyama-temp/tcrpred/tcr_attention_cachedir/df_pep_{PEPSEQ}.parquet ../../../tcr_attention_cachedir/')

# df_pep = pd.read_parquet(f'../../../tcr_attention_cachedir/df_pep_{PEPSEQ}.parquet')



# PDB based analysis

## Look at the "auth id"

In [None]:
PDBENTRIES = [s.replace('pdb','').replace('.ent','').upper() for s in os.listdir('./pdb/')]

print('len(PDBENTRIES)', len(PDBENTRIES))

In [None]:
print('len(PDBENTRIES)', len(PDBENTRIES))

In [None]:
# %%time

import  multiprocessing as mp
saved_list =  [a.replace('.ent','').replace('pdb','').upper() for a in os.listdir('pdb')]

pickleload = lambda p: pickle.load(open(p,"rb"))

path = "/Users/kyoheikoyama/workspace/tcrpred/analysis/analysis1_alldata/../DICT_PDB_Result/DICT_PDBID_2_CHAINNAMES.pickle"
if os.path.exists(path):
    DICT_PDBID_2_CHAINNAMES = pickleload(path)
    DICT_PDBID_2_RESIDUES = pickleload("/Users/kyoheikoyama/workspace/tcrpred/analysis/analysis1_alldata/../DICT_PDB_Result/DICT_PDBID_2_RESIDUES.pickle")
    DICT_PDBID_2_RESIDUES = {k:[remove_HOH(v) for v in vv] for k,vv in DICT_PDBID_2_RESIDUES.items()}
else:
    # DICT_PDBID_2_STRUCTURE = {pdbid:get_structure_from_id(pdbid) for pdbid in tqdm(PDBENTRIES)}
    with mp.Pool(4) as p:
        structs = p.map(get_structure_from_id, PDBENTRIES)

    DICT_PDBID_2_STRUCTURE = {pdbid:s for pdbid, s in zip(PDBENTRIES, structs)}

    temp = {pdbid:(from_str_to_chain_names(s)) for pdbid, s in 
                             tqdm(DICT_PDBID_2_STRUCTURE.items())}
    DICT_PDBID_2_CHAINNAMES = {k:v for k, v in temp.items() if v[0] is not None}
    DICT_PDBID_2_RESIDUES = {}
    for p, v in tqdm(DICT_PDBID_2_CHAINNAMES.items()):
        a, b, e = v
        a = split_and_get_first(a)
        b = split_and_get_first(b)
        e = split_and_get_first(e)
        s = DICT_PDBID_2_STRUCTURE[p]
        DICT_PDBID_2_RESIDUES[p]  = get_residues_from_names(s, a, b, e)



In [None]:
print('count_of_pdbids', len(DICT_PDBID_2_RESIDUES))

# USE Anarci to get CDRs of TCR from PDB

https://github.com/oxpig/ANARCI

In [None]:
%%time
DICT_PDBID_2_CDRS = {}
for p, v in DICT_PDBID_2_RESIDUES.items():
    residues_chain_alpha, residues_chain_beta, epi = v
    residues_chain_cdr_alpha, residues_chain_cdr_beta = get_cdrs_from_anarci(
        p, residues_chain_alpha, residues_chain_beta)
    DICT_PDBID_2_CDRS[p] = (residues_chain_cdr_alpha, residues_chain_cdr_beta, epi)

#### http://www.imgt.org/IMGTScientificChart/Nomenclature/IMGT-FRCDRdefinition.html

# PDB-distance-based Visualization
(Distance Matrix on PDB)

In [None]:
from pdb_util import distance_mat_from_residues

In [None]:
from pdb_util import get_seqs_from_residues

In [None]:
get_seqs_from_residues(*v)

In [None]:
DICT_PDBID_2_DISTANCE = {}

for p,v in DICT_PDBID_2_CDRS.items():
    a,b,e = v
    if any([vv is None or len(vv)==0 for vv in v]):
        continue
    else:
        DICT_PDBID_2_DISTANCE[p] = distance_mat_from_residues(a, b, e)

In [None]:
DICT_PDBID_2_DISTANCE[p]

In [None]:
for i in range(3):
    print([AACODES_DICT[r.get_resname()]  for r in v[i]])

In [None]:
# for pdbid in DICT_PDBID_2_CDRS.keys():
#     if pdbid not in DICT_PDBID_2_DISTANCE:
#         print(pdbid)
#         print(DICT_PDBID_2_CDRS[pdbid])
        
#         cdr_beta = ''.join(
#             pd.read_csv(f'anarci/{pdbid}_anarci_B.csv')[_IMGT_CDR_POS].values[0].tolist())
#         cdr_alpha = ''.join(
#             pd.read_csv(f'anarci/{pdbid}_anarci_D.csv')[_IMGT_CDR_POS].values[0].tolist())
#         break


In [None]:
len(DICT_PDBID_2_RESIDUES), len(DICT_PDBID_2_CDRS), len(DICT_PDBID_2_DISTANCE)

In [None]:
import plotly.express as px

from pdb_util import make_xlabel

DICT_PDBID_2_MELTDIST = {}
for p,distmat in DICT_PDBID_2_DISTANCE.items():
    distmat_vis = distmat.drop(columns=[c for c in distmat.columns if ':' in c])\
        .melt(ignore_index=False).reset_index().rename(columns={'variable':'tcr', 'index':'peptide'})
    distmat_vis = distmat_vis.sort_values(by=['peptide', 'tcr'])
    DICT_PDBID_2_MELTDIST[p] = distmat_vis

In [None]:
for p,distmat in DICT_PDBID_2_DISTANCE.items():
    fig = px.imshow(distmat, 
              x=distmat.columns, 
              y=distmat.index,  
              width=1000, height=1200)
    fig.update_layout(
        title=f"Residue distance in PDB {p}",
        xaxis_title="cdr_alpha:beta",
        yaxis_title="seq_epitope",
        autosize=False
    )
    fig.show()
    break

# Get attention from adhoc sequence

In [None]:
with open(f"../../hpo_params/optuna_best.json", "r") as fp:
    hparams = json.load(fp)

In [None]:
DICT_PDBID_2_SEQUENCES_CDR = {}
for p, cdrs in DICT_PDBID_2_CDRS.items():
    if any([v is None or len(v)==0 for v in cdrs]):
        continue
    DICT_PDBID_2_SEQUENCES_CDR[p] =  get_seqs_from_residues(*cdrs)


In [None]:
df = pd.DataFrame(DICT_PDBID_2_SEQUENCES_CDR).T
df = df.rename(columns={0:'tcra', 1:'tcrb', 2:'peptide'})
df['sign']=1.0

In [None]:
df.head(3)

In [None]:
df = df[df['peptide'].apply(len)<=25]

In [None]:
len(df)

In [None]:
# df = df.reset_index().rename(columns={'index':'pdbid'})

In [None]:
torch_dataset = MCPASDataset(df)
analysis_loader = torch.utils.data.DataLoader(torch_dataset, batch_size=1)

## Recover model weights

In [None]:
from recipes.utils import get_file_paths
from scripts.attention_extractor import get_attention_weights, Explain_TCRModel
from recipes.model import TCRModel
from ignite.handlers import Checkpoint
from ignite.engine import Events, create_supervised_trainer
n_tok = 29  # NUM_VOCAB
n_pos1 = 62  # MAX_LEN_AB
n_pos2 = 26  # MAX_LEN_Epitope
n_seg = 3

d_model, d_ff, n_head,n_local_encoder = hparams['d_model'], hparams['d_ff'], hparams['n_head'], hparams['n_local_encoder']
n_global_encoder = hparams['n_global_encoder']
dropout = hparams['dropout']
batch_size = hparams['batch_size']
lr=hparams['lr']

explain_model = Explain_TCRModel(d_model=d_model, d_ff=d_ff, n_head=n_head, n_local_encoder=n_local_encoder, 
                                 n_global_encoder=n_global_encoder, dropout=dropout, scope=4, 
                                 n_tok=n_tok, n_pos1=n_pos1, n_pos2=n_pos2, n_seg=n_seg)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

'''Model, optim, trainer'''
model = TCRModel(d_model=d_model, d_ff=d_ff, n_head=n_head, n_local_encoder=n_local_encoder, 
                 n_global_encoder=n_global_encoder, dropout=dropout, scope=4, 
                 n_tok=n_tok, n_pos1=n_pos1,
                n_pos2=n_pos2, n_seg=n_seg)

# Optimizer
optim = torch.optim.Adam(model.parameters(), lr=lr)

# Loss
loss_fn = torch.nn.CrossEntropyLoss(weight=torch.tensor([1.0, 15.0], device=device))

trainer = create_supervised_trainer(
    model=model, optimizer=optim, loss_fn=loss_fn, device=device,
)

checkpoint_files = get_file_paths('sg-playground-kkoyama-temp', 'tcrpred/checkpoint', '', '.pt')
checkpoint_files_withdt = [f for f in checkpoint_files if dt in f]
CHECKPOINT_DIR = f's3://sg-playground-kkoyama-temp/tcrpred/checkpoint/{dt}'
CHECKPOINT_PATH = checkpoint_files_withdt[0].split('/')[-1]

## Get the prediction and the best model
to_load = {'model': model, 'optimizer': optim, 'trainer': trainer}

torch_check_point_local = '../../../checkpoint/'
fro = os.path.join(CHECKPOINT_DIR, CHECKPOINT_PATH)
if not os.path.exists(os.path.join(torch_check_point_local, CHECKPOINT_PATH)):
    os.system(f'''aws s3 cp {fro} {torch_check_point_local}''')
    
checkpoint = torch.load(os.path.join(torch_check_point_local, CHECKPOINT_PATH), map_location=torch.device(device))

checkpoint['trainer']['seed'] = 9

Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)


In [None]:
MAXLENGTH_B

In [None]:
def get_mat_from_result_tuple(result_tuple, aseq, bseq, peptide):
    #abseq = convert_len(aseq, MAXLENGTH_A)convert_len(bseq, MAXLENGTH_B) 
    attn_output_weights1 = result_tuple[0]
    attn_output_weights2 = result_tuple[1]
    print(f"aseq={aseq}, bseq={bseq}, peptide={peptide}")
    abseq_with_comma = f'{aseq}:{bseq}'
    abseq_index = convert_len(aseq, MAXLENGTH_A) + ':' + convert_len(bseq, MAXLENGTH_B) 
    
    attn_output_weights2_list = []
    for head_i in range(4):
#         print('head', head_i, attn_output_weights2[head_i].shape)
        a = attn_output_weights2[head_i]
        dfa = pd.DataFrame(a)
        dfa.insert(27, "delimiter", [0.1**9 for _ in range(len(dfa))])
        dfa = dfa.loc[:, ((dfa!=0).sum()!=0).values]
        dfa.columns = list(abseq_with_comma)
        dfa.columns = [f'{c}_{i}' for i,c in enumerate(dfa.columns)]
        dfa = dfa.head(len(peptide.replace('8','')))
        dfa.index = list(peptide.replace('8',''))
        dfa.index = [f'{ind}_{i}' for i,ind in enumerate(dfa.index)]    
        #         print('df.sum(axis=0)', dfa.sum(axis=0))
        #         print('df.sum(axis=1)', dfa.sum(axis=1))
        # display(px.imshow(dfa, width=800, height=480))
        
        axs[0,head_i].imshow(dfa, aspect='equal')
        axs[0,head_i].set_xticks(range(len(dfa.columns)))
        axs[0,head_i].set_yticks(range(len(dfa.index)))
        axs[0,head_i].set_xticklabels(dfa.columns)
        axs[0,head_i].set_yticklabels(dfa.index)
        attn_output_weights2_list.append(dfa)
    
    attn_output_weights1_list = []
    for head_i in range(4):
#         print('head', head_i, attn_output_weights1[head_i].shape)
        a = attn_output_weights1[head_i]
        dfa = pd.DataFrame(a).T
        dfa.insert(MAXLENGTH_A, "delimiter", [0.1**9 for _ in range(len(dfa))])
        dfa = dfa.T
        dfa = dfa.loc[:, ((dfa!=0).sum()!=0).values]
        dfa.index = list(abseq_index)
        dfa.index = [f'{c}_{i}' for i,c in enumerate(dfa.index)]
        dfa.columns = list(convert_len(peptide, len(dfa.columns)))
        dfa.columns = [f'{ind}_{i}' for i,ind in enumerate(dfa.columns)]
        selector_columns = [c for c in dfa.columns if '8_' not in c]
        selector_index = [c for c in dfa.index if '8_' not in c]
        dfa = dfa.loc[selector_index]
        dfa.index = [f'{c}_{i}' for i,c in enumerate(abseq_with_comma)]
        dfa = dfa[selector_columns]
        #         print('df.sum(axis=0)', dfa.sum(axis=0))
        #         print('df.sum(axis=1)', dfa.sum(axis=1))
        # display(px.imshow(dfa, width=800, height=480))
        
        axs[1,head_i].imshow(dfa, aspect='equal')
        axs[1,head_i].set_xticks(range(len(dfa.columns)))
        axs[1,head_i].set_yticks(range(len(dfa.index)))
        axs[1,head_i].set_xticklabels(dfa.columns)
        axs[1,head_i].set_yticklabels(dfa.index)
        attn_output_weights1_list.append(dfa)
    return attn_output_weights1_list, attn_output_weights2_list

In [None]:
len(set([s[2] for s in DICT_PDBID_2_SEQUENCES_CDR.values()])), len(DICT_PDBID_2_SEQUENCES_CDR) 
len(set([s[0] for s in DICT_PDBID_2_SEQUENCES_CDR.values()])), len(DICT_PDBID_2_SEQUENCES_CDR) 

In [None]:
DICT_PDBID_2_model_out = {}

for p, (xx,yy) in zip(df.index, analysis_loader):
    result_tuple = get_attention_weights([x.to(device) for x in xx], 
                                   model, explain_model=explain_model, device=device)
    DICT_PDBID_2_model_out[p] = result_tuple

In [None]:
# DICT_PDBID_2_model_out

In [None]:
%%time
from matplotlib import pyplot as plt

DICT_PDBID_2_Atten12 = {}
for p, (xx,yy) in zip(df.index, analysis_loader):
    
    if p not in DICT_PDBID_2_model_out:
        continue
    print('pdbid =', p)
    HEAD_COUNT = 4
    fig, axs = plt.subplots(2, 4, figsize=(60,16))
    
    a,b,e = DICT_PDBID_2_SEQUENCES_CDR[p]
    attn_output_weights1_list, attn_output_weights2_list = get_mat_from_result_tuple(
        result_tuple=DICT_PDBID_2_model_out[p], aseq=a, bseq=b, peptide=e)
    DICT_PDBID_2_Atten12[p] = (attn_output_weights1_list, attn_output_weights2_list)
    plt.show()
    print('-'*100)
    break
    

# Attention-based Visualization

In [None]:
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt

In [None]:
# %%writefile -a analysis_util.py
def calc_melt_df(distmat_vis, attn_output_weights1, attn_output_weights2):
    separator = [c for c in attn_output_weights2.columns if ':' in c][0]
    separator_pos = attn_output_weights2.columns.get_loc(separator)
    removers = [':', 'C_0'] + \
        [attn_output_weights2.columns[i] for i in [separator_pos, separator_pos-1, separator_pos+1, len(attn_output_weights2.columns)-1]]
    remove_col = [c for c in attn_output_weights2.columns if c in removers]
    remove_ind = [c for c in attn_output_weights1.index if c in removers]
    attn_output_weights2 = attn_output_weights2.drop(columns=remove_col)
    attn_output_weights1 = attn_output_weights1.drop(index=remove_ind)
    
    attn_output_weights2_vis = attn_output_weights2.melt(ignore_index=False).reset_index().rename(columns={'variable':'tcr', 'index':'peptide'})
    attn_output_weights2_vis= attn_output_weights2_vis.sort_values(by=['peptide', 'tcr'])
    attn_output_weights1_vis = attn_output_weights1.T.melt(ignore_index=False).reset_index().rename(columns={'variable':'tcr', 'index':'peptide'})
    attn_output_weights1_vis = attn_output_weights1_vis.sort_values(by=['peptide', 'tcr'])
    
    merged_xy_1 = pd.merge(distmat_vis, attn_output_weights1_vis, on=['peptide','tcr'], how='inner')
    merged_xy_2 = pd.merge(distmat_vis, attn_output_weights2_vis, on=['peptide','tcr'], how='inner')
    return merged_xy_1, merged_xy_2



In [None]:
# %%writefile -a analysis_util.py

def show_correlation(distmat_vis, merged_xy_1, merged_xy_2, head):
    #merged_xy_1, merged_xy_2 = calc_melt_df(distmat_vis, attn_output_weights2, attn_output_weights1)
    assert len(merged_xy_2)!=0
    assert len(merged_xy_1)!=0
#     display(merged_xy_2)
#     display(merged_xy_1)
    x = merged_xy_2['value_x'].values
    y = merged_xy_2['value_y'].values
    b2, m2 = polyfit(x, y, 1)
    axs[0, head].plot(x, b2 + m2 * x, '-')
    axs[0, head].scatter(x, y)
    axs[0, head].set_xlabel('distance in PDB')
    axs[0, head].set_ylabel('attention weight value')
    axs[0, head].set_title(f'correlation b/w distance and attention \n (given a PEP, distribute tcr values to the sum 1.0), head {head} \n y={b2:.3f}+{m2:.3f}* x')
    
    x = merged_xy_1['value_x'].values
    y = merged_xy_1['value_y'].values
    b1, m1 = polyfit(x, y, 1)
    axs[1, head].plot(x, b1 + m1 * x, '-')
    axs[1, head].scatter(x, y)
    axs[1, head].set_xlabel('distance in PDB')
    axs[1, head].set_ylabel('attention weight value')
    axs[1, head].set_title(f'correlation b/w distance and attention \n (given a TCR, distribute pep values to the sum 1.0), head {head} \n y={b1:.3f}+{m1:.3f}* x')
    return (b1,m1,b2,m2)
    


In [None]:
def position_sorted(values):
    sorted_values = sorted(set([c.split('_')[1]+'_'+c.split('_')[0] for c in values]))
    return sorted_values

m1_list = []
m2_list = []

HEAD_COUNT = 4
for p in DICT_PDBID_2_model_out.keys():
    a,b,e = DICT_PDBID_2_SEQUENCES_CDR[p]
    print((p,a,b,e))

    attn_output_weights1_list, attn_output_weights2_list = get_mat_from_result_tuple(
        result_tuple=DICT_PDBID_2_model_out[p], aseq=a, bseq=b, peptide=e)

    fig, axs = plt.subplots(2, HEAD_COUNT, figsize=(60,16))
    for head in range(HEAD_COUNT):
        distmat_vis = DICT_PDBID_2_MELTDIST[p]
        assert all(attn_output_weights1_list[head].index == attn_output_weights2_list[head].columns)
        no_comma_position_sorted = [c for c in position_sorted(attn_output_weights1_list[head].index) if ':' not in c]
        assert position_sorted(distmat_vis['tcr'].values) == no_comma_position_sorted
        attn_output_weights1, attn_output_weights2 = attn_output_weights1_list[head], attn_output_weights2_list[head]
        merged_xy_1, merged_xy_2 = calc_melt_df(distmat_vis, attn_output_weights1, attn_output_weights2)
        b1,m1,b2,m2 = show_correlation(distmat_vis, merged_xy_1, merged_xy_2, head)
        m1_list.append(m1)
        m2_list.append(m2)
    plt.show()
    break

In [None]:
pd.Series(m1_list).hist(bins=100)

In [None]:
pd.Series(m1_list).describe()

In [None]:
pd.Series(m2_list).hist(bins=100)

## Example of correlation visualization

In [None]:
DICT_PDBID_2_SEQUENCES_CDR.keys()

In [None]:
p = '3MV7'
HEAD_COUNT = 4

a,b,e = DICT_PDBID_2_SEQUENCES_CDR[p]
print((p,a,b,e))

def position_sorted(values):
    sorted_values = sorted(set([c.split('_')[1]+'_'+c.split('_')[0] for c in values]))
    # print(sorted_values)
    return sorted_values

fig, axs = plt.subplots(2, 4, figsize=(60,16))    
attn_output_weights1_list, attn_output_weights2_list = get_mat_from_result_tuple(
    result_tuple=DICT_PDBID_2_model_out[p], aseq=a, bseq=b, peptide=e)
plt.show()

fig, axs = plt.subplots(2, HEAD_COUNT, figsize=(60,16))
for head in range(HEAD_COUNT):
    distmat_vis = DICT_PDBID_2_MELTDIST[p]
    assert all(attn_output_weights1_list[head].index == attn_output_weights2_list[head].columns)
    no_comma_position_sorted = [c for c in position_sorted(attn_output_weights1_list[head].index) if ':' not in c]
    assert position_sorted(distmat_vis['tcr'].values) == no_comma_position_sorted
    attn_output_weights1, attn_output_weights2 = attn_output_weights1_list[head], attn_output_weights2_list[head]
    merged_xy_1, merged_xy_2 = calc_melt_df(distmat_vis, attn_output_weights1, attn_output_weights2)
    show_correlation(distmat_vis, merged_xy_1, merged_xy_2, head)
plt.show()

# Sum up all the PDB

## Only 2 sigma attention

In [None]:
# attention = pickle.load(open('attention_GILGFVFTL.pickle','rb'))

givenTCRdistributePEP_by_head = {}
givenPEPdistributeTCR_by_head = {}

for hi in tqdm(range(HEAD_COUNT)):
    givenPEPdistributeTCR_by_head[hi] = []
    givenTCRdistributePEP_by_head[hi] = []

    for p, (a1_by_head, a2_by_head) in DICT_PDBID_2_Atten12.items():
        distmat_vis = DICT_PDBID_2_MELTDIST[p]
        a1, a2 = calc_melt_df(distmat_vis, a1_by_head[hi], a2_by_head[hi])
        a1['pdbid'] = p
        a2['pdbid'] = p
        
        temp1 = a1[a1['value_y'] > a1['value_y'].mean() + 5.0 * a1['value_y'].std()]
        temp2 = a2[a2['value_y'] > a2['value_y'].mean() + 5.0 * a2['value_y'].std()]
        
        temp1[['tcr', 'pdbid']].drop_duplicates()
        
        givenTCRdistributePEP_by_head[hi] += [temp1]
        givenPEPdistributeTCR_by_head[hi] += [temp2]

fig, axs = plt.subplots(2, HEAD_COUNT, figsize=(60,16))
for hi in range(4):
    df1_strong_atten = pd.concat(givenTCRdistributePEP_by_head[hi])
    df2_strong_atten = pd.concat(givenPEPdistributeTCR_by_head[hi])
    show_correlation(distmat_vis, df1_strong_atten, df2_strong_atten, hi)


## Only 2 sigma and PosPred Attention

In [None]:
strong_ids = [k for k,v in DICT_PDBID_2_model_out.items() if DICT_PDBID_2_model_out[k][-1]>0.5]
# attention = pickle.load(open('attention_GILGFVFTL.pickle','rb'))

givenTCRdistributePEP_by_head = {}
givenPEPdistributeTCR_by_head = {}

for hi in tqdm(range(HEAD_COUNT)):
    givenPEPdistributeTCR_by_head[hi] = []
    givenTCRdistributePEP_by_head[hi] = []

    for p, (a1_by_head, a2_by_head) in DICT_PDBID_2_Atten12.items():
        if p not in strong_ids:
            continue
        else:
            distmat_vis = DICT_PDBID_2_MELTDIST[p]
            a1, a2 = calc_melt_df(distmat_vis, a1_by_head[hi], a2_by_head[hi])
            a1['pdbid'] = p
            a2['pdbid'] = p

            temp1 = a1[a1['value_y'] > a1['value_y'].mean() + 5.0 * a1['value_y'].std()]
            temp2 = a2[a2['value_y'] > a2['value_y'].mean() + 5.0 * a2['value_y'].std()]

            temp1[['tcr', 'pdbid']].drop_duplicates()

            givenTCRdistributePEP_by_head[hi] += [temp1]
            givenPEPdistributeTCR_by_head[hi] += [temp2]

fig, axs = plt.subplots(2, HEAD_COUNT, figsize=(60,16))
for hi in range(4):
    df1_strong_atten = pd.concat(givenTCRdistributePEP_by_head[hi])
    df2_strong_atten = pd.concat(givenPEPdistributeTCR_by_head[hi])
    show_correlation(distmat_vis, df1_strong_atten, df2_strong_atten, hi)



## All attention values 

In [None]:
# attention = pickle.load(open('attention_GILGFVFTL.pickle','rb'))

givenTCRdistributePEP_by_head = {}
givenPEPdistributeTCR_by_head = {}

for hi in tqdm(range(HEAD_COUNT)):
    givenPEPdistributeTCR_by_head[hi] = []
    givenTCRdistributePEP_by_head[hi] = []

    for p, (a1_by_head, a2_by_head) in DICT_PDBID_2_Atten12.items():
        distmat_vis = DICT_PDBID_2_MELTDIST[p]
        a1, a2 = calc_melt_df(distmat_vis, a1_by_head[hi], a2_by_head[hi])
        a1['pdbid'] = p
        a2['pdbid'] = p
        
        givenTCRdistributePEP_by_head[hi] += [a1]
        givenPEPdistributeTCR_by_head[hi] += [a2]

fig, axs = plt.subplots(2, HEAD_COUNT, figsize=(60,16))
for hi in range(4):
    df1 = pd.concat(givenTCRdistributePEP_by_head[hi])
    df2 = pd.concat(givenPEPdistributeTCR_by_head[hi])
    show_correlation(distmat_vis, df1, df2, hi)


## 2 Group Comparison

In [None]:
print(df1_strong_atten.value_x.describe(), df1.value_x.describe())

In [None]:
print(df1_strong_atten.value_y.describe(), df1.value_y.describe())

## Other

In [None]:
df1['pair_tcr_pdbid'] = df1[['tcr','pdbid']].apply(tuple, axis=1)
# .isin(df1_strong_atten[['tcr','pdbid']].apply(tuple, axis=1).apply(tuple))

In [None]:
temp_df1_strong_atten = df1_strong_atten[['tcr','pdbid']].drop_duplicates() #.apply(tuple, axis=1)
temp_df1_strong_atten['pair_tcr_pdbid'] = temp_df1_strong_atten.apply(tuple, axis=1)
temp_df1_strong_atten['is_strong_atten'] = 1

In [None]:
pd.merge(df1, temp_df1_strong_atten, on=['pair_tcr_pdbid'], how='left').drop_duplicates('pair_tcr_pdbid') #['is_strong_atten'].value_counts()

## Logs of attention values

In [None]:
fig, axs = plt.subplots(2, HEAD_COUNT, figsize=(60,16))
for hi in range(4):
    df1 = pd.concat(givenTCRdistributePEP_by_head[hi])
    df1['value_y'] = df1['value_y'].apply(np.log)
    df2 = pd.concat(givenPEPdistributeTCR_by_head[hi])
    df2['value_y'] = df2['value_y'].apply(np.log)
    show_correlation(distmat_vis, df1, df2, hi)


In [None]:
fig, axs = plt.subplots(2, HEAD_COUNT, figsize=(60,16))
for hi in range(4):
    df1 = pd.concat(givenTCRdistributePEP_by_head[hi])
    df1['value_x'] = df1['value_x'].apply(np.log)
    df1['value_y'] = df1['value_y'].apply(np.log)
    df2 = pd.concat(givenPEPdistributeTCR_by_head[hi])
    df2['value_x'] = df2['value_x'].apply(np.log)
    df2['value_y'] = df2['value_y'].apply(np.log)
    show_correlation(distmat_vis, df1, df2, hi)


# Analysis by head after summing up values

In [None]:
df1

In [None]:
for head in range(4): #= 2
    df = pd.concat(givenTCRdistributePEP_by_head[head])
    

In [None]:
df.head()

In [None]:
print(df['value_y'].describe(percentiles=np.arange(0,1,0.1)))
percentile80 = df['value_y'].describe(percentiles=np.arange(0,1,0.1))['80%']

In [None]:
df['is_strong_attention'] = df['value_y'] > percentile80
df['aa'] = df['tcr'].str.split('_').apply(lambda x: x[0])
df['pos'] = df['tcr'].str.split('_').apply(lambda x: x[1])

In [None]:
df.groupby(by=['aa', 'is_strong_attention']).agg(len).iloc[:,:1].plot.bar(figsize=(20,2))

In [None]:
df.groupby(by=['pos', 'is_strong_attention']).agg(len).iloc[:,:1].plot.bar(figsize=(20,2))

# PDB Command

In [None]:
def add_commnad(Pymol_COMMAND, text):
    return Pymol_COMMAND + text + '\n'

def get_command(pdbid):
    Pymol_COMMAND = ''
    residues_chain_cdr_alpha = DICT_PDBID_2_CDRS[pdbid][0]
    residues_chain_cdr_beta = DICT_PDBID_2_CDRS[pdbid][1]

    chain2_cdr3_pdb_pos = [r.get_full_id()[3][1] for r in residues_chain_cdr_beta]
    chain1_cdr3_pdb_pos = [r.get_full_id()[3][1] for r in residues_chain_cdr_alpha]

    ALPHA_CHAIN_NAME, BETA_CHAIN_NAME, EPITOPE_CHAIN_NAME = DICT_PDBID_2_CHAINNAMES[pdbid]

    attn_output_weights1_list, attn_output_weights2_list = DICT_PDBID_2_Atten12[pdbid]
    attention_residue_index = []
    attention_residue_headnum = []
    for headi in range(4):
        atten = attn_output_weights2_list[headi]
        good_vals = (atten > (atten.values.ravel().mean() + 5.0 * atten.values.ravel().std())).any()
        attention_residue_index += [i for i,v in enumerate(good_vals.values) if v]
        attention_residue_headnum += [headi for _ in range(sum(good_vals.values))]

    residues_all = residues_chain_cdr_alpha + [None] + residues_chain_cdr_beta
    attention_residue_alpha = [residues_all[i] for i in attention_residue_index if i<len(residues_chain_cdr_alpha)]
    attention_residue_alpha_head = [attention_residue_headnum[i] for i, ri in enumerate(attention_residue_index) if ri<len(residues_chain_cdr_alpha)]
    attention_residue_beta = [residues_all[i] for i in attention_residue_index if i>=len(residues_chain_cdr_beta)]
    attention_residue_beta_head = [attention_residue_headnum[i] for i, ri in enumerate(attention_residue_index) if ri>=len(residues_chain_cdr_alpha)]
    attention_pos_alpha = [r.get_full_id()[3][1] for r in attention_residue_alpha]
    attention_pos_beta = [r.get_full_id()[3][1] for r in attention_residue_beta]

    Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'fetch {pdbid};')
    Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set seq_view, 1;')
    Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'bg_color white;')

    Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set_color blue60, [ 0, 0, 110 ];')
    Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set_color blue30, [ 0, 0, 20 ];')
    Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set_color red60, [ 110, 0, 0 ];')
    Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set_color red30, [ 20, 0, 0 ];')

    if len(BETA_CHAIN_NAME)==1:
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel beta_cdr3, (chain {BETA_CHAIN_NAME} and resi {chain2_cdr3_pdb_pos[0]}:{chain2_cdr3_pdb_pos[-1]});')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set cartoon_side_chain_helper, on')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'show sticks, beta_cdr3;')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color blue, (chain {BETA_CHAIN_NAME} and resi {chain2_cdr3_pdb_pos[0]}:{chain2_cdr3_pdb_pos[-1]});')
    else:
        for b in BETA_CHAIN_NAME.split(', '):
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel beta_cdr3, (chain {b} and resi {chain2_cdr3_pdb_pos[0]}:{chain2_cdr3_pdb_pos[-1]});')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set cartoon_side_chain_helper, on')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'show sticks, beta_cdr3;')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color blue, (chain {b} and resi {chain2_cdr3_pdb_pos[0]}:{chain2_cdr3_pdb_pos[-1]});')
    # Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel beta_cdr2, (chain {BETA_CHAIN_NAME} and resi {chain2_cdr2_pdb_pos[0]}:{chain2_cdr2_pdb_pos[-1]});')
    # Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color blue60, (chain {BETA_CHAIN_NAME} and resi  {chain2_cdr2_pdb_pos[0]}:{chain2_cdr2_pdb_pos[-1]});')
    # Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel beta_cdr1, (chain {BETA_CHAIN_NAME} and resi {chain2_cdr1_pdb_pos[0]}:{chain2_cdr1_pdb_pos[-1]});')
    # Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color blue30, (chain {BETA_CHAIN_NAME} and resi resi {chain2_cdr1_pdb_pos[0]}:{chain2_cdr1_pdb_pos[-1]});')

    
    if len(ALPHA_CHAIN_NAME)==1:        
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel alpha_cdr3, (chain {ALPHA_CHAIN_NAME} and resi {chain1_cdr3_pdb_pos[0]}:{chain1_cdr3_pdb_pos[-1]});')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set cartoon_side_chain_helper, on')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'show sticks, alpha_cdr3;')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color red, (chain {ALPHA_CHAIN_NAME} and resi {chain1_cdr3_pdb_pos[0]}:{chain1_cdr3_pdb_pos[-1]});')
    else:
        for a in ALPHA_CHAIN_NAME.split(', '):
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel alpha_cdr3, (chain {a} and resi {chain1_cdr3_pdb_pos[0]}:{chain1_cdr3_pdb_pos[-1]});')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set cartoon_side_chain_helper, on')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'show sticks, alpha_cdr3;')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color red, (chain {a} and resi {chain1_cdr3_pdb_pos[0]}:{chain1_cdr3_pdb_pos[-1]});')

        
    # Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel alpha_cdr2, (chain {ALPHA_CHAIN_NAME} and resi {chain1_cdr2_pdb_pos[0]}:{chain1_cdr2_pdb_pos[-1]});')
    # Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color red60, (chain {ALPHA_CHAIN_NAME} and resi {chain1_cdr2_pdb_pos[0]}:{chain1_cdr2_pdb_pos[-1]});')
    # Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel alpha_cdr1, (chain {ALPHA_CHAIN_NAME} and resi {chain1_cdr1_pdb_pos[0]}:{chain1_cdr1_pdb_pos[-1]});')
    # Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color red30, (chain {ALPHA_CHAIN_NAME} and resi {chain1_cdr1_pdb_pos[0]}:{chain1_cdr1_pdb_pos[-1]});')
    
    if len(EPITOPE_CHAIN_NAME)==1:
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel epitope, chain {EPITOPE_CHAIN_NAME};')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'show sticks, epitope;')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color yellow, chain {EPITOPE_CHAIN_NAME};')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'distance polar-contact, epitope, all, mode=2;')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'color black, polar-contact;')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'distance pipi-pication, epitope, all, mode=5;')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set_color gray60, [122,122,122];')
        Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'color gray60, pipi-pication;')
    else:
        for e in EPITOPE_CHAIN_NAME.split(', '):
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel epitope, chain {e};')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'show sticks, epitope;')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color yellow, chain {e};')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'distance polar-contact, epitope, all, mode=2;')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'color black, polar-contact;')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'distance pipi-pication, epitope, all, mode=5;')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'set_color gray60, [122,122,122];')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, 'color gray60, pipi-pication;')


    '''Attention'''
    if len(ALPHA_CHAIN_NAME)==1:
        for hi, ap in zip(attention_residue_alpha_head, attention_pos_alpha):
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel atten_a_head{hi}, (chain {ALPHA_CHAIN_NAME} and resi {ap});')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color pink, atten_a_head{hi};')
    else:
        for a in ALPHA_CHAIN_NAME.split(', '):
            for hi, ap in zip(attention_residue_alpha_head, attention_pos_alpha):
                Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel atten_a_head{hi}, (chain {a} and resi {ap});')
                Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color pink, atten_a_head{hi};')
        
    if len(BETA_CHAIN_NAME)==1:
        for hi, ap in zip(attention_residue_beta_head, attention_pos_beta):
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel atten_b_head{hi}, (chain {BETA_CHAIN_NAME} and resi {ap});')
            Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color pink, atten_b_head{hi};')
    else:
        for b in BETA_CHAIN_NAME.split(', '):
            for hi, ap in zip(attention_residue_beta_head, attention_pos_beta):
                Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'sel atten_b_head{hi}, (chain {b} and resi {ap});')
                Pymol_COMMAND = add_commnad(Pymol_COMMAND, f'color pink, atten_b_head{hi};')            
    return Pymol_COMMAND

In [None]:
DICT_PDBID_2_Atten12.keys()

In [None]:
PDBID = '5TEZ'
print(get_command(pdbid=PDBID))

# Save command into file

In [None]:
len(DICT_PDBID_2_Atten12.keys())

In [None]:
pdbid = PDBID
for pdbid in DICT_PDBID_2_Atten12.keys():
    with open(f'./pymolcommand/{pdbid}.txt', 'w') as f:
        f.writelines(get_command(pdbid))


In [None]:
!cat /Users/kyoheikoyama/workspace/tcrpred/analysis/analysis1_alldata/pymolcommand/3MV7.txt

In [None]:
!@/Users/kyoheikoyama/workspace/tcrpred/analysis/analysis1_alldata/pymolcommand/2UWE.txt

# Save Dict_something

In [None]:
# import pickle
# directory = "../DICT_PDB_Result"
# for name, dic in zip(
#     ["DICT_PDBID_2_Atten12",
#         "DICT_PDBID_2_CDRS",
#         "DICT_PDBID_2_CHAINNAMES",
#         "DICT_PDBID_2_DISTANCE",
#         "DICT_PDBID_2_MELTDIST",
#         "DICT_PDBID_2_model_out",
#         "DICT_PDBID_2_RESIDUES",
#         "DICT_PDBID_2_SEQUENCES_CDR",
# #         "DICT_PDBID_2_STRUCTURE"
#     ],
#     [DICT_PDBID_2_Atten12,
#         DICT_PDBID_2_CDRS,
#         DICT_PDBID_2_CHAINNAMES,
#         DICT_PDBID_2_DISTANCE,
#         DICT_PDBID_2_MELTDIST,
#         DICT_PDBID_2_model_out,
#         DICT_PDBID_2_RESIDUES,
#         DICT_PDBID_2_SEQUENCES_CDR,
# #         DICT_PDBID_2_STRUCTURE
#     ]):
    
#     pickle.dump(dic, open(f'{directory}/{name}.pickle', 'wb'))
