In [None]:
import sys
sys.path.append("notebooks/scripts/")

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

## Define inputs, outputs, and parameters

In [None]:
sns.set_style("ticks")

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 120

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rc('text', usetex=False)

In [None]:
within_between_df_training = pd.read_csv(snakemake.input.within_between_df_training)
within_between_df_test = pd.read_csv(snakemake.input.within_between_df_test)
within_between_df_sars_training = pd.read_csv(snakemake.input.within_between_df_sars_training)
within_between_df_sars_test = pd.read_csv(snakemake.input.within_between_df_sars_test)

In [None]:
png_chart_flu = snakemake.output.flu_png
png_chart_sars = snakemake.output.sars_png

In [None]:
flu_labels_to_axis_labels = {
    "clade_membership": "Nextstrain clade",
    "pca_label": "PCA cluster",
    "mds_label": "MDS cluster",
    "t-sne_label": "t-SNE cluster",
    "umap_label": "UMAP cluster",
    "genetic_label": "genetic distance cluster",
}

In [None]:
flu_training_upper_limit = int(
    np.ceil(
        (within_between_df_training["mean"] + within_between_df_training["std"]).max()
    )
)

In [None]:
flu_test_upper_limit = int(
    np.ceil(
        (within_between_df_test["mean"] + within_between_df_test["std"]).max()
    )
)

In [None]:
flu_upper_limit = max(flu_training_upper_limit, flu_test_upper_limit) + 1

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.5, 4), sharey=True)

x_positions = np.arange(len(within_between_df_training.to_numpy()) // 2)

within_df = within_between_df_training[within_between_df_training["comparison"] == "within"].copy()
between_df = within_between_df_training[within_between_df_training["comparison"] == "between"].copy()

ax1.errorbar(
    within_df["mean"][::-1],
    x_positions - 0.05,
    xerr=within_df["std"][::-1],
    fmt='o',
    color="blue",
    label="within",
    capsize=2,
)
ax1.errorbar(
    between_df["mean"][::-1],
    x_positions + 0.05,
    xerr=between_df["std"][::-1],
    fmt='o',
    color="orange",
    label="between",
    capsize=2,
)

ax1.set_yticks(x_positions)
y_labels = list(reversed(within_df["group"].map(flu_labels_to_axis_labels).values))
ax1.set_yticklabels(y_labels)

ax1.set_xlim(0, flu_upper_limit)

ax1.set_title("H3N2 HA 2016-2018")
ax1.set_xlabel("Pairwise nucleotide distance")
ax1.set_ylabel("Group")

within_df = within_between_df_test[within_between_df_test["comparison"] == "within"].copy()
between_df = within_between_df_test[within_between_df_test["comparison"] == "between"].copy()

ax2.errorbar(
    within_df["mean"][::-1],
    x_positions - 0.05,
    xerr=within_df["std"][::-1],
    fmt='o',
    color="blue",
    label="within",
    capsize=2,
)
ax2.errorbar(
    between_df["mean"][::-1],
    x_positions + 0.05,
    xerr=between_df["std"][::-1],
    fmt='o',
    color="orange",
    label="between",
    capsize=2,
)
ax2.set_yticks(x_positions)
ax2.set_xlim(0, flu_upper_limit)

ax2.set_title("H3N2 HA 2018-2020")
ax2.set_xlabel("Pairwise nucleotide distance")

ax1.legend(
    frameon=False,
    loc="upper right",
)
sns.despine()
plt.tight_layout()
plt.savefig(png_chart_flu)

In [None]:
sars_labels_to_axis_labels = {
    "Nextstrain_clade": "Nextstrain clade",
    "Nextclade_pango_collapsed": "Pango",
    "pca_label_for_Nextstrain_clade": "PCA cluster",
    "mds_label_for_Nextstrain_clade": "MDS cluster",
    "t-sne_label_for_Nextstrain_clade": "t-SNE cluster",
    "umap_label_for_Nextstrain_clade": "UMAP cluster",
    "genetic_label_for_Nextstrain_clade": "genetic distance cluster",
}

In [None]:
within_between_df_sars_training_to_plot = within_between_df_sars_training[
    within_between_df_sars_training["group"].isin(list(sars_labels_to_axis_labels.keys()))
]

In [None]:
within_between_df_sars_test_to_plot = within_between_df_sars_test[
    within_between_df_sars_test["group"].isin(list(sars_labels_to_axis_labels.keys()))
]

In [None]:
sars_training_upper_limit = int(
    np.ceil(
        (within_between_df_sars_training_to_plot["mean"] + within_between_df_sars_training_to_plot["std"]).max()
    )
)

In [None]:
sars_test_upper_limit = int(
    np.ceil(
        (within_between_df_sars_test_to_plot["mean"] + within_between_df_sars_test_to_plot["std"]).max()
    )
)

In [None]:
sars_upper_limit = max(sars_training_upper_limit, sars_test_upper_limit) + 1

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.5, 4), sharex=True, sharey=True)

# Training
within_df = within_between_df_sars_training_to_plot[
    within_between_df_sars_training_to_plot["comparison"] == "within"
].copy()
between_df = within_between_df_sars_training_to_plot[
    within_between_df_sars_training_to_plot["comparison"] == "between"
].copy()

y_positions = np.arange(within_df.shape[0])
y_labels = list(reversed(within_df["group"].map(sars_labels_to_axis_labels).values))

ax1.errorbar(
    within_df["mean"][::-1],
    y_positions - 0.1,
    xerr=within_df["std"][::-1],
    fmt='o',
    color="blue",
    label="within",
    capsize=2,
)
ax1.errorbar(
    between_df["mean"][::-1],
    y_positions + 0.1,
    xerr=between_df["std"][::-1],
    fmt='o',
    color="orange",
    label="between",
    capsize=2,
)

ax1.set_yticks(y_positions)
ax1.set_yticklabels(y_labels)

ax1.set_xlim(0, sars_upper_limit)

ax1.set_title("SARS-CoV-2 2020-2022")
ax1.set_xlabel("Pairwise nucleotide distance")
ax1.set_ylabel("Group")

# Test
within_df = within_between_df_sars_test_to_plot[
    within_between_df_sars_test_to_plot["comparison"] == "within"
].copy()
between_df = within_between_df_sars_test_to_plot[
    within_between_df_sars_test_to_plot["comparison"] == "between"
].copy()

ax2.errorbar(
    within_df["mean"][::-1],
    y_positions - 0.1,
    xerr=within_df["std"][::-1],
    fmt='o',
    color="blue",
    label="within",
    capsize=2,
)
ax2.errorbar(
    between_df["mean"][::-1],
    y_positions + 0.1,
    xerr=between_df["std"][::-1],
    fmt='o',
    color="orange",
    label="between",
    capsize=2,
)

ax2.set_title("SARS-CoV-2 2022-2023")
ax2.set_xlabel("Pairwise nucleotide distance")

ax1.legend(
    frameon=False,
    loc=(0.6, 0.8),
)

sns.despine()
plt.tight_layout()
plt.savefig(png_chart_sars)