In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

%matplotlib inline

In [None]:
sns.set_style("ticks")

In [None]:
# Display figures at a reasonable default size.
mpl.rcParams['figure.figsize'] = (6, 4)

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
    
# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 300

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14

In [None]:
# Minimum distance in nucleotides that an additional component needs
# to reduce the error for a given embedding (PCA or MDS).
minimum_distance_improvement = 1.0

In [None]:
# Function to summarize error across replicates and train/test splits.
error_summary_statistic = "median"

## Load data

Load exhaustive grid search data. For each possible embedding method, corresponding method parameters, and HDBSCAN distance threshold, we produced an embedding for training and validation data (using 2-fold validation with 3 repeats for N=6 cross-validation iterations per parameter combination), assigned clusters to each embedding, and evaluated how well all pairs of strains in the data were assigned to the same or different cluster compared to predetermined clade assignments.

In [None]:
grid = pd.read_csv(snakemake.input.table)

In [None]:
grid.head()

In [None]:
list(grid.columns)

## Identify optimal method parameter values

Find the method parameters for each method that minimizes the mean absolute error (MAE) across all replicates.

In [None]:
grid_columns = [
    "method",
    "components",
    "perplexity",
    "learning_rate",
    "nearest_neighbors",
    "min_dist",
]

In [None]:
grid

In [None]:
grid.groupby(grid_columns, dropna=False)["mae"].agg(["mean", "std"]).reset_index()

### PCA

In [None]:
pca_grid = grid.query("method == 'pca'").dropna(axis=1)

In [None]:
pca_grid["components"] = pca_grid["components"].astype(int)

In [None]:
pca_grid.shape

In [None]:
pca_grid

In [None]:
pca_accuracy = pca_grid.groupby([
    "method",
    "components",
])["mae"].aggregate(
    error_summary_statistic
).reset_index().sort_values(
    "mae",
    ascending=False
)

In [None]:
pca_accuracy

In [None]:
pca_mae = pca_accuracy["mae"].values

In [None]:
np.abs(np.diff(pca_mae)) >= minimum_distance_improvement

In [None]:
pca_indices_with_improvement = np.where(
    np.abs(np.diff(pca_mae)) >= minimum_distance_improvement
)[0]

In [None]:
if len(pca_indices_with_improvement) > 0:
    best_pca_mae_index = pca_indices_with_improvement[-1] + 1
else:
    best_pca_mae_index = 0

In [None]:
best_pca_mae_index

In [None]:
pca_best_accuracy = pca_accuracy[
    pca_accuracy["components"] == pca_accuracy.loc[best_pca_mae_index, "components"]
].copy()

In [None]:
pca_best_accuracy

In [None]:
pca_best_accuracy["virus"] = snakemake.wildcards.virus
pca_best_accuracy["recombination_rate"] = snakemake.wildcards.recombination_rate

In [None]:
pca_best_accuracy

In [None]:
pca_best_accuracy.to_csv(
    snakemake.output.pca_parameters,
    index=False,
)

### MDS

In [None]:
mds_grid = grid.query("method == 'mds'").dropna(axis=1)

In [None]:
mds_grid["components"] = mds_grid["components"].astype(int)

In [None]:
mds_grid.shape

In [None]:
mds_grid

In [None]:
mds_accuracy = mds_grid.groupby([
    "method",
    "components",
])["mae"].aggregate(
    error_summary_statistic
).reset_index().sort_values(
    "mae",
    ascending=False
)

In [None]:
mds_accuracy

In [None]:
mds_accuracy["mae"].values

In [None]:
mds_mae = mds_accuracy["mae"].values

In [None]:
np.abs(np.diff(mds_mae)) >= minimum_distance_improvement

In [None]:
best_mds_mae_index = np.where(np.abs(np.diff(mds_mae)) >= minimum_distance_improvement)[0][-1] + 1

In [None]:
mds_mae[best_mds_mae_index]

In [None]:
mds_best_accuracy = mds_accuracy[
    mds_accuracy["components"] == mds_accuracy.loc[best_mds_mae_index, "components"]
].copy()

In [None]:
mds_best_accuracy

In [None]:
mds_best_accuracy["virus"] = snakemake.wildcards.virus
mds_best_accuracy["recombination_rate"] = snakemake.wildcards.recombination_rate

In [None]:
mds_best_accuracy

In [None]:
mds_best_accuracy.to_csv(
    snakemake.output.mds_parameters,
    index=False,
)

### t-SNE

In [None]:
tsne_grid = grid.query("method == 't-sne'").dropna(axis=1)

In [None]:
tsne_grid["perplexity"] = tsne_grid["perplexity"].astype(float)

In [None]:
tsne_grid["learning_rate"] = tsne_grid["learning_rate"].astype(float)

In [None]:
tsne_grid.shape

