In [None]:
#!/usr/bin/env python

import sys
import time
from pathlib import Path
from collections import defaultdict

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import cosine

In [None]:
start_time = time.time()
# Script to compute similarity matrices for subregion electrostatics of each PDB. Will need to plot them too. Hmmm -> matrix heatmap.

## Load the data ##

pdbs = []
potentials = {}
data_folder = Path('data')
with open(data_folder / 'all_spike_strs_regions_pot.csv', 'r') as f:
    header = next(f).split(',')
    print('column names:', header)
    data = defaultdict(dict)
    for line in f:
        mm = line.split(',')

        if len(mm) == 3:
            key_AG, key_region, potential = mm
            # key_region = int(key_region.split['_'][-1]) # transform region key to int?
            data[key_AG].update({key_region: float(potential)})

data = dict(data)

In [None]:
import pandas as pd
df = pd.DataFrame(data).T
df

All 8 pdbs have 20 regions defined: 

In [None]:
df.notna().sum(axis=1).value_counts().sort_index(ascending=False)

In [None]:
df = df.loc[:,'region_1':'region_20'].dropna()
df

continue with remaining complete list of antigens

In [None]:
def lower_triangle(df):
    """Compute the correlation matrix, returning only unique values."""
    lower_triangle = pd.DataFrame(
        np.tril(np.ones(df.shape), -1)).astype(bool)
    lower_triangle.index, lower_triangle.columns = df.index, df.columns
    return df.where(lower_triangle)

In [None]:
dict_dist= {}
metrics = ['cosine', 'euclidean', 'l2', 'manhattan', 'l1', 'hamming', 'chebyshev'] # 'jaccard' excluded as it's for binary data
for _metric in metrics:
    dict_dist[_metric] = pd.DataFrame(pairwise_distances(X=df, metric=_metric), index=df.index, columns=df.index)
    dict_dist[_metric] = lower_triangle(dict_dist[_metric]).stack()
df_metrics = pd.DataFrame(dict_dist)
df_metrics

## Normalization

$ z = \frac{x - min(X)}{max(X)-min(X)}$

where
- $x$: a single correlation value of a metric
- $X$: the set of correlations for a single metric
- $z$: a singe *normalized* correlation value of a metric


In [None]:
stats_metrics = df_metrics.describe()
stats_metrics

In [None]:
X_min = stats_metrics.loc['min'] ## doesn't give correct answer, because of dtype errors?
X_max = stats_metrics.loc['max'] ## doesn't give correct answer?
#df_metrics['cosine'] = pd.to_numeric(df_metrics.cosine)
#df_metrics['cosine'].max()
X_min = df_metrics.min()
X_max = df_metrics.max()
X_max

In [None]:
df_metrics_normalized = (df_metrics - X_min) / (X_max - X_min)
df_metrics_normalized

## Plotting the mean metrics heatmap

In [None]:
mean_metrics = df_metrics_normalized.mean(axis=1).unstack()
mean_metrics

In [None]:
# Set up the matplotlib figure
matplotlib.rc('xtick', labelsize=16)
matplotlib.rc('ytick', labelsize=16)

fig, ax = plt.subplots(figsize=(30,20)) 

ax = sns.heatmap(mean_metrics, annot=True, fmt='0.2f', cmap="RdBu_r", ax=ax, annot_kws={"size": 28} ) #annot=labels, fmt='',annot_kws={"size": 14}, cmap="RdBu_r") #fmt="0.2f",  cmap="RdBu_r")
ax.set_xticklabels(ax.get_xmajorticklabels(), fontsize = 28)
ax.set_yticklabels(ax.get_ymajorticklabels(), fontsize = 28)


## Plot Histogram of similarity matrix values

In [None]:
plt.hist(mean_metrics.to_numpy().flatten())
plt.ylabel('Count', fontsize=18)
plt.xlabel('Similarity score', fontsize=18)

plt.show()

### Compute and plot 'Super' similarity entries by taking average of similarity scores and select those with average above 0.6 (threshold)

In [None]:
rows_sum = mean_metrics.sum(axis=1)
rows_sum


In [None]:
columns_sum = mean_metrics.sum(axis=0)
columns_sum

In [None]:
row_sum = np.concatenate(([0],rows_sum.to_numpy())) ## insert leading zero.
row_sum
column_sum = np.concatenate((columns_sum.to_numpy(),[0])) ## insert last zero.
column_sum

Add arrays together, and plot histogram

In [None]:
super_similarity = ( row_sum + column_sum ) / len( row_sum )

plt.hist(super_similarity)
plt.ylabel('Count', fontsize=18)
plt.xlabel('Super similarity score', fontsize=18)

plt.axvline(x=0.6,lw=3,c='tab:orange')

plt.title('Super similarity histogram', fontsize=18)

plt.show()

Get indices of super similars

