In [None]:
import numpy as np
import pandas as pd
import mdtraj
import pickle
import matplotlib.pyplot as plt
import os
import scipy.stats as ss
import random
from collections import Counter
import seaborn as sns
import itertools

plt.style.use('seaborn')

In [None]:
def save_obj(obj, name ):
    with open('obj/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name ):
    with open('obj/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)
    
all_anchor_overall_pos_score_dict = load_obj('all_anchor_overall_pos_score_dict_r4')

In [None]:
def chain_split(table):
    dict_by_chain = {}
    resSeq_tracker = {}
    serial_Name = {}
    count = 0
    for n,i in table.iterrows():
        count += 1
        if i['chainID'] not in serial_Name:
            serial_Name[i['chainID']] = {}
        if i['resSeq'] not in serial_Name[i['chainID']]:
            serial_Name[i['chainID']][i['resSeq']] = []
        try:
            if i['resSeq'] in resSeq_tracker[i['chainID']]:
                serial_Name[i['chainID']][i['resSeq']].append(n)
                continue
            else:
                dict_by_chain[i['chainID']].append(i['resName'])
                resSeq_tracker[i['chainID']].append(i['resSeq'])
                serial_Name[i['chainID']][i['resSeq']].append(n)
        except:
            dict_by_chain[i['chainID']] = [i['resName']]
            resSeq_tracker[i['chainID']] = [i['resSeq']]
            serial_Name[i['chainID']][i['resSeq']] = [n]
    return dict_by_chain, resSeq_tracker, serial_Name

def calculate_distance(HLA, peptide, traj, criteria='avg'):
    group_list = []
    for i in peptide.keys():
        for j in peptide[i]:
            for k in list(itertools.chain.from_iterable(list(HLA.values()))):
                group_list.append([j,k])
    distances = mdtraj.compute_distances(traj, group_list)
    init_dict = {j:[] for i in peptide.keys() for j in peptide[i]}
    return_dict = {i:{j:0} for i in peptide.keys() for j in peptide[i]}
    for i,j in enumerate(group_list):
        init_dict[j[0]].append(distances[0][i])
    if criteria == 'min':
        for i in return_dict.keys():
            for j in peptide[i]:
                one_res_data = init_dict[j]
                return_dict[i][j] = np.min(one_res_data)
    elif criteria == 'avg':
        for i in return_dict.keys():
            for j in peptide[i]:
                one_res_data = init_dict[j]
                return_dict[i][j] = np.mean(one_res_data)
    elif criteria == 'max':
        for i in return_dict.keys():
            for j in peptide[i]:
                one_res_data = init_dict[j]
                return_dict[i][j] = np.max(one_res_data)
        
    return init_dict, return_dict

def calculate_SASA(n_peptide, traj):
    residue_sizes = [traj.topology.chain(i).n_residues for i in range(0,traj.n_chains)]
    agg_res_sizes = [sum(residue_sizes[:i]) for i in range(0,traj.n_chains+1)]
    peptide_ind = [i for i in range(agg_res_sizes[n_peptide],agg_res_sizes[n_peptide+1])]
    sasa = mdtraj.shrake_rupley(traj, mode='residue')
    peptide_sasa = [sasa[0,i] for i in peptide_ind]
    return peptide_sasa

In [None]:
## Initialize dictionaries
SASA_results = {}
Distance_with_percentile_results = {}
sample_hla_query = {}

In [None]:
## Loop through each entry (strucutre: sample_id, peptide_seq, peptide_length)
with open('cleaned_list_pdb_structures.tsv', 'r') as file:
    for line in file:
        entry = line.strip('\n')
        hla_details, sample, peptide, peptide_length = entry.split("\t")   
        print(sample, peptide, peptide_length)
        if "HLA-"+hla_details not in sample_hla_query:
            sample_hla_query["HLA-"+hla_details] = [(sample,peptide,peptide_length)]
        else:
            sample_hla_query["HLA-"+hla_details].append((sample,peptide,peptide_length))
        
        if sample in SASA_results or sample in Distance_with_percentile_results:
            print("Already calculated for sample: "+sample)
        else:
            file_name = "pdb_files/"+sample+".pdb"
            hla = mdtraj.load(file_name)
            hla_name = "HLA-"+hla_details
            table, bonds = hla.topology.to_dataframe()
            dict_by_chain, resSeq_tracker, serial = chain_split(table)
            n_hla = 0
            n_peptide = 2
            for i in dict_by_chain.keys():
                if (i == 0 and len(dict_by_chain[i]) > 290) or (i == 0 and len(dict_by_chain[i]) < 269):
                    print(file_name, i, len(dict_by_chain[i]), "ERROR")
                    print([(j, len(dict_by_chain[j])) for j in dict_by_chain.keys()])
                    break
                if (i == 2 and len(dict_by_chain[i]) != int(peptide_length)):
                    for (chain_id, length) in [(j, len(dict_by_chain[j])) for j in dict_by_chain.keys()]:
                        if int(length) == int(peptide_length):
                            n_peptide = chain_id
                            print('NEW n_peptide: ' + str(n_peptide))
                            print('SAMPLE ERROR: ' +sample)
                            break
                
            ## threshold for top x closest locations considered as metric (only applies to distance not SASA)
            percentile = 50

            ## DISTANCE CALCULATIONS
            HLA_mol = serial[n_hla]
            peptide_mol = serial[n_peptide]
            ## For each atom in each residue, calculate minimum distance between it and any HLA atom 
            init, summary_min = calculate_distance(HLA_mol, peptide_mol, hla, criteria='min')

            ## Determine which locations are backbone and which are not
            not_backbone_ind = hla.top.select('not backbone')
            backbone_ind = hla.top.select('backbone')

            new_summary_min = {}
            for i in summary_min.keys():
                count = 0
                new_summary_min[i] = {}
                for j in summary_min[i]:
                    if j in not_backbone_ind:
                        count += 1
                        new_summary_min[i][j] = summary_min[i][j]
                if count == 0:
                    new_summary_min[i] = summary_min[i]
                    #print("This entry does not have non-backbone atom: ", i)

            min_summary_perc = {i:0 for i in new_summary_min.keys()}
            for i in new_summary_min:
                entry = [new_summary_min[i][j] for j in new_summary_min[i]]
                min_summary_perc[i] = np.mean(np.array([k for k in entry if k <= np.percentile(entry,percentile)]))

            ## SASA CALCULATIONS
            peptide_sasa = calculate_SASA(n_peptide, hla)

            ## Add calculations to dataset
            SASA_results[sample] = peptide_sasa
            Distance_with_percentile_results[sample] = [min_summary_perc[i] for i in min_summary_perc.keys()]
file.close()

In [None]:
print(len(SASA_results), len(Distance_with_percentile_results))
print(Counter([len(SASA_results[i]) for i in SASA_results.keys()]))
for i in Distance_with_percentile_results.keys():
    if len(Distance_with_percentile_results[i]) > 11:
        print(i,j)

## Generate file for input to pvacbind

In [None]:
for hla_name in sample_hla_query.keys():
    all_peptides = {9:[], 8:[], 10:[], 11:[]}
    for sample,peptide,peptide_length in sample_hla_query[hla_name]:
        if int(peptide_length) > 11:
            continue
        else:
            all_peptides[int(peptide_length)].append(peptide)
    for length in [8,9,10,11]:
        if len(all_peptides[length]) != 0:
            with open("pvacbind_files/anchor_validation_pvacbind_"+hla_name+"-"+str(length)+".fa", "w+") as fasta_file:
                for i,j in enumerate(all_peptides[length]):
                    fasta_file.write(">"+str(i+1)+'\n')
                    fasta_file.write(j+'\n')
                fasta_file.close()

with open("pvacbind_files/anchor_validation_pvacbind_query.tsv", "w+") as file:
    file.write("HLA"+"\t"+"length"+'\n')
    for hla_name in sample_hla_query.keys():
        all_peptides = {9:[], 8:[], 10:[], 11:[]}
        for sample,peptide,peptide_length in sample_hla_query[hla_name]:
            if int(peptide_length) > 11:
                break
            else:
                all_peptides[int(peptide_length)].append(peptide)
        for i in all_peptides.keys():
            if len(all_peptides[i]) != 0:
                file.write(hla_name+"\t"+str(i)+'\n')
file.close()

## Plotting with pvacbind results

In [None]:
binding_data = pd.read_csv("./validation_analysis_pvacbind_output_combined.txt", delimiter='\t')
sample_data = pd.read_csv("pdb_files/cleaned_list_pdb_structures.tsv", delimiter='\t', header=None)
sample_data_with_binding = pd.DataFrame()

In [None]:
sample_data_with_binding['HLA Allele'] = sample_data[0]
sample_data_with_binding['PDB_ID'] = sample_data[1]
sample_data_with_binding['Peptide'] = sample_data[2]
sample_data_with_binding['Length'] = sample_data[3]
best_binding = []
median_binding = []
SASA_vs_Predicted = []
SASA_vs_Predicted_pval = []
Distance_vs_Predicted = []
Distance_vs_Predicted_pval = []
randomized_SASA = []
randomized_Distance = []

for n, i in sample_data_with_binding.iterrows():
    matched_entry = binding_data.loc[(binding_data['Epitope Seq'] == i['Peptide']) & (binding_data['HLA Allele'] == "HLA-"+i['HLA Allele'])].iloc[0]
    best_binding.append(matched_entry['Best Score'])
    median_binding.append(matched_entry['Median Score'])
    sasa_calculation = SASA_results[i['PDB_ID']]
    distance_calculation = Distance_with_percentile_results[i['PDB_ID']]
    prediction_calculation = all_anchor_overall_pos_score_dict["HLA-"+i['HLA Allele']][len(i['Peptide'])]
    corr1, p1 = ss.spearmanr(sasa_calculation, prediction_calculation)
    corr2, p2 = ss.spearmanr(distance_calculation, prediction_calculation)
    SASA_vs_Predicted.append(corr1)
    SASA_vs_Predicted_pval.append(p1)
    Distance_vs_Predicted.append(corr2)
    Distance_vs_Predicted_pval.append(p2)
    random.shuffle(sasa_calculation)
    random.shuffle(distance_calculation)
    ran_corr1, ran_p1 = ss.spearmanr(sasa_calculation, prediction_calculation)
    ran_corr2, ran_p2 = ss.spearmanr(distance_calculation, prediction_calculation)
    randomized_SASA.append(ran_corr1)
    randomized_Distance.append(ran_corr2)


sample_data_with_binding['Best_Score'] = best_binding
sample_data_with_binding['Median_Score'] = median_binding
sample_data_with_binding['Distance'] = Distance_vs_Predicted
sample_data_with_binding['Distance_Pval'] = Distance_vs_Predicted_pval
sample_data_with_binding['SASA'] = SASA_vs_Predicted
sample_data_with_binding['SASA_Pval'] = SASA_vs_Predicted_pval
sample_data_with_binding['Random SASA'] = randomized_SASA
sample_data_with_binding['Random Distance'] = randomized_Distance

In [None]:
## Generating Random Subset for plotting
index_dictionary = {}
for n, i in sample_data_with_binding.iterrows():
    if i['HLA Allele'] in index_dictionary:
        index_dictionary[i['HLA Allele']].append(n)
    else:
        index_dictionary[i['HLA Allele']] = [n]
        
random_subset = []
for i in index_dictionary.keys():
    data = index_dictionary[i]
    if len(data) > 5:
        random_data = random.sample(data, k=5)
        random_subset += random_data
    else:
        random_subset += data

In [None]:
score_metrics = ["Best_Score", "Median_Score"]
correlation_metrics = ["Distance", "SASA"]
sample_data_with_binding_subset = sample_data_with_binding.iloc[random_subset]

In [None]:
sample_data_with_binding.to_csv("All_Data_with_Binding_and Correlation_Info.tsv", sep='\t')

In [None]:
sample_data_with_binding_subset.to_csv("Subset_Data_with_Binding_and Correlation_Info.tsv", sep='\t')

## Z-test for sample distributions

In [None]:
print(ss.ttest_ind(sample_data_with_binding_subset['Distance'], sample_data_with_binding_subset['Random Distance'], equal_var=False))
print(ss.ttest_ind(sample_data_with_binding_subset['SASA'], sample_data_with_binding_subset['Random SASA'], equal_var=False))

In [None]:
len(sample_data_with_binding_subset[sample_data_with_binding_subset['Distance'] <= 0])/len(sample_data_with_binding_subset)

## Plotting the Spearman Correlation Distribution

In [None]:
# Joint Plot
fig, axes = plt.subplots(2,1, figsize=(12,15))

sns.distplot(sample_data_with_binding_subset['Distance'], hist=False, rug=True, label="Distance Correlation", ax=axes[0]);
sns.distplot(sample_data_with_binding_subset['Random Distance'], hist=False, rug=True, label="Randomized Correlation", ax=axes[0]);
axes[0].set_xlabel('Spearman Correlation')
axes[0].set_ylabel('Density')
axes[0].title.set_text('Distribution of Spearman Correlations between Anchor Prediction & Distance Calculation')

sns.distplot(sample_data_with_binding_subset['SASA'], hist=False, rug=True, label="SASA Correlation", ax=axes[1]);
sns.distplot(sample_data_with_binding_subset['Random SASA'], hist=False, rug=True, label="Randomized Correlation", ax=axes[1]);
axes[1].set_xlabel('Spearman Correlation')
axes[1].set_ylabel('Density')
axes[1].title.set_text('Distribution of Spearman Correlations between Anchor Prediction & SASA Calculation')
plt.savefig('../../Anchor Paper/Main Figure 3/Distribution of Spearman Correlations between Anchor Prediction & Distance joint SASA Calculation.pdf')

In [None]:
HLA_allele = 'HLA-A*02:01'
length = 9

fig, axs = plt.subplots(2, figsize=(10,10))
fig.suptitle("Anchor Validation Data for "+HLA_allele+" at length " + str(length), fontweight='bold', fontsize=13, x=0.55)
for j in [i[0] for i in sample_hla_query[HLA_allele] if i[2] == str(length)]:
    x_pos = [k for k in range(1,length+1)]
    distance_result = Distance_with_percentile_results[j]
    sasa_result = SASA_results[j]

    axs[0].plot(x_pos, distance_result)
    axs[0].set_title('Distances between atoms of peptide at each position to all possible atoms of HLA')
    axs[0].set(ylabel='Distance (Å)')

    axs[1].plot(x_pos, sasa_result)
    axs[1].set_title('SASA of peptide at each position when in complex with HLA molecule')
    axs[1].set(ylabel='SASA ($Å^2$)', xlabel='Positions')

plt.tight_layout(pad=3)
plt.show()

In [None]:
HLA_allele = 'HLA-A*02:01'
length = 9

columns = [i for i in range(1,10)]
dataset_0201_distance = []
dataset_0201_sasa = []

for j in [i[0] for i in sample_hla_query[HLA_allele] if i[2] == str(length)]:
    #print(j)
    x_pos = [k for k in range(1,length+1)]
    distance_result = Distance_with_percentile_results[j]
    dataset_0201_distance.append(distance_result)
    sasa_result = SASA_results[j]
    dataset_0201_sasa.append(sasa_result)
dataset_0201_distance_df = pd.DataFrame(dataset_0201_distance, columns=columns)
dataset_0201_sasa_df = pd.DataFrame(dataset_0201_sasa, columns=columns)
dataset_0201_distance_df.to_csv("HLA-A*02:01_pMHC_crystallography_analysis_distance.tsv", sep='\t')
dataset_0201_sasa_df.to_csv("HLA-A*02:01_pMHC_crystallography_analysis_sasa.tsv", sep='\t')

In [None]:
HLA_allele = 'HLA-B*08:01'
length = 9

fig, axs = plt.subplots(2, figsize=(10,10))
fig.suptitle("Anchor Validation Data for "+HLA_allele+" at length " + str(length), fontweight='bold', fontsize=13, x=0.55)
for j in [i[0] for i in sample_hla_query[HLA_allele] if i[2] == str(length)]:
    x_pos = [k for k in range(1,length+1)]
    distance_result = Distance_with_percentile_results[j]
    sasa_result = SASA_results[j]

    axs[0].plot(x_pos, distance_result)
    axs[0].set_title('Distances between atoms of peptide at each position to all possible atoms of HLA')
    axs[0].set(ylabel='Distance (Å)')

    axs[1].plot(x_pos, sasa_result)
    axs[1].set_title('SASA of peptide at each position when in complex with HLA molecule')
    axs[1].set(ylabel='SASA ($Å^2$)', xlabel='Positions')

plt.tight_layout(pad=3)
plt.show()

In [None]:
HLA_allele = 'HLA-B*08:01'
length = 9

columns = [i for i in range(1,10)]
dataset_0801_distance = []
dataset_0801_sasa = []

for j in [i[0] for i in sample_hla_query[HLA_allele] if i[2] == str(length)]:
    x_pos = [k for k in range(1,length+1)]
    distance_result = Distance_with_percentile_results[j]
    dataset_0801_distance.append(distance_result)
    sasa_result = SASA_results[j]
    dataset_0801_sasa.append(sasa_result)
dataset_0801_distance_df = pd.DataFrame(dataset_0801_distance, columns=columns)
dataset_0801_sasa_df = pd.DataFrame(dataset_0801_sasa, columns=columns)
dataset_0801_distance_df.to_csv("HLA-B*08:01_pMHC_crystallography_analysis_distance.tsv", sep='\t')
dataset_0801_sasa_df.to_csv("HLA-B*08:01_pMHC_crystallography_analysis_sasa.tsv", sep='\t')