In [None]:
import pandas as pd
import numpy as np
import pickle
import csv
from collections import Counter
import matplotlib.pyplot as plt
import upsetplot
import pylab as pl
import seaborn as sns; sns.set(color_codes=True)
from scipy.spatial import distance
from scipy.cluster import hierarchy
from scipy.stats import spearmanr
from scipy.cluster.hierarchy import dendrogram
from sklearn.cluster import AgglomerativeClustering, KMeans

def save_obj(obj, name ):
    with open('obj/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name ):
    with open('obj/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

In [None]:
## Loading overall pos score dataset
all_anchor_overall_pos_score_dict_r4 = load_obj('all_anchor_overall_pos_score_dict_r4')
combined_peptide_database_round_4 = load_obj('combined_peptide_database_round_4')

In [None]:
hla_count = 0
hla_suff = 0 ## criteria here: at least 3 lengths with 10 
comb_count_1 = 0
comb_count_10 = 0
for i in combined_peptide_database_round_4:
    hla_count += 1
    #hla_list.append(i)
    count_suff = 0
    for j in combined_peptide_database_round_4[i]:
        if len(combined_peptide_database_round_4[i][j]) >= 1:
            comb_count_1 += 1
        if len(combined_peptide_database_round_4[i][j]) > 9:
            comb_count_10 += 1
            count_suff += 1
    if count_suff >= 3:
        hla_suff += 1
    
print(hla_count, comb_count_1, comb_count_10, hla_suff)

In [None]:
## Supplemental Figure: multipanel plotting showing distribution of database
distribution = {8:[],9:[],10:[],11:[]}
counts = {8:0,9:0,10:0,11:0, 'total':0}
for i in combined_peptide_database_round_4.keys():
    for j in combined_peptide_database_round_4[i]:
        distribution[j].append(len(combined_peptide_database_round_4[i][j]))
        counts[j] += len(combined_peptide_database_round_4[i][j])
        counts["total"] += len(combined_peptide_database_round_4[i][j])
counts

In [None]:
plt.rc('font', family='sans-serif')
fig, ax = plt.subplots(2, 2, figsize=(7,5))
bin_num = 25
for i,j,k in [(0,0,8),(0,1,9),(1,0,10),(1,1,11)]:
    maximum = np.max(distribution[k])
    minimum = np.min(distribution[k])
    logbins = np.logspace(np.log10(minimum),np.log10(maximum),bin_num)
    
    sns.distplot(distribution[k], ax=ax[i,j], kde=False, bins=logbins, rug=True)
    ax[i,j].set_xscale('log')
    ax[i,j].set_title('Length '+str(k)+" (N="+str(counts[k])+")")
    ax[i,j].set_xticklabels([0,0,1,10,100,1000,10000])
    #ax[i,j].set_xlabel('Number of peptides')
fig.text(0.5, 0.01, 'Number of Peptides per HLA Allele', ha='center')
fig.text(0.005, 0.5, 'Density', va='center', rotation='vertical')
plt.suptitle("Distribution of collected peptides for all 328 HLA alleles", )
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
#plt.savefig("Distribution of collected peptides for all 328 HLA alleles.png", dpi=300)
plt.show()

## Clustering and trends

In [None]:
## First we need to normalize each allele-length result
## Normalize by dividing by sum 
normalized_dict = {8:{}, 9:{}, 10:{}, 11:{}}
normalized_data = {8:[], 9:[], 10:[], 11:[]}
normalized_data_labels = {8:[], 9:[], 10:[], 11:[]}
for hla in all_anchor_overall_pos_score_dict_r4.keys():
    for length in all_anchor_overall_pos_score_dict_r4[hla].keys():
        normalized_dict[length][hla] = all_anchor_overall_pos_score_dict_r4[hla][length]/sum(all_anchor_overall_pos_score_dict_r4[hla][length])
        normalized_data[length].append(normalized_dict[length][hla])
        normalized_data_labels[length].append(hla)

for i in normalized_data.keys():
    normalized_data[i] = np.stack(normalized_data[i])

## OVERALL FIGURE

In [None]:
## heatmap of raw data
g_heatmap = sns.clustermap(pd.DataFrame.from_dict(normalized_dict[9], orient='index'), col_cluster=False, method='average', xticklabels=[1,2,3,4,5,6,7,8,9], yticklabels=False, cmap='YlGnBu_r')
ax = g_heatmap.ax_heatmap
ax.set_ylabel("HLA alleles")
ax.set_xlabel("Positions")
#g_heatmap.savefig('Hierarchical_clustering_with_average_linkage_for_raw_positional_data_horizontal_v2.pdf')

In [None]:
def cluster_plot_generating(n, length, draw_n, method='average', cluster_draw=[1,2,3,4,5,6,7,8,9,10]):
    model_n = AgglomerativeClustering(n_clusters=n, linkage=method)
    X = [normalized_dict[length][i] for i in normalized_dict[length]]
    model_n = model_n.fit(X)
    clusters = [[] for i in range(n)]
    for i,j in enumerate(model_n.labels_):
        for k in range(n):
            if j == k:
                clusters[k].append((normalized_data_labels[length][i],normalized_data[length][i]))
    mean_of_clusters = [[] for i in range(n)]
    for i in range(n):
        mean_of_clusters[i] = np.mean([k for j,k in clusters[i]], axis=0)
    #print(mean_of_clusters)
    ## Plotting subplots             
    fig = plt.subplots(draw_n,1, figsize=(10,14))
    edge = 10
    text_size_s = 5
    text_size_b = 15
    w = 10
    v =1
    for i in range(n+1):
        if i not in cluster_draw:
            continue
        else:
            i=i-1
            ax1 = plt.subplot(draw_n,1,v)
            v = v+1
            count = 1
            hla_str = "HLA alleles in this cluster: \n"
            for j in range(0, len(clusters[i])):
                ax1.plot(range(1,length+1),clusters[i][j][1])
                if count != 0 and count % w == 0:
                    hla_str += clusters[i][j][0]
                    hla_str += ", "
                    hla_str += "\n"
                    count += 1
                else:
                    hla_str += clusters[i][j][0]
                    hla_str += ", "
                    count += 1
            plt.subplots_adjust(left=0.3, right=0.9, bottom=0.3, top=0.9)
            plt.xlabel("Normalized Score")
            plt.ylabel("Positions")
            ax1.text(edge,1, hla_str, size=text_size_s, ha="left", va="top", wrap=True)
            plt.ylim(0,1)
            ax1.plot(range(1,length+1), mean_of_clusters[i], 'w--')


    plt.xticks(range(1,length+1))
    
    plt.suptitle("Anchor caluclations plotted for "+str(n)+" clusters determined \n using hierarchical clustering (linkage method "+method+") for length "+str(length), size=text_size_b, ha='center')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig("/Users/h.xia/Desktop/Griffith_Lab/Anchor Paper/Main Figure 2/"+str(length)+"mer/Anchor_clusters-length"+str(length)+"-"+str(n)+"-"+method+"clusters_with_mean.pdf")
    plt.show()
    plt.close()
    
    return clusters, model_n.labels_

In [None]:
cluster_test, labels = cluster_plot_generating(8, 9, 3, cluster_draw=[4,6,7])

In [None]:
hla_alleles = []
for i in normalized_dict:
    for j in normalized_dict[i]:
        hla_alleles.append(j)
len(set(hla_alleles))

In [None]:
nine_mer_data = {}
for i in all_anchor_overall_pos_score_dict_r4:
    try:
        nine_mer_data[i] = all_anchor_overall_pos_score_dict_r4[i][9]
    except:
        print("no 9mer:" + i)
        continue
nine_mer_df = pd.DataFrame.from_dict(nine_mer_data, orient='index')
nine_mer_norm_data = pd.DataFrame.from_dict(normalized_dict[9], orient='index')
method = "average"

def clusters_calc(n):
    model_n = AgglomerativeClustering(n_clusters=n, linkage=method)
    X = [normalized_dict[9][i] for i in normalized_dict[9]]
    model_n = model_n.fit(X)
    clusters = [[] for i in range(n)]
    for i,j in enumerate(model_n.labels_):
        for k in range(n):
            if j == k:
                clusters[k].append((normalized_data_labels[9][i],normalized_data[9][i]))
    return clusters

## orange 
clusters_3 = clusters_calc(3)
orange_hlas = [i for i, j in clusters_3[1]]
## green
green_hlas = [i for i,j in clusters_3[2]]

## red
clusters_8 = clusters_calc(8)
red_hlas = [i for i,j in clusters_8[3]]

## purple
purple_hlas = [i for i,j in clusters_8[5]]

## blue
blue_hlas = [i for i,j in clusters_8[6]]

## not colored
no_color_hlas = [i for i, j in clusters_8[7]]

## overall_color_dict 
cluster_color_dict = {}
def add_cluster(color, hla_list):
    for i in hla_list:
        cluster_color_dict[i] = color
    return None
add_cluster("orange", orange_hlas)
add_cluster("green", green_hlas)
add_cluster("red", red_hlas)
add_cluster("purple", purple_hlas)
add_cluster("blue", blue_hlas)
add_cluster("Not colored", no_color_hlas)

print(len(orange_hlas),len(green_hlas),len(red_hlas),len(purple_hlas),len(blue_hlas), len(no_color_hlas))
len(orange_hlas) + len(green_hlas) + len(red_hlas) + len(purple_hlas) + len(blue_hlas) + len(no_color_hlas)

In [None]:
Cluster_color_code = []
for i,j in nine_mer_norm_data.iterrows():
    Cluster_color_code.append(cluster_color_dict[i])
nine_mer_norm_data['Cluster Color Codes'] = Cluster_color_code

In [None]:
nine_mer_norm_data.to_csv("Normalized_anchor_predictions_9_mer_with_color_coding.tsv", sep="\t")

In [None]:
eight_mer_norm_data = pd.DataFrame.from_dict(normalized_dict[8], orient='index')
ten_mer_norm_data = pd.DataFrame.from_dict(normalized_dict[10], orient='index')
eleven_mer_norm_data = pd.DataFrame.from_dict(normalized_dict[11], orient='index')

In [None]:
eight_mer_norm_data.to_csv("Normalized_anchor_predictions_8_mer.tsv", sep="\t")
ten_mer_norm_data.to_csv("Normalized_anchor_predictions_10_mer.tsv", sep="\t")
eleven_mer_norm_data.to_csv("Normalized_anchor_predictions_11_mer.tsv", sep="\t")