# Compute all pairwise escape profile comparisons within each sample group and between each pair of groups
- Performs computations for a library batch specified by the `batch` variable below (Note: run time is a couple hours)
- Results of the computations are stored in a CSV file that can plotted with `plot_escape_compare_sims.ipynb` notebook

In [1]:
# Analysis definitions, constants, helper functions
from utils import *

import os
import copy
import warnings
warnings.filterwarnings('ignore')

# Python Optimal Transport package
import ot

In [2]:
# library batch ("SPIKE1" or "SPIKE2")
batch = "SPIKE2"

# Scaled differential selection data variable name
metric = "smooth_flank_1_enr_diff_sel"

# Data
peptide_table_file = 'data/phage-dms-nhp_peptide_table.csv'
sample_table_file  = 'data/phage-dms-nhp_sample_table.csv'
counts_file        = 'data/phage-dms-nhp_counts.csv'
enrichment_file    = 'data/phage-dms-nhp_enrichment.csv'
diff_sel_file      = f'data/phage-dms-nhp_{metric}.csv'

In [3]:
# Construct xarray dataset from CSV files

peptide_table_df = pd.read_csv(peptide_table_file,index_col='peptide_id')
peptide_table_arr = xr.DataArray(data=peptide_table_df, dims=['peptide_id','peptide_metadata'], name='peptide_table')
peptide_table_ds = peptide_table_arr.to_dataset()

sample_table_df  = pd.read_csv(sample_table_file,index_col='sample_id')
sample_table_arr = xr.DataArray(data=sample_table_df, dims=['sample_id','sample_metadata'], name='sample_table')
sample_table_ds = sample_table_arr.to_dataset()

counts_df = pd.read_csv(counts_file,index_col='peptide_id')
counts_df.columns = counts_df.columns.astype(int)
counts_arr = xr.DataArray(data=counts_df, dims=['peptide_id','sample_id'], name='counts')
counts_ds = counts_arr.to_dataset()

enrichment_df = pd.read_csv(enrichment_file,index_col='peptide_id')
enrichment_df.columns = enrichment_df.columns.astype(int)
enrichment_arr = xr.DataArray(data=enrichment_df, dims=['peptide_id','sample_id'], name='enrichment')
enrichment_ds = enrichment_arr.to_dataset()

diff_sel_df = pd.read_csv(diff_sel_file,index_col='peptide_id')
diff_sel_df.columns = diff_sel_df.columns.astype(int)
diff_sel_arr = xr.DataArray(data=diff_sel_df, dims=['peptide_id','sample_id'], name=metric)
diff_sel_ds = diff_sel_arr.to_dataset()

phip_ds = xr.merge([peptide_table_ds, sample_table_ds, counts_ds, enrichment_ds, diff_sel_ds])
 
batch_samples = id_coordinate_subset(phip_ds, where="library_batch", is_equal_to=batch)
ds = phip_ds.loc[dict(sample_id=batch_samples)]
ds

In [4]:
# Cost matrix for escape similarity score
cost_matrix = get_cost_matrix()

In [5]:
# Compute weights over sites in epitope region of interest
# for performing weighted sum of similiarity scores
def get_weights(
    ds,
    sid1,
    sid2,
    loc_start,
    loc_end
):   
    loc_sums1=[]
    loc_sums2=[]
    for loc in range(loc_start, loc_end+1):
        ds1 = ds.loc[
                dict(
                    peptide_id=peptide_id_coordinate_subset(ds,where='Loc',is_equal_to=loc),
                    sample_id=sample_id_coordinate_subset(ds,where='sample_ID',is_equal_to=sid1)
                    )
                ]

        diff_sel1 = ds1[metric].to_pandas().to_numpy().flatten()
        loc_sums1.append(0)
        for val in diff_sel1:
            loc_sums1[-1] = loc_sums1[-1] + abs(val)
        
        ds2 = ds.loc[
                dict(
                    peptide_id=peptide_id_coordinate_subset(ds,where='Loc',is_equal_to=loc),
                    sample_id=sample_id_coordinate_subset(ds,where='sample_ID',is_equal_to=sid2)
                    )
                ]
        
        diff_sel2 = ds2[metric].to_pandas().to_numpy().flatten()
        loc_sums2.append(0)
        for val in diff_sel2:
            loc_sums2[-1] = loc_sums2[-1] + abs(val)
    
    loc_sums1 = loc_sums1/np.sum(loc_sums1)
    loc_sums2 = loc_sums2/np.sum(loc_sums2)
        
    weights={}
    total=0
    for i,loc in zip(range(loc_end-loc_start+1), range(loc_start, loc_end+1)):
        val = min(loc_sums1[i], loc_sums2[i])
        total = total+val
        weights[loc] = val
    
    weights = {k: v/total for k,v in weights.items()}

    return weights

In [6]:
# Compute regional similarity score for a pair of escape profiles
def region_compare(ds, sid1, sid2, loc_start, loc_end):
    weights = get_weights(ds, sid1, sid2, loc_start, loc_end)
    region_sim=0
    for loc in range(loc_start, loc_end+1):
        a    = get_loc_escape_data(ds,sid1,loc,metric)
        b    = get_loc_escape_data(ds,sid2,loc,metric)
        cost = ot.emd2(a, b, cost_matrix)
        sim  = weights[loc]/cost
        if (np.sum(a)==0 and np.sum(b)>0) or (np.sum(a)>0 and np.sum(b)==0):
            sim
        region_sim = region_sim + sim
        
    return region_sim

## Perform pairwise escape profile comparisons

In [7]:
output_df = pd.DataFrame(columns=['sample_ID_1','group_1','sample_ID_2','group_2','epitope_region','similarity'])

group_list = [moderna, vaccinated_pigtail, conv_60d, convalescent_rhesus]
desc_list  = ['vaccinated_human', 'vaccinated_macaque', 'convalescent_human', 'convalescent_macaque']

other_group_list = copy.deepcopy(group_list)
other_desc_list  = copy.deepcopy(desc_list)
for group1, desc1 in zip(group_list, desc_list):
    for group2, desc2 in zip(other_group_list, other_desc_list):
        for region in epitope_limits:
            for sid1 in group1:
                for sid2 in group2:
                    if group1==group2 and sid1>=sid2: continue
                    sim = region_compare(ds, sid1, sid2, epitope_limits[region][0], epitope_limits[region][1])
                    output_df.loc[len(output_df.index)] = [sid1, desc1, sid2, desc2, region, sim]
    other_group_list.remove(group1)
    other_desc_list.remove(desc1)

output_df.to_csv(f'{batch}_escape_compare_sims.csv', index=False, na_rep="NA")
output_df

Unnamed: 0,sample_ID_1,group_1,sample_ID_2,group_2,epitope_region,similarity
0,254,vaccinated_human,256,vaccinated_human,CTDN,1.125129
1,254,vaccinated_human,260,vaccinated_human,CTDN,1.079677
2,254,vaccinated_human,262,vaccinated_human,CTDN,1.087283
3,254,vaccinated_human,266,vaccinated_human,CTDN,1.057653
4,254,vaccinated_human,268,vaccinated_human,CTDN,1.178810
...,...,...,...,...,...,...
3670,202,convalescent_macaque,206,convalescent_macaque,SHH,1.472064
3671,202,convalescent_macaque,208,convalescent_macaque,SHH,1.492371
3672,204,convalescent_macaque,206,convalescent_macaque,SHH,1.568540
3673,204,convalescent_macaque,208,convalescent_macaque,SHH,1.530216