In [None]:
pdb_indices = []
for index in range(0, len(super_similarity)) :
    if super_similarity[index] > 0.6:
        pdb_indices.append(index)
        
pdb_indices

Get pdb ids of super similars

In [None]:
row_labels = np.array(mean_metrics.index)
row_labels

In [None]:
column_labels = np.array(mean_metrics.columns)
column_labels

In [None]:
all_pdb_ids = np.insert(row_labels,0,column_labels[0])
all_pdb_ids

In [None]:
super_similar_pdbs = all_pdb_ids[pdb_indices]
super_similar_pdbs

## Region by region comparison

In [None]:
region = 20
residues = 330

## df.iloc[:,0] ## column 1 = region 1
dict_dist= {}
metrics = ['cosine', 'euclidean', 'l2', 'manhattan', 'l1', 'hamming', 'chebyshev'] # 'jaccard' excluded as it's for binary data
for _metric in metrics:
    dict_dist[_metric] = pd.DataFrame(pairwise_distances(X=df[['region_{}'.format(region)]].to_numpy(), metric=_metric), index=df.index, columns=df.index)
    dict_dist[_metric] = lower_triangle(dict_dist[_metric]).stack()
df_metrics = pd.DataFrame(dict_dist)
df_metrics

In [None]:
stats_metrics = df_metrics.describe()
X_min = stats_metrics.loc['min']
X_max = stats_metrics.loc['max']
df_metrics_normalized = (df_metrics - X_min) / (X_max - X_min)
mean_metrics = df_metrics_normalized.mean(axis=1).unstack()

# Set up the matplotlib figure
matplotlib.rc('xtick', labelsize=18)
matplotlib.rc('ytick', labelsize=18)

fig, ax = plt.subplots(figsize=(30,20)) 

ax = sns.heatmap(mean_metrics, annot=True, fmt='0.2f', cmap="RdBu_r", ax=ax, annot_kws={"size": 28} ) #annot=labels, fmt='',annot_kws={"size": 14}, cmap="RdBu_r") #fmt="0.2f",  cmap="RdBu_r")
ax.set_xticklabels(ax.get_xmajorticklabels(), fontsize = 28)
ax.set_yticklabels(ax.get_ymajorticklabels(), fontsize = 28)

plt.title('Dissimilarity between Region{} Residues {}-{}'.format(region,residues+region*10 - 10,residues+region*10), fontsize = 28)

## Pairwise Comparison

In [None]:
pdb1 = '7lyn'
pdb2 = '7eb5'

pairwise_region = []

for region in range(1,21):
    dict_dist= {}
    metrics = ['cosine', 'euclidean', 'l2', 'manhattan', 'l1', 'hamming', 'chebyshev'] # 'jaccard' excluded as it's for binary data
    for _metric in metrics:
        dict_dist[_metric] = pd.DataFrame(pairwise_distances(X=df[['region_{}'.format(region)]].to_numpy(), metric=_metric), index=df.index, columns=df.index)
        dict_dist[_metric] = lower_triangle(dict_dist[_metric]).stack()
    df_metrics = pd.DataFrame(dict_dist)
    stats_metrics = df_metrics.describe()
    X_min = stats_metrics.loc['min']
    X_max = stats_metrics.loc['max']
    df_metrics_normalized = (df_metrics - X_min) / (X_max - X_min)
    mean_metrics = df_metrics_normalized.mean(axis=1).unstack()


    if pdb1 in mean_metrics[pdb2]:
        pairwise_region.append(mean_metrics[pdb2][[pdb1]])
    elif pdb2 in mean_metrics[pdb1]:
        pairwise_region.append(mean_metrics[pdb1][[pdb2]])
        
      
# Set up the matplotlib figure
matplotlib.rc('xtick', labelsize=18)
matplotlib.rc('ytick', labelsize=18)

fig, ax = plt.subplots(figsize=(30,20)) 

ax = sns.heatmap(pairwise_region, annot=True, fmt='0.2f', cmap="RdBu_r", ax=ax, annot_kws={"size": 36} ) #annot=labels, fmt='',annot_kws={"size": 14}, cmap="RdBu_r") #fmt="0.2f",  cmap="RdBu_r")
#ax.set_xticklabels(ax.get_xmajorticklabels(), fontsize = 36)
#ax.set_yticklabels(ax.get_ymajorticklabels(), fontsize = 28)
#ax.set_yticklabels([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], fontsize = 36)
ax.set_yticklabels(['330-340','340-350','350-360','360-370','370-380','380-390','390-400','400-410','410-420','420-430','430-440','440-450','450-460','460-470','470-480','480-490','490-500','500-510','510-520','520-530'], rotation=0, fontsize = 36)

plt.ylabel('Residues',fontsize = 36)

plt.title('Dissimilarity between {} {}'.format(pdb1,pdb2), fontsize = 36)
