In [None]:
import sys
sys.path.append("notebooks/scripts/")

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

## Define inputs, outputs, and parameters

In [None]:
sns.set_style("ticks")

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 120

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rc('text', usetex=False)

In [None]:
within_between_df_training = pd.read_csv(snakemake.input.within_between_df_training)
within_between_df_test = pd.read_csv(snakemake.input.within_between_df_test)
within_between_df_sars_training = pd.read_csv(snakemake.input.within_between_df_sars_training)
within_between_df_sars_test = pd.read_csv(snakemake.input.within_between_df_sars_test)

In [None]:
png_chart_flu = snakemake.output.flu_png
png_chart_sars = snakemake.output.sars_png

In [None]:
y_labels

In [None]:
flu_labels_to_axis_labels = {
    "clade_membership": "Nextstrain clade",
    "pca_label": "PCA",
    "mds_label": "MDS",
    "t-sne_label": "t-SNE",
    "umap_label": "UMAP",
}

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharey=True)

x_positions = np.arange(len(within_between_df_training.to_numpy()) // 2)

within_df = within_between_df_training[within_between_df_training["comparison"] == "within"].copy()
between_df = within_between_df_training[within_between_df_training["comparison"] == "between"].copy()

ax1.errorbar(
    within_df["mean"][::-1],
    x_positions - 0.05,
    xerr=within_df["std"][::-1],
    fmt='o',
    color="blue",
    label="within",
    capsize=2,
)
ax1.errorbar(
    between_df["mean"][::-1],
    x_positions + 0.05,
    xerr=between_df["std"][::-1],
    fmt='o',
    color="orange",
    label="between",
    capsize=2,
)
ax1.set_yticks(x_positions)
y_labels = list(reversed(within_df["group"].map(flu_labels_to_axis_labels).values))
ax1.set_yticklabels(y_labels)
ax1.set_xlim(0, 70)

ax1.set_title("Influenza H3N2 HA 2016-2018")

within_df = within_between_df_test[within_between_df_test["comparison"] == "within"].copy()
between_df = within_between_df_test[within_between_df_test["comparison"] == "between"].copy()

ax2.errorbar(
    within_df["mean"][::-1],
    x_positions - 0.05,
    xerr=within_df["std"][::-1],
    fmt='o',
    color="blue",
    label="within",
    capsize=2,
)
ax2.errorbar(
    between_df["mean"][::-1],
    x_positions + 0.05,
    xerr=between_df["std"][::-1],
    fmt='o',
    color="orange",
    label="between",
    capsize=2,
)
ax2.set_yticks(x_positions)
ax2.set_xlim(0, 70)

ax2.set_title("Influenza H3N2 HA 2018-2020")

ax1.legend(
    frameon=False,
    loc="upper right",
)
sns.despine()
plt.tight_layout()
plt.savefig(png_chart_flu)

In [None]:
def make_subplot_sars(df, ax, nextstrain_or_pango, label): #'Nextstrain_clade' or 'pango'
    grouped = df.groupby(df["group"].str.contains(nextstrain_or_pango))
    
    group = grouped.get_group(True)

    y_ticklabels = []

    x_positions = np.arange(len(group.to_numpy()) // 2)

    ax.errorbar(group[group["comparison"] == "within"]["mean"][::-1], x_positions, xerr= group[group["comparison"] == "within"]["std"][::-1], fmt='o', color="blue", label=label + " within", capsize=2)
    ax.errorbar(group[group["comparison"] == "between"]["mean"][::-1], x_positions + 0.2, xerr= group[group["comparison"] == "between"]["std"][::-1], fmt='o', color="orange", label=label + " between", capsize=2)
    y_ticklabels = [""] + [val.replace("_for_" + str(nextstrain_or_pango), "") for val in list(group[group["comparison"] == "within"]["group"])[::-1]]
    y_ticklabels[-1] = "clade_membership"

    ax.set_yticklabels(y_ticklabels)
    ax.set_xlim(0, 70)

    sns.despine()
    
    return ax

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(9,11), dpi=120, sharex=True, sharey=True)
make_subplot_sars(within_between_df_sars_training, ax[0][0], 'Nextstrain_clade', "Nextstrain clade") # early, nextclade
make_subplot_sars(within_between_df_sars_test, ax[0][1], 'Nextstrain_clade', "Nextstrain clade") # late, nextclade
make_subplot_sars(within_between_df_sars_training, ax[1][0], 'Nextclade_pango_collapsed', "Pango") # early, pango
make_subplot_sars(within_between_df_sars_test, ax[1][1], 'Nextclade_pango_collapsed', "Pango") # late, pango
# late - each method represented twice
# clade membership should have both nextstrain clade and pango lineages
# share row and col (2 by 2 figures) - left nextstrain clade, right  pango 
# generate both within_between dataframes for diff clade membership definitions

ax[0][0].set_title("SARS-CoV-2 2020/2022")
ax[0][1].set_title("SARS-CoV-2 2022/2023")
ax[0][0].set_ylabel("Nextstrain Clade")
ax[1][0].set_ylabel("Nextclade Pango")

ax[0][1].legend(
    frameon=False,
    bbox_to_anchor=(1.0, 1.0), 
    loc="upper left"
)

ax[1][1].legend(
    frameon=False,
    bbox_to_anchor=(1.0, 1.0), 
    loc="upper left"
)
    
plt.subplots_adjust(hspace=.0)#, wspace=.0)
sns.despine()
plt.savefig(png_chart_sars)

# make x axis bigger, move legend to x axis of one chart (rotate legend in another way, add y axis buffer to each figure)
# replace variable names (pca_label -> pca)
# replace clade_membership with actual clade_membership
# use nextclade_pango_collapsed in legend not just pango
# x axis label (pairwise genetic distance (nucleotides))