In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform
import seaborn as sns

%matplotlib inline

In [None]:
sns.set_style("ticks")

In [None]:
tMRCAs_file = snakemake.input.tmrcas
embedding_file = snakemake.input.embedding

distances_figure = snakemake.output.distances_figure

In [None]:
df = pd.read_csv(
    tMRCAs_file,
    sep="\t"
).rename(
    columns={"median_median": "median_tMRCA"}
)

In [None]:
df.head()

In [None]:
strains = sorted(df["reference_strain"].drop_duplicates().values)

In [None]:
other_strains = sorted(df["name"].drop_duplicates().values)

In [None]:
len(strains)

In [None]:
len(other_strains)

In [None]:
df.shape

In [None]:
df = df.set_index([
    "reference_strain",
    "name",
])

In [None]:
df.head()

In [None]:
df

In [None]:
df.loc[("AbuDhabi_UAE_30_2014|KP209309|human|2014-04-19", "Riyadh-KKUH-291||human|2014-05-06"), "median_tMRCA"]

In [None]:
tMRCAs = []
for i in range(len(strains)):
    strain_i = strains[i]
    for j in range(i + 1, len(strains)):
        strain_j = strains[j]
        
        try:
            tMRCA = df.loc[(strain_i, strain_j), "median_tMRCA"]
        except KeyError:
            tMRCA = np.NaN
            
        tMRCAs.append(tMRCA)

In [None]:
len(tMRCAs)

In [None]:
tMRCAs = np.array(tMRCAs)

In [None]:
tMRCAs

In [None]:
embedding_name = "MDS"

In [None]:
embedding_columns = ["mds1", "mds2", "mds3", "mds4", "mds5", "mds6"]

In [None]:
embedding = pd.read_csv(
    embedding_file,
    index_col="strain",
    usecols=["strain"] + embedding_columns
)

In [None]:
embedding.head()

In [None]:
embedding.shape

In [None]:
euclidean_distances = pdist(embedding.values)

In [None]:
euclidean_distances

In [None]:
euclidean_distances.shape

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=120)
ax.hist(df["median_tMRCA"])

ax.set_xlabel(f"median tMRCA")
ax.set_ylabel("number of strains")

In [None]:
distances = pd.DataFrame({
    "tMRCA": tMRCAs,
    "euclidean": euclidean_distances,
})

In [None]:
distances.head()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=120)

sns.regplot(
    x="euclidean",
    y="tMRCA",
    data=distances,
    lowess=True,
    scatter_kws={"alpha": 0.25},
    line_kws={"color": "orange"},
    ax=ax,
)

ax.set_xlabel(f"Euclidean distance ({embedding_name})")
ax.set_ylabel("median tMRCA")

ax.set_ylim(bottom=0)

sns.despine()

plt.tight_layout()
plt.savefig(distances_figure, dpi=200)