In [None]:
import numpy as np
import regex
import pandas as pd
import random
import pickle
import matplotlib.pyplot as plt
%matplotlib inline
plt.ion()
import os
import matplotlib.style as style
import matplotlib.cm as mplcm
import matplotlib.colors as colors
import collections
import scipy.stats as ss

In [None]:
def save_obj(obj, name ):
    with open('obj/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name ):
    with open('/Users/h.xia/Desktop/Griffith_Lab/R_shiny_visualization/anchor_analysis/obj/saturation_analysis/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

## function that takes in 10 peptide sequences and outputs all possible mutation combinations
amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
def all_possible_mutations(peptides):
    new_peptides = []
    for i in peptides:
        #print(i)
        for j in range(0, len(i)):
            for k in amino_acids:
                new_peptide = i[:j]+k+i[j+1:]
                #print(new_peptide)
                new_peptides.append(new_peptide)
        #break
    return new_peptides

def three_mutation_sequence_match(seq1, seq2):
    m = regex.findall("("+seq1+"){s<=2}", seq2, overlapped=True)
    return m!=[]

### Generating FASTA files

In [None]:
combined_peptide_database = load_obj("combined_peptide_database")
## CHOOSING HLA-A*02:01 for Saturation Analysis:
## HLA-A*02:01 8 1524
## HLA-A*02:01 9 7383
## HLA-A*02:01 10 6939
## HLA-A*02:01 11 4175
#count = 0
#for i in combined_peptide_database.keys():
#    for j in combined_peptide_database[i].keys():
#        if len(combined_peptide_database[i][j]) < 10:
#            count += 1
            #print(i, j, len(combined_peptide_database[i][j]))
#print(count)
hla_a_0201_data = combined_peptide_database['HLA-A*02:01']
sizes = [1,5,10,20,50,100,200,500,1000]
def random_choice_3_mutations(input_data, total):
    new_list = input_data.copy()
    peptide_list = []
    count = 0
    while len(peptide_list) < total:
        if count >= 2000: 
            print('Exceeded Maximum tries, length of peptide list is: '+str(len(peptide_list)))
            break
        peptide = random.choice(new_list)
        new_list.remove(peptide)
        #print(peptide)
        #if peptide in peptide_list:
            #print('Skip')
            #continue
        curr_list = [peptide]
        for i in peptide_list:
            if three_mutation_sequence_match(i, peptide): 
                curr_list = []
                break
        peptide_list += curr_list
        count += 1
        #print(peptide_list)
    return peptide_list

In [None]:
saturation_dict = {}
for j in sizes:
    saturation_dict[j] = {}
    for i in hla_a_0201_data.keys():
        #print(len(hla_a_0201_data[i]))
        saturation_dict[j][i] = random_choice_3_mutations(hla_a_0201_data[i],total=j)
        print('Done with: '+str(i)+'_'+str(j))

In [None]:
#save_obj(saturation_dict, 'random_round_5/saturation_dict')
saturation_dict = load_obj('random_round_3/saturation_dict')

In [None]:
for i in saturation_dict.keys():
    for j in saturation_dict[i].keys():
        peptide_set = saturation_dict[i][j]
        mutated_set = all_possible_mutations(peptide_set)
        fasta_file = open("saturation_analysis/HLA_A_0201_anchor_"+str(j)+"mer_input_"+str(i)+".fa", 'w+')
        for m,n in enumerate(mutated_set):
            fasta_file.write(">"+str(m+1)+'\n')
            fasta_file.write(n+'\n')
        fasta_file.close()

### Analysis functions

In [None]:
### For finding the sequences that are at most one mutation away
def one_mutation_sequence_match(seq1, seq2):
    m = regex.findall("("+seq1+"){s<=1}", seq2, overlapped=True)
    return m!=[]
### For finding the sequences that are exact matches to original
def exact_mutation_sequence_match(seq1, seq2):
    m = regex.findall("("+seq1+"){s<=0}", seq2, overlapped=True)
    return m!=[]

### For generating the og_epitopes list from mutated sequences:
def create_og_epitope_list(input_data, query_mutations):
    og_epitope_list = []
    for i in input_data:
        if i in set(query_mutations['Epitope Seq']):
            og_epitope_list.append(i)
    og_epitope_set = set(og_epitope_list)
    return og_epitope_set

def find_median_score(epitope_list, all_epitope_data):
    score_list = []
    for i in epitope_list:
        epitope = all_epitope_data.loc[all_epitope_data['Epitope Seq'] == i]
        [score] = list(set(epitope['Median Score']))
        score_list.append(score)
    if len(score_list) != len(epitope_list):
        print('ERROR')
        return None
    return score_list
        

### For creating the dictionary matching all original epitopes to their mutation and scores
def create_mutation_dictionary(og_epitopes, all_mutated_epitopes, all_mutated_median):
    all_mutated_epitopes_new = all_mutated_epitopes.copy()
    mutation_dict = {}
    for i in og_epitopes:
        count = 0
        mutation_dict[i] = {}
        for k,j in enumerate(all_mutated_epitopes):
            if one_mutation_sequence_match(i,j):
                median = all_mutated_median[k]
                mutation_dict[i][j] = median
                all_mutated_epitopes_new = all_mutated_epitopes_new.drop(k)
                count += 1
    return mutation_dict


amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']

### for calculating sum differences for each individual epitope for each individual position
def calculate_sum_diff(og_epitopes, og_scores, all_mutant_dict): 
    sum_dict = {}
    for (i,j) in zip(og_epitopes, og_scores):
        mutant_list = all_mutant_dict[i]
        sum_dict[i] = []
        for n in range(0, len(i)):
            sum_ratio_diff = 0
            for k in amino_acids:
                query_peptide = i[:n]+k+i[n+1:]
                score = mutant_list[query_peptide]
                sum_ratio_diff += score/j
            sum_dict[i].append([n, sum_ratio_diff])
    return sum_dict

### for calculating the overall sum differences across all 10 original epitopes 
def overall_pos_sum_diff(sum_diff_results, pos_length):
    overall_score = [0]*pos_length
    for i in sum_diff_results.keys():
        for j in range(len(i)):
            (pos, score) = sum_diff_results[i][j]
            overall_score[pos] += score
    return overall_score

### Analysis for 8,9,10,11 length mers
#### Need to Manually loop through different random rounds

In [None]:
## Initialization
lengths = [8, 9, 10, 11]
overall_pos_scores_dict = {}
og_epitope = {}
round_num = 2
combined_peptide_database = load_obj("combined_peptide_database")
hla_a_0201_data = combined_peptide_database['HLA-A*02:01']
sizes = [1,5,10,20,50,100,200,500,1000]

In [None]:
saturation_dict = load_obj('random_round_'+str(round_num)+'/saturation_dict')
for n_mer in lengths:
    print("Processing length: "+ str(n_mer))
    ## First find the original epitopes and build a dictionary
    for i in sizes:
        all_epitopes_n_mer = pd.read_table('./saturation_analysis/round_'+str(round_num)+'/output/pvacbind_'+str(n_mer)+'_'+str(i)+'_output/MHC_Class_I/ANCHOR.all_epitopes.tsv', delimiter='\t')
        og_epitope[i] = create_og_epitope_list(hla_a_0201_data[n_mer], all_epitopes_n_mer)
        for i in og_epitope.keys():
            extra = []
            for j in og_epitope[i]:
                if j not in saturation_dict[i][n_mer]:
                    extra.append(j)
            for k in extra:
                og_epitope[i].remove(k)
        print(len(og_epitope[i]))

    ## Then for each size, calculate the corresponding normalized sum of ratios
    for i in sizes:
        if i in overall_pos_scores_dict:
            print("Done: "+ str(i))
            continue
        else:
            all_epitope_data = pd.read_table('./saturation_analysis/round_'+str(round_num)+'/output/pvacbind_'+str(n_mer)+'_'+str(i)+'_output/MHC_Class_I/ANCHOR.all_epitopes.tsv', delimiter='\t')
            mutation_dictionary = create_mutation_dictionary(og_epitope[i], all_epitope_data['Epitope Seq'], all_epitope_data['Median Score'])
            sum_diff_pos = calculate_sum_diff(og_epitope[i], find_median_score(og_epitope[i], all_epitope_data), mutation_dictionary)
            overall_pos_scores = [j/i for j in overall_pos_sum_diff(sum_diff_pos,n_mer)]
            overall_pos_scores_dict[i] = overall_pos_scores
        print("Done: "+ str(i))

    ### Save the dictionary generated in corresponding folder
    save_obj(overall_pos_scores_dict, 'random_round_'+str(round_num)+'/overall_pos_scores_'+str(n_mer)+'_mer_dict')

    ### Then calculate the pearson correlation between ground truth dataset (size 1000) & the data generated from each subsample
    standard = overall_pos_scores_dict[1000]
    overall_pos_corr = {}
    for i in sizes:
        trend1 = overall_pos_scores_dict[i]
        overall_pos_corr[i], p_val = ss.pearsonr(trend1, standard)
    save_obj(overall_pos_corr, 'random_round_'+str(round_num)+'/overall_pos_'+str(n_mer)+'_mer_corr')
    overall_pos_scores_dict = {}
    og_epitope = {}

In [None]:
## Initialization
for round_num in [1,2,3,4,5]:
    print("Now in round: "+str(round_num))
    lengths = [8, 9, 10, 11]
    sizes = [1,5,10,20,50,100,200,500,1000]
    overall_pos_scores_dict = {}
    og_epitope = {}
    combined_peptide_database = load_obj("combined_peptide_database")
    hla_a_0201_data = combined_peptide_database['HLA-A*02:01']

    saturation_dict = load_obj('random_round_'+str(round_num)+'/saturation_dict')
    for n_mer in lengths:
        print("Processing length: "+ str(n_mer))
        ## First find the original epitopes and build a dictionary
        for i in sizes:
            all_epitopes_n_mer = pd.read_table('./saturation_analysis/round_'+str(round_num)+'/output/pvacbind_'+str(n_mer)+'_'+str(i)+'_output/MHC_Class_I/ANCHOR.all_epitopes.tsv', delimiter='\t')
            og_epitope[i] = create_og_epitope_list(hla_a_0201_data[n_mer], all_epitopes_n_mer)
            for i in og_epitope.keys():
                extra = []
                for j in og_epitope[i]:
                    if j not in saturation_dict[i][n_mer]:
                        extra.append(j)
                for k in extra:
                    og_epitope[i].remove(k)
            print(len(og_epitope[i]))

        ## Then for each size, calculate the corresponding normalized sum of ratios
        for i in sizes:
            if i in overall_pos_scores_dict:
                print("Done: "+ str(i))
                continue
            else:
                all_epitope_data = pd.read_table('./saturation_analysis/round_'+str(round_num)+'/output/pvacbind_'+str(n_mer)+'_'+str(i)+'_output/MHC_Class_I/ANCHOR.all_epitopes.tsv', delimiter='\t')
                mutation_dictionary = create_mutation_dictionary(og_epitope[i], all_epitope_data['Epitope Seq'], all_epitope_data['Median Score'])
                sum_diff_pos = calculate_sum_diff(og_epitope[i], find_median_score(og_epitope[i], all_epitope_data), mutation_dictionary)
                overall_pos_scores = [j/i for j in overall_pos_sum_diff(sum_diff_pos,n_mer)]
                overall_pos_scores_dict[i] = overall_pos_scores
            print("Done: "+ str(i))

        ### Save the dictionary generated in corresponding folder
        save_obj(overall_pos_scores_dict, 'random_round_'+str(round_num)+'/overall_pos_scores_'+str(n_mer)+'_mer_dict')

        ### Then calculate the pearson correlation between ground truth dataset (size 1000) & the data generated from each subsample
        standard = overall_pos_scores_dict[1000]
        overall_pos_corr = {}
        for i in sizes:
            trend1 = overall_pos_scores_dict[i]
            overall_pos_corr[i], p_val = ss.pearsonr(trend1, standard)
        save_obj(overall_pos_corr, 'random_round_'+str(round_num)+'/overall_pos_'+str(n_mer)+'_mer_corr')
        overall_pos_scores_dict = {}
        og_epitope = {}

### Plotting Results from Rounds 1-5

In [None]:
length = 11
fig = plt.figure()
ax1 = fig.add_subplot(111)
data = []
for round_num in [1,2,3,4,5]:
    correlation_dict = load_obj('random_round_'+str(round_num)+'/overall_pos_'+str(length)+'_mer_corr')
    x = np.arange(len(correlation_dict.keys()))
    y = list(correlation_dict.values())
    print(x,y)
    data.append(y)
    ax1.scatter(x,y, label=str(round_num))
    ax1.plot(x,y)
print(correlation_dict.keys())
ax1.xaxis.set_ticklabels([0, 1, 5, 10, 20, 50, 100, 200, 500, 1000])
plt.legend(loc='lower right');
plt.title('Saturation Analysis using Spearman correlation for '+str(length)+'-mer peptides')
plt.xlabel('Subsample Size')
plt.ylabel('Spearman Correlation with Subsample 1000')
#plt.savefig('Saturation Analysis using Spearman correlation for '+str(length)+'-mer peptides')
plt.show()
pd.DataFrame(data)