In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform
import seaborn as sns

%matplotlib inline

In [None]:
sns.set_style("ticks")

## Define inputs, outputs, and parameters

In [None]:
tMRCAs_file = snakemake.input.tmrcas
pca_path = snakemake.input.embedding_pca
mds_path = snakemake.input.embedding_mds
tsne_path = snakemake.input.embedding_tsne
umap_path = snakemake.input.embedding_umap

distances_figure = snakemake.output.distances_figure

## Load data

### tMRCAs from recombination network

In [None]:
df = pd.read_csv(
    tMRCAs_file,
    sep="\t"
).rename(
    columns={"median_median": "median_tMRCA"}
)

In [None]:
df.head()

In [None]:
strains = sorted(df["reference_strain"].drop_duplicates().values)

In [None]:
other_strains = sorted(df["name"].drop_duplicates().values)

In [None]:
len(strains)

In [None]:
len(other_strains)

In [None]:
df.shape

In [None]:
df = df.set_index([
    "reference_strain",
    "name",
])

In [None]:
df.head()

In [None]:
df

In [None]:
tMRCAs = []
for i in range(len(strains)):
    strain_i = strains[i]
    for j in range(i + 1, len(strains)):
        strain_j = strains[j]
        
        try:
            tMRCA = df.loc[(strain_i, strain_j), "median_tMRCA"]
        except KeyError:
            tMRCA = np.NaN
            
        tMRCAs.append(tMRCA)

In [None]:
len(tMRCAs)

In [None]:
tMRCAs = np.array(tMRCAs)

In [None]:
tMRCAs

### PCA

In [None]:
pca = pd.read_csv(
    pca_path,
    index_col="strain",
)

pca = pca.drop(columns=[column for column in pca.columns if "label" in column])

In [None]:
pca.head()

In [None]:
pca_distances = pdist(pca.values)

In [None]:
pca_distances

### MDS

In [None]:
mds = pd.read_csv(
    mds_path,
    index_col="strain",
)

mds = mds.drop(columns=[column for column in mds.columns if "label" in column])

In [None]:
mds.head()

In [None]:
mds_distances = pdist(mds.values)

In [None]:
mds_distances

### t-SNE

In [None]:
tsne = pd.read_csv(
    tsne_path,
    index_col="strain",
)

tsne = tsne.drop(columns=[column for column in tsne.columns if "label" in column])

In [None]:
tsne.head()

In [None]:
tsne_distances = pdist(tsne.values)

In [None]:
tsne_distances

### UMAP

In [None]:
umap = pd.read_csv(
    umap_path,
    index_col="strain",
)

umap = umap.drop(columns=[column for column in umap.columns if "label" in column])

In [None]:
umap.head()

In [None]:
umap_distances = pdist(umap.values)

In [None]:
umap_distances

In [None]:
umap_distances.shape

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=120)
ax.hist(df["median_tMRCA"])

ax.set_xlabel(f"median tMRCA")
ax.set_ylabel("number of strains")

In [None]:
distances = pd.DataFrame({
    "tMRCA": tMRCAs,
    "pca": pca_distances,
    "mds": mds_distances,
    "t-sne": tsne_distances,
    "umap": umap_distances,
})

In [None]:
distances.head()

## Plot Euclidean distances per embedding by tMRCA

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(8, 8), dpi=200) 
all_axes = axes.flatten()

# PCA
ax1 = all_axes[0]
ax1 = sns.regplot(
    x="tMRCA",
    y="pca",
    data=distances,
    lowess=True,
    scatter_kws={"alpha": 0.25},
    line_kws={"color": "orange"},
    ax=ax1,
)

ax1.set_xlabel("median tMRCA")
ax1.set_ylabel(f"Euclidean distance (PCA)")

ax1.set_ylim(bottom=0)

# MDS
ax2 = all_axes[1]
ax2 = sns.regplot(
    x="tMRCA",
    y="mds",
    data=distances,
    lowess=True,
    scatter_kws={"alpha": 0.25},
    line_kws={"color": "orange"},
    ax=ax2,
)

ax2.set_xlabel("median tMRCA")
ax2.set_ylabel(f"Euclidean distance (MDS)")

ax2.set_ylim(bottom=0)

# t-SNE
ax3 = all_axes[2]
ax3 = sns.regplot(
    x="tMRCA",
    y="t-sne",
    data=distances,
    lowess=True,
    scatter_kws={"alpha": 0.25},
    line_kws={"color": "orange"},
    ax=ax3,
)

ax3.set_xlabel("median tMRCA")
ax3.set_ylabel(f"Euclidean distance (t-SNE)")

ax3.set_ylim(bottom=0)

# UMAP
ax4 = all_axes[3]
ax4 = sns.regplot(
    x="tMRCA",
    y="umap",
    data=distances,
    lowess=True,
    scatter_kws={"alpha": 0.25},
    line_kws={"color": "orange"},
    ax=ax4,
)

ax4.set_xlabel("median tMRCA")
ax4.set_ylabel(f"Euclidean distance (UMAP)")

ax4.set_ylim(bottom=0)

sns.despine()

plt.tight_layout()
plt.savefig(distances_figure)