In [None]:
tsne_grid

In [None]:
tsne_mean_accuracy = tsne_grid.groupby([
    "method",
    "perplexity",
    "learning_rate"
])["mae"].aggregate(
    error_summary_statistic
).reset_index().sort_values(
    "mae",
    ascending=False
)

In [None]:
tsne_mean_accuracy

In [None]:
tsne_best_accuracy = tsne_mean_accuracy.sort_values("mae", ascending=True).head(1)

In [None]:
tsne_best_accuracy

In [None]:
tsne_best_accuracy["virus"] = snakemake.wildcards.virus
tsne_best_accuracy["recombination_rate"] = snakemake.wildcards.recombination_rate

In [None]:
tsne_best_accuracy

In [None]:
tsne_best_accuracy.to_csv(
    snakemake.output.tsne_parameters,
    index=False,
)

### UMAP

In [None]:
umap_grid = grid.query("method == 'umap'").dropna(axis=1)

In [None]:
umap_grid["nearest_neighbors"] = umap_grid["nearest_neighbors"].astype(float)

In [None]:
umap_grid.head()

In [None]:
umap_grid.shape

In [None]:
umap_mean_accuracy = umap_grid.groupby([
    "method",
    "min_dist",
    "nearest_neighbors",
])["mae"].aggregate(
    error_summary_statistic
).reset_index().sort_values(
    "mae",
    ascending=False
)

In [None]:
umap_mean_accuracy

In [None]:
umap_best_accuracy = umap_mean_accuracy.sort_values("mae", ascending=True).head(1)

In [None]:
umap_best_accuracy

In [None]:
umap_best_accuracy["virus"] = snakemake.wildcards.virus
umap_best_accuracy["recombination_rate"] = snakemake.wildcards.recombination_rate

In [None]:
umap_best_accuracy

In [None]:
umap_best_accuracy.to_csv(
    snakemake.output.umap_parameters,
    index=False,
)

In [None]:
upper_ylim = int(np.ceil(grid["mae"].max()) + 1)

In [None]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 8), dpi=200, sharey=True)

ax1.set_title("PCA")
ax1 = sns.boxplot(
    data=pca_grid,
    x="components",
    y="mae",
    ax=ax1,
    color="#CCCCCC",
)

ax1.set_xlabel("Number of components")
ax1.set_ylabel("Mean absolute test error observed\nand predicted genetic distance")
ax1.set_ylim(bottom=-2, top=upper_ylim)

ax2.set_title("MDS")
ax2 = sns.boxplot(
    data=mds_grid,
    x="components",
    y="mae",
    ax=ax2,
    color="#CCCCCC",
)

ax2.set_xlabel("Number of components")
ax2.set_ylabel("")
ax2.set_ylim(bottom=-2, top=upper_ylim)

ax3.set_title("t-SNE")
ax3 = sns.boxplot(
    data=tsne_grid,
    x="perplexity",
    y="mae",
    hue="learning_rate",
    dodge=True,
    ax=ax3,
)

ax3.legend(
    title="Learning rate",
    loc="lower left",
    frameon=False,
    ncol=3,
    handletextpad=0.5,
    title_fontsize=12,
)

ax3.set_xlabel("Perplexity")
ax3.set_ylabel("Mean absolute test error observed\nand predicted genetic distance")
ax3.set_ylim(bottom=-4, top=upper_ylim)

ax4.set_title("UMAP")
ax4 = sns.boxplot(
    data=umap_grid,
    x="min_dist",
    y="mae",
    hue="nearest_neighbors",
    dodge=True,
    ax=ax4,
)

ax4.legend(
    title="Nearest neighbors",
    loc="lower left",
    frameon=False,
    ncol=3,
    handletextpad=0.5,
    title_fontsize=12,
)

ax4.set_xlabel("Minimum distance between points")
ax4.set_ylabel("")
ax4.set_ylim(bottom=-4, top=upper_ylim)

# Annotate panel labels.
panel_labels_dict = {
    "weight": "bold",
    "size": 14
}
plt.figtext(0.005, 0.97, "A", **panel_labels_dict)
plt.figtext(0.55, 0.97, "B", **panel_labels_dict)
plt.figtext(0.005, 0.47, "C", **panel_labels_dict)
plt.figtext(0.55, 0.47, "D", **panel_labels_dict)

plt.tight_layout()
plt.savefig(snakemake.output.scores_by_parameters)

## Find best accuracy per method

In [None]:
best_accuracy = pd.concat([pca_best_accuracy, mds_best_accuracy, tsne_best_accuracy, umap_best_accuracy])

In [None]:
best_accuracy

In [None]:
output_columns = ["virus", "recombination_rate"] + grid_columns + ["mae"]
best_accuracy.to_csv(
    snakemake.output.summary_score_by_method,
    index=False,
    columns=output_columns,
)