In [None]:
import sys
sys.path.append("notebooks/scripts/")

## Define inputs, outputs, and parameters

In [None]:
import matplotlib.pyplot as plt 

In [None]:
df_path = snakemake.input.dataframe

In [None]:
png_chart = snakemake.output.png

In [None]:
import pandas as pd
within_between_df_training = pd.read_csv("../seasonal-flu-nextstrain/results/full_within_between_stats.csv")
within_between_df_test = pd.read_csv("../seasonal-flu-nextstrain-2018-2020/results/full_within_between_stats.csv")
within_between_df_sars_training = pd.read_csv("../sars-cov-2-nextstrain/results/full_within_between_stats.csv")
within_between_df_sars_test = pd.read_csv("../sars-cov-2-nextstrain-2022-2023/results/full_within_between_stats.csv")

In [None]:
within_between_df_sars_test

In [None]:
def make_subplot(df, ax):
    x_positions = np.arange(len(df.to_numpy()) // 2)
    
    ax.errorbar(df[df["comparison"] == "within"]["mean"][::-1], x_positions, xerr= df[df["comparison"] == "within"]["std"][::-1], fmt='o', color="blue", label="within", capsize=2)
    ax.errorbar(df[df["comparison"] == "between"]["mean"][::-1], x_positions + 0.1, xerr= df[df["comparison"] == "between"]["std"][::-1], fmt='o', color="orange", label="between", capsize=2)
    ax.set_yticklabels([""] + list(df[df["comparison"] == "within"]["group"])[::-1])
    ax.set_xlim(0, 70)

    sns.despine()
    return ax

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

fig, ax = plt.subplots(2, 1, figsize=(6, 9), dpi=120, sharex='col')

make_subplot(within_between_df_training, ax[0])
make_subplot(within_between_df_test, ax[1])

ax[0].set_ylabel("H3N2 HA Influenza 2016/2018")
ax[1].set_ylabel("H3N2 HA Influenza 2018/2020")

ax[0].legend(
    frameon=False,
    bbox_to_anchor=(0.8, 1.0), 
    loc="upper left"
)

plt.subplots_adjust(hspace=.0)
sns.despine()

#early and late side by side instead of below eachother

In [None]:
def make_subplot_sars(df, ax, nextstrain_or_pango, label): #'Nextstrain_clade' or 'pango'
    grouped = df.groupby(df["group"].str.contains(nextstrain_or_pango))
    
    group = grouped.get_group(True)

    y_ticklabels = []

    x_positions = np.arange(len(group.to_numpy()) // 2)

    ax.errorbar(group[group["comparison"] == "within"]["mean"][::-1], x_positions, xerr= group[group["comparison"] == "within"]["std"][::-1], fmt='o', color="blue", label=label + " within", capsize=2)
    ax.errorbar(group[group["comparison"] == "between"]["mean"][::-1], x_positions + 0.2, xerr= group[group["comparison"] == "between"]["std"][::-1], fmt='o', color="orange", label=label + " between", capsize=2)
    y_ticklabels = [""] + [val.replace("_for_" + str(nextstrain_or_pango), "") for val in list(group[group["comparison"] == "within"]["group"])[::-1]]
    y_ticklabels[-1] = "clade_membership"

    ax.set_yticklabels(y_ticklabels)
    ax.set_xlim(0, 70)

    sns.despine()
    
    return ax

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(9,9), dpi=120, sharex=True, sharey=True)
make_subplot_sars(within_between_df_sars_training, ax[0][0], 'Nextstrain_clade', "Nextstrain clade") # early, nextclade
make_subplot_sars(within_between_df_sars_test, ax[0][1], 'Nextstrain_clade', "Nextstrain clade") # late, nextclade
make_subplot_sars(within_between_df_sars_training, ax[1][0], 'Nextclade_pango_collapsed', "Pango") # early, pango
make_subplot_sars(within_between_df_sars_test, ax[1][1], 'Nextclade_pango_collapsed', "Pango") # late, pango
# late - each method represented twice
# clade membership should have both nextstrain clade and pango lineages
# share row and col (2 by 2 figures) - left nextstrain clade, right  pango 
# generate both within_between dataframes for diff clade membership definitions

ax[0][0].set_title("SARS-CoV-2 2020/2022")
ax[0][1].set_title("SARS-CoV-2 2022/2023")
ax[0][0].set_ylabel("Nextstrain Clade")
ax[1][0].set_ylabel("Nextclade Pango")

ax[0][1].legend(
    frameon=False,
    bbox_to_anchor=(1.0, 1.0), 
    loc="upper left"
)

ax[1][1].legend(
    frameon=False,
    bbox_to_anchor=(1.0, 1.0), 
    loc="upper left"
)
    
plt.subplots_adjust(hspace=.0)#, wspace=.0)
sns.despine()

# make x axis bigger, move legend to x axis of one chart (rotate legend in another way, add y axis buffer to each figure)
# replace variable names (pca_label -> pca)
# replace clade_membership with actual clade_membership
# use nextclade_pango_collapsed in legend not just pango
# x axis label (pairwise genetic distance (nucleotides))

In [None]:
within_between_df_test

In [None]:
# old v new, mcc, clade membership
import matplotlib
import numpy as np

fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=120)

df = within_between_df_test

