In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

%matplotlib inline

In [None]:
sns.set_style("ticks")

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 120

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rc('text', usetex=False)

## Load data

In [None]:
df = pd.read_csv(snakemake.input.accuracies)

In [None]:
df.head()

In [None]:
name_by_method = {
    "pca": "PCA",
    "mds": "MDS",
    "t-sne": "t-SNE",
    "umap": "UMAP",
}

In [None]:
df["method_name"] = df["method"].map(name_by_method)

## Plot accuracies by method, sequences per group, and replicate

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=False, sharey=True)

# Even subsampling
even_df = df.query("subsampling_scheme == 'even'")
ax1 = sns.boxplot(
    x="subsample_max_sequences",
    y="normalized_vi",
    hue="method_name",
    data=even_df,
    dodge=True,
    fliersize=0,
    ax=ax1,
)
ax1 = sns.stripplot(
    x="subsample_max_sequences",
    y="normalized_vi",
    hue="method_name",
    data=even_df,
    dodge=True,
    alpha=0.5,
    color="#000000",
    ax=ax1,
)

ax1.set_ylim(bottom=0)

handles, labels = ax1.get_legend_handles_labels()
ax1.legend(
    handles[:len(name_by_method)],
    labels[:len(name_by_method)],
    title="Method",
    frameon=False,
    ncol=4,
    title_fontsize=12,
    handlelength=1,
    handletextpad=0.25,
    columnspacing=0.5,
)

ax1.set_xlabel("Number of sequences sampled evenly by geography and time")
ax1.set_ylabel("Cluster distance from\nNextstrain clades\n(normalized VI)")

# Random subsampling
random_df = df.query("subsampling_scheme == 'random'")

ax2 = sns.boxplot(
    x="subsample_max_sequences",
    y="normalized_vi",
    hue="method_name",
    data=random_df,
    dodge=True,
    fliersize=0,
    ax=ax2,
)
ax2 = sns.stripplot(
    x="subsample_max_sequences",
    y="normalized_vi",
    hue="method_name",
    data=random_df,
    dodge=True,
    alpha=0.5,
    color="#000000",
    ax=ax2,
)

ax2.set_ylim(bottom=0)
ax2.get_legend().remove()

ax2.set_xlabel("Number of sequences sampled randomly")
ax2.set_ylabel("Cluster distance from\nNextstrain clades\n(normalized VI)")

# Annotate panel labels.
panel_labels_dict = {
    "weight": "bold",
    "size": 14
}
plt.figtext(0.03, 0.97, "A", **panel_labels_dict)
plt.figtext(0.03, 0.49, "B", **panel_labels_dict)

plt.tight_layout()
plt.savefig(snakemake.output.accuracies)

In [None]:
min_max = df.groupby(["method", "subsample_max_sequences"]).agg({"normalized_vi": ["min", "max"]}).reset_index(drop=True)

In [None]:
(min_max.iloc[:, 1] - min_max.iloc[:, 0]).value_counts()

In [None]:
df.groupby(["method", "subsample_max_sequences", "subsampling_scheme"]).agg({"normalized_vi": ["median", "std"]})