In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

%matplotlib inline

In [None]:
sns.set_style("white")

In [None]:
# Display figures at a reasonable default size.
mpl.rcParams['figure.figsize'] = (6, 4)

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
    
# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 300

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14

## Load data

Load exhaustive grid search data. For each possible embedding method, corresponding method parameters, and HDBSCAN distance threshold, we produced an embedding for training and validation data (using 2-fold validation with 3 repeats for N=6 cross-validation iterations per parameter combination), assigned clusters to each embedding, and evaluated how well all pairs of strains in the data were assigned to the same or different cluster compared to predetermined clade assignments.

In [None]:
grid = pd.read_csv(snakemake.input.table)

In [None]:
grid.head()

In [None]:
list(grid.columns)

## Identify optimal method parameter values

Find the method parameters for each method that minimizes the mean squared error (MSE) across all replicates.

In [None]:
grid_columns = [
    "method",
    "components",
    "perplexity",
    "learning_rate",
    "nearest_neighbors",
    "min_dist",
]

In [None]:
grid

In [None]:
grid.groupby(grid_columns, dropna=False)["mse"].agg(["mean", "std"]).reset_index()

### PCA

In [None]:
pca_grid = grid.query("method == 'pca'")

In [None]:
pca_grid.shape

In [None]:
pca_grid

In [None]:
pca_mean_accuracy = pca_grid.groupby([
    "method",
    "components",
])["mse"].mean().reset_index().sort_values(
    "mse",
    ascending=False
)

In [None]:
pca_mean_accuracy

In [None]:
pca_best_accuracy = pca_mean_accuracy.sort_values("mse", ascending=True).head(1)

In [None]:
pca_best_accuracy

In [None]:
pca_best_accuracy.to_csv(
    snakemake.output.pca_parameters,
    index=False,
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=120)
ax = sns.boxplot(
    data=pca_grid,
    x="components",
    y="mse",
    ax=ax,
    color="#CCCCCC",
)

ax = sns.swarmplot(
    data=pca_grid,
    x="components",
    y="mse",
    ax=ax
)

ax.set_ylabel("Mean squared test error\nobserved and predicted Euclidean distance")
ax.set_ylim(bottom=0)

plt.tight_layout()
plt.savefig(snakemake.output.score_by_pca_parameters)

### MDS

In [None]:
mds_grid = grid.query("method == 'mds'")

In [None]:
mds_grid.shape

In [None]:
mds_mean_accuracy = mds_grid.groupby([
    "method",
    "components",
])["mse"].mean().reset_index().sort_values(
    "mse",
    ascending=False
)

In [None]:
mds_mean_accuracy

In [None]:
mds_best_accuracy = mds_mean_accuracy.sort_values("mse", ascending=True).head(1)

In [None]:
mds_best_accuracy

In [None]:
mds_best_accuracy.to_csv(
    snakemake.output.mds_parameters,
    index=False,
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=120)
ax = sns.boxplot(
    data=mds_grid,
    x="components",
    y="mse",
    ax=ax,
    color="#CCCCCC",
)

ax = sns.swarmplot(
    data=mds_grid,
    x="components",
    y="mse",
    ax=ax
)

ax.set_ylabel("Mean squared test error\nobserved and predicted Euclidean distance")
ax.set_ylim(bottom=0)

plt.tight_layout()
plt.savefig(snakemake.output.score_by_mds_parameters)

### t-SNE

In [None]:
tsne_grid = grid.query("method == 't-sne'")

In [None]:
tsne_grid.shape

In [None]:
tsne_mean_accuracy = tsne_grid.groupby([
    "method",
    "perplexity",
    "learning_rate"
])["mse"].mean().reset_index().sort_values(
    "mse",
    ascending=False
)

In [None]:
tsne_mean_accuracy

In [None]:
tsne_best_accuracy = tsne_mean_accuracy.sort_values("mse", ascending=True).head(1)

In [None]:
tsne_best_accuracy

In [None]:
tsne_best_accuracy.to_csv(
    snakemake.output.tsne_parameters,
    index=False,
)

In [None]:
facet_grid = sns.catplot(
    data=tsne_grid,
    x="perplexity",
    y="mse",
    hue="learning_rate",
    dodge=True,
    kind="box",
    aspect=1.41,
    height=6,
    legend=False,
)

for ax in facet_grid.axes.flatten():
    ax.set_xlabel("Perplexity")
    ax.set_ylabel("Mean squared test error\nobserved and predicted Euclidean distance")
    ax.set_ylim(bottom=0)

facet_grid.add_legend(
    title="Learning rate",
    loc="upper right",
)

plt.tight_layout()
plt.savefig(snakemake.output.score_by_tsne_parameters)

### UMAP

In [None]:
umap_grid = grid.query("method == 'umap'")

In [None]:
umap_grid.head()

In [None]:
umap_grid.shape

In [None]:
umap_mean_accuracy = umap_grid.groupby([
    "method",
    "min_dist",
    "nearest_neighbors",
])["mse"].mean().reset_index().sort_values(
    "mse",
    ascending=False
)

In [None]:
umap_mean_accuracy

In [None]:
umap_best_accuracy = umap_mean_accuracy.sort_values("mse", ascending=True).head(1)

In [None]:
umap_best_accuracy

In [None]:
umap_best_accuracy.to_csv(
    snakemake.output.umap_parameters,
    index=False,
)

In [None]:
facet_grid = sns.catplot(
    data=umap_grid,
    x="min_dist",
    y="mse",
    hue="nearest_neighbors",
    dodge=True,
    kind="box",
    aspect=1.41,
    height=6,
    legend=False,
)

for ax in facet_grid.axes.flatten():
    ax.set_xlabel("Minimum distance between points")
    ax.set_ylabel("Mean squared test error\nobserved and predicted Euclidean distance")
    ax.set_ylim(bottom=0)

facet_grid.add_legend(
    title="Nearest neighbors",
    loc="upper right",
)

plt.tight_layout()
plt.savefig(snakemake.output.score_by_umap_parameters)

## Find best accuracy per method

In [None]:
best_accuracy = pd.concat([pca_best_accuracy, mds_best_accuracy, tsne_best_accuracy, umap_best_accuracy])

In [None]:
best_accuracy

In [None]:
grid.head()

In [None]:
grid_summary = grid.merge(
    best_accuracy,
    on=grid_columns,
    suffixes=["", "_mean"]
).groupby([
    "method",
    "recombination_rate"
]).aggregate({
    "replicate": ["count"],
    "mse": ["mean", "std"],
})

In [None]:
grid_summary

In [None]:
grid_summary.columns

In [None]:
new_columns = ["_".join(columns) for columns in grid_summary.columns]

In [None]:
grid_summary = grid_summary.set_axis(new_columns, axis=1).reset_index()

In [None]:
grid_summary = grid_summary.sort_values("mse_mean")

In [None]:
grid_summary

In [None]:
grid_summary.to_csv(
    snakemake.output.summary_score_by_method,
    index=False,
)