x_positions = np.arange(len(df.to_numpy()) // 2)

ax.errorbar(df[df["comparison"] == "within"]["mean"],x_positions, xerr= df[df["comparison"] == "within"]["std"], fmt='o', color="blue")
ax.errorbar(df[df["comparison"] == "between"]["mean"], x_positions + 0.2, xerr= df[df["comparison"] == "between"]["std"], fmt='o', color="orange")
#ax.scatter(within_between_df["group"], within_between_df["mean"], c=cmap(within_between_df["comparison"]), s=100)
ax.set_yticklabels([""] + list(df["group"].unique()))
ax.set_xlim(0)

ax.legend(
    frameon=False,
    bbox_to_anchor=(1.05, 1.1), 
    loc="upper left"
)
sns.despine()

In [None]:
# old v new, mcc, clade membership
import matplotlib
import numpy as np

fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=120)

df = within_between_df_sars

x_positions = np.arange(len(df.to_numpy()) // 2)

ax.errorbar(df[df["comparison"] == "within"]["mean"],x_positions, xerr= df[df["comparison"] == "within"]["std"], fmt='o', color="blue")
ax.errorbar(df[df["comparison"] == "between"]["mean"], x_positions + 0.2, xerr= df[df["comparison"] == "between"]["std"], fmt='o', color="orange")
#ax.scatter(within_between_df["group"], within_between_df["mean"], c=cmap(within_between_df["comparison"]), s=100)
ax.set_yticklabels([""] + list(df["group"].unique()))
ax.set_xlim(0)

ax.legend(
    frameon=False,
    bbox_to_anchor=(1.05, 1.1), 
    loc="upper left"
)
sns.despine()

# Annotated Embedding 

In [None]:
def get_parent_y(row):
    parent_name = row['parent_name']
    if parent_name is not None:
        parent_y = annotated_embeddings.loc[annotated_embeddings['strain'] == parent_name, 'y_value'].values
        if len(parent_y) > 0:
            return parent_y[0]
        else:
            return row['y_value']
    
    return np.nan

def get_parent_mutation_length(row):
    parent_name = row['parent_name']
    if parent_name is not None:
        parent_y = annotated_embeddings.loc[annotated_embeddings['strain'] == parent_name, 'divergence'].values
        if len(parent_y) > 0:
            return parent_y[0]
        else:
            return row['divergence']
    
    return np.nan

In [None]:
annotated_embeddings = pd.read_csv("../seasonal-flu-nextstrain/results/annotated_embeddings.tsv", sep="\t")

In [None]:
annotated_embeddings

In [None]:
# Apply the function to create the 'parent_y' column
import numpy as np
annotated_embeddings['parent_y'] = annotated_embeddings.apply(get_parent_y, axis=1)
annotated_embeddings['parent_mutation'] = annotated_embeddings.apply(get_parent_mutation_length, axis=1)

In [None]:
annotated_embeddings["y_value"] = annotated_embeddings["y_value"].max() - annotated_embeddings["y_value"]

annotated_embeddings["parent_y"] = annotated_embeddings["parent_y"].max() - annotated_embeddings["parent_y"]

In [None]:
annotated_embeddings

In [None]:
import seaborn as sns

plt.scatter(annotated_embeddings["divergence"], annotated_embeddings["y_value"])

In [None]:
import altair as alt
dataFrame = annotated_embeddings

In [None]:
base = alt.Chart(dataFrame[dataFrame["is_internal_node"] == True])
brush = alt.selection(type='interval', resolve='global')
tips = base.mark_circle().encode(
    x=alt.X(
        "divergence:Q",
        scale=alt.Scale(
            domain=(dataFrame["divergence"].min() - 0.0002, dataFrame["divergence"].max() + 0.0002)),
        title="Divergence",
        axis=alt.Axis(labels=True, ticks=True)
    ),
    y=alt.Y(
        "y_value:Q",
        title="",
        axis=alt.Axis(labels=False, ticks=False)
    ),
    # color=alt.condition(brush, if_false=alt.ColorValue('gray'), if_true=alt.Color(color, scale=alt.Scale(domain=domain, range=range_))),
    #tooltip=ToolTip
).add_selection(brush)

lines = alt.Chart(dataFrame).mark_line().encode(
            x=alt.X("parent_mutation:Q", scale=alt.Scale(domain=(dataFrame["divergence"].min() - 0.002, dataFrame["divergence"].max() + 0.002))),
            x2="divergence:Q",
            y=alt.Y("parent_y:Q", scale=alt.Scale(domain=(dataFrame["y_value"].min() - 1.0, dataFrame["y_value"].max() + 0.2))),
            y2="y_value:Q",
            color=alt.ColorValue("#cccccc")
        )

In [None]:
horizontal_lines = alt.Chart(dataFrame).mark_line().encode(
    x="divergence:Q",
    x2="parent_mutation:Q",
    y=alt.Y("parent_y:Q", scale=alt.Scale(domain=(dataFrame["y_value"].min() - 1.0, dataFrame["y_value"].max() + 0.2))),
    color=alt.ColorValue("#cccccc")
)

# Creating vertical lines
vertical_lines = alt.Chart(dataFrame).mark_rule().encode(
    x="divergence:Q",
    y=alt.Y("parent_y:Q", scale=alt.Scale(domain=(dataFrame["y_value"].min() - 1.0, dataFrame["y_value"].max() + 0.2))),
    y2="y_value:Q",
    color=alt.ColorValue("#cccccc")
)

In [None]:
lines_norm = (horizontal_lines + vertical_lines)

In [None]:
(lines+tips)

In [None]:
(lines_norm + tips)