In [None]:
#!/usr/bin/env python

import sys
import time
from pathlib import Path
from collections import defaultdict

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import cosine

In [None]:
start_time = time.time()
# Script to compute similarity matrices for subregion electrostatics of each PDB. Will need to plot them too. Hmmm -> matrix heatmap.

## Load the data ##

pdbs = []
potentials = {}
data_folder = Path('data')
with open(data_folder / 'all_spike_strs_regions_pot.csv', 'r') as f:
    header = next(f).split(',')
    print('column names:', header)
    data = defaultdict(dict)
    for line in f:
        mm = line.split(',')

        if len(mm) == 3:
            key_AG, key_region, potential = mm
            # key_region = int(key_region.split['_'][-1]) # transform region key to int?
            data[key_AG].update({key_region: float(potential)})

data = dict(data)

In [None]:
import pandas as pd
df = pd.DataFrame(data).T
df

Only a 98 have 21 regions defined: 

In [None]:
df.notna().sum(axis=1).value_counts().sort_index(ascending=False)

In [None]:
df = df.loc[:,'region_1':'region_19'].dropna()
df

continue with remaining complete list of antigens

In [None]:
def lower_triangle(df):
    """Compute the correlation matrix, returning only unique values."""
    lower_triangle = pd.DataFrame(
        np.tril(np.ones(df.shape), -1)).astype(bool)
    lower_triangle.index, lower_triangle.columns = df.index, df.columns
    return df.where(lower_triangle)

In [None]:
dict_dist= {}
metrics = ['cosine', 'euclidean', 'l2', 'manhattan', 'l1', 'hamming', 'chebyshev'] # 'jaccard' excluded as it's for binary data
for _metric in metrics:
    dict_dist[_metric] = pd.DataFrame(pairwise_distances(X=df, metric=_metric), index=df.index, columns=df.index)
    dict_dist[_metric] = lower_triangle(dict_dist[_metric]).stack()
df_metrics = pd.DataFrame(dict_dist)
df_metrics

## Normalization

$ z = \frac{x - min(X)}{max(X)-min(X)}$

where
- $x$: a single correlation value of a metric
- $X$: the set of correlations for a single metric
- $z$: a singe *normalized* correlation value of a metric


In [None]:
stats_metrics = df_metrics.describe()
stats_metrics

In [None]:
X_min = stats_metrics.loc['min']
X_max = stats_metrics.loc['max']

In [None]:
df_metrics_normalized = (df_metrics - X_min) / (X_max - X_min)
df_metrics_normalized

## Plotting the mean metrics heatmap

In [None]:
mean_metrics = df_metrics_normalized.mean(axis=1).unstack()
mean_metrics

In [None]:
# Set up the matplotlib figure
matplotlib.rc('xtick', labelsize=16)
matplotlib.rc('ytick', labelsize=16)

fig, ax = plt.subplots(figsize=(30,20)) 

ax = sns.heatmap(mean_metrics, annot=False, cmap="RdBu_r", ax=ax) #annot=labels, fmt='',annot_kws={"size": 14}, cmap="RdBu_r") #fmt="0.2f",  cmap="RdBu_r")