## Imports

In [None]:
import sys
sys.path.append("notebooks/scripts/")

In [None]:
import altair as alt
from augur.dates import numeric_date
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from Helpers import linking_tree_with_plots_brush
from Helpers import make_node_branch_widths

%matplotlib inline

## Define inputs, outputs, and parameters

In [None]:
colors_path = snakemake.input.colors
tree_path = snakemake.input.tree
embeddings_path = snakemake.input.annotated_embeddings
accuracy_path_ha = snakemake.input.accuracy_table_ha
accuracy_path_concatenated = snakemake.input.accuracy_table_concatenated

explained_variance_pca_ha = snakemake.input.explained_variance_pca_ha
explained_variance_pca_concatenated = snakemake.input.explained_variance_pca_concatenated

output_pca_html = snakemake.output.HANAFullChartBrushablePCAHTML
output_pca_png = snakemake.output.HANAFullChartBrushablePCAPNG
output_mds_html = snakemake.output.HANAFullChartBrushableMDSHTML
output_mds_png = snakemake.output.HANAFullChartBrushableMDSPNG
output_tsne_html = snakemake.output.HANAFullChartBrushableTSNEHTML
output_tsne_png = snakemake.output.HANAFullChartBrushableTSNEPNG
output_umap_html = snakemake.output.HANAFullChartBrushableUMAPHTML
output_umap_png = snakemake.output.HANAFullChartBrushableUMAPPNG
output_ha_na_html = snakemake.output.HANAChartHTML
output_ha_na_png = snakemake.output.HANAChartPNG
output_full_html = snakemake.output.fullChartHTML
output_full_png = snakemake.output.fullChartPNG

In [None]:
accuracy_column = snakemake.params.accuracy_column

In [None]:
alt.renderers.set_embed_options(
    padding={"left": 0, "right": 0, "bottom": 1, "top": 1}
)

## Load data

In [None]:
node_branch_widths = make_node_branch_widths(tree_path)

In [None]:
colors = pd.read_csv(colors_path, sep="\t", names=[i for i in range(0,101)], nrows=101)

In [None]:
embeddings_df = pd.read_csv(embeddings_path, sep="\t")

In [None]:
embeddings_df.head()

In [None]:
embeddings_df = embeddings_df.merge(
    node_branch_widths,
    left_on="strain",
    right_on="node",
    validate="1:1",
)

In [None]:
# Parametrizing node_df
clade_membership = "MCC"

In [None]:
accuracy_df_ha = pd.read_csv(accuracy_path_ha)

In [None]:
accuracy_df_ha

In [None]:
accuracy_df_concatenated = pd.read_csv(accuracy_path_concatenated)

In [None]:
accuracy_df_concatenated

In [None]:
explained_variance_df_ha = pd.read_csv(explained_variance_pca_ha)

In [None]:
explained_variance_df_ha

In [None]:
explained_variance_pca_ha_values = explained_variance_df_ha["explained variance"].values.tolist()

In [None]:
explained_variance_pca_ha_values

In [None]:
explained_variance_df_concatenated = pd.read_csv(explained_variance_pca_concatenated)

In [None]:
explained_variance_df_concatenated

In [None]:
explained_variance_pca_concatenated_values = explained_variance_df_concatenated["explained variance"].values.tolist()


In [None]:
explained_variance_pca_concatenated_values

## Build color scales

In [None]:
def build_color_range_for_domain(domain, colors, value_for_unassigned=None):
    # Rows are zero-indexed, so to get N colors, we select row N - 1.
    # When we want N - 1 colors after excluding an "unassigned" value,
    # we select N - 1 - 1.
    if value_for_unassigned is not None and value_for_unassigned in domain:
        range_ = colors.loc[len(domain) - 1 - 1].dropna().values.tolist()

        # Replace known values for "unassigned" clade or cluster labels.
        index_for_unassigned = domain.index(value_for_unassigned)
        range_.insert(index_for_unassigned, "#999999")
    else:
        range_ = colors.loc[len(domain) - 1].dropna().values.tolist()
        
    return range_

In [None]:
clade_color_domain = [
    clade
    for clade in embeddings_df[clade_membership].drop_duplicates().dropna().tolist()
    if clade != "unassigned"
]

In [None]:
clade_color_range = build_color_range_for_domain(clade_color_domain, colors)

In [None]:
clade_color_domain, clade_color_range = zip(
    *sorted(
        zip(
            clade_color_domain,
            clade_color_range
        ),
        key=lambda item: int(item[0].split("_")[1])
    )
)

In [None]:
clade_color_domain = ["unassigned"] + list(clade_color_domain)

In [None]:
clade_color_range = ["#999999"] + list(clade_color_range)

In [None]:
clade_color_domain

In [None]:
clade_color_range

## PCA

In [None]:
pca_ha_label_color_domain =  sorted(embeddings_df["pca_ha_label"].drop_duplicates().dropna().values)

In [None]:
pca_ha_label_color_range = build_color_range_for_domain(
    pca_ha_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_pca_ha = accuracy_df_ha.query("method == 'pca'").iloc[0][accuracy_column]

In [None]:
accuracy_pca_ha

In [None]:
pca_ha_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["pca1_ha", "pca2_ha"],
    [
        f"PC 1 (Explained variance: {round(explained_variance_pca_ha_values[0] * 100, 2)}%)",
        f"PC 2 (Explained variance: {round(explained_variance_pca_ha_values[1] * 100, 2)}%)"
    ],
    "pca_ha_label:N",
    "PCA cluster",
    ["strain:N", clade_membership, "pca_ha_label:N"],
    pca_ha_label_color_domain,
    pca_ha_label_color_range,
)

pca_ha_chart = (
    pca_ha_list_of_chart[0] | pca_ha_list_of_chart[1].properties(
        title="Normalized VI: " + str(round(accuracy_pca_ha, 2))
    )
)

In [None]:
pca_concatenated_label_color_domain = sorted(embeddings_df["pca_concatenated_label"].drop_duplicates().dropna().values)

In [None]:
pca_concatenated_label_color_range = build_color_range_for_domain(
    pca_concatenated_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_pca_concatenated = accuracy_df_concatenated.query("method == 'pca'").iloc[0][accuracy_column]

In [None]:
accuracy_pca_concatenated

In [None]:
pca_concatenated_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["pca1_concatenated", "pca2_concatenated"],
    [
        f"PC 1 (Explained variance: {round(explained_variance_pca_concatenated_values[0] * 100, 2)}%)",
        f"PC 2 (Explained variance: {round(explained_variance_pca_concatenated_values[1] * 100, 2)}%)",
    ],
    "pca_concatenated_label:N",
    "PCA cluster",
    ["strain:N", clade_membership, "pca_concatenated_label:N"],
    pca_concatenated_label_color_domain,
    pca_concatenated_label_color_range,
)

pca_concatenated_chart = (
    pca_concatenated_list_of_chart[0] | (pca_concatenated_list_of_chart[1].properties(
        title="Normalized VI: " + str(round(accuracy_pca_concatenated, 2)))
    )
)

In [None]:
pca_final_chart = alt.vconcat(
    pca_ha_chart,
    pca_concatenated_chart
).resolve_scale(
    color="independent",
).configure_axis(grid=False).configure_view(stroke=None)
pca_final_chart

In [None]:
pca_final_chart.save(output_pca_html)
pca_final_chart.save(output_pca_png, format="png", scale_factor=2.0)

## MDS

In [None]:
mds_ha_label_color_domain =  sorted(embeddings_df["mds_ha_label"].drop_duplicates().dropna().values)

In [None]:
mds_ha_label_color_range = build_color_range_for_domain(
    mds_ha_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_mds_ha = accuracy_df_ha.query("method == 'mds'").iloc[0][accuracy_column]

In [None]:
accuracy_mds_ha

In [None]:
mds_ha_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["mds1_ha", "mds2_ha"],
    [
        "MDS 1",
        "MDS 2"
    ],
    "mds_ha_label:N",
    "MDS cluster",
    ["strain:N", clade_membership, "mds_ha_label:N"],
    mds_ha_label_color_domain,
    mds_ha_label_color_range,
    legend_columns=2,
)

mds_ha_chart = (
    mds_ha_list_of_chart[0] | mds_ha_list_of_chart[1].properties(
        title="Normalized VI: " + str(round(accuracy_mds_ha, 2))
    )
)

In [None]:
mds_concatenated_label_color_domain = sorted(embeddings_df["mds_concatenated_label"].drop_duplicates().dropna().values)

In [None]:
mds_concatenated_label_color_range = build_color_range_for_domain(
    mds_concatenated_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_mds_concatenated = accuracy_df_concatenated.query("method == 'mds'").iloc[0][accuracy_column]

In [None]:
accuracy_mds_concatenated

In [None]:
mds_concatenated_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["mds1_concatenated", "mds2_concatenated"],
    ["MDS 1", "MDS 2"],
    "mds_concatenated_label:N",
    "MDS cluster",
    ["strain:N", clade_membership, "mds_concatenated_label:N"],
    mds_concatenated_label_color_domain,
    mds_concatenated_label_color_range,
    legend_columns=2,
)

mds_concatenated_chart = (
    mds_concatenated_list_of_chart[0] | (mds_concatenated_list_of_chart[1].properties(
        title="Normalized VI: " + str(round(accuracy_mds_concatenated, 2)))
    )
)

In [None]:
mds_final_chart = alt.vconcat(
    mds_ha_chart,
    mds_concatenated_chart
).resolve_scale(
    color="independent",
).configure_axis(grid=False).configure_view(stroke=None)

In [None]:
mds_final_chart

In [None]:
mds_final_chart.save(output_mds_html)
mds_final_chart.save(output_mds_png, format="png", scale_factor=2.0)

## t-SNE 

In [None]:
tsne_ha_label_color_domain =  sorted(embeddings_df["t-sne_ha_label"].drop_duplicates().dropna().values)

In [None]:
tsne_ha_label_color_range = build_color_range_for_domain(
    tsne_ha_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_tsne_ha = accuracy_df_ha.query("method == 't-sne'").iloc[0][accuracy_column]

In [None]:
accuracy_tsne_ha

In [None]:
tsne_ha_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["tsne_x_ha", "tsne_y_ha"],
    [
        "t-SNE 1",
        "t-SNE 2"
    ],
    "t-sne_ha_label:N",
    "t-SNE cluster",
    ["strain:N", clade_membership, "t-sne_ha_label:N"],
    tsne_ha_label_color_domain,
    tsne_ha_label_color_range,
)

tsne_ha_chart = (
    tsne_ha_list_of_chart[0] | tsne_ha_list_of_chart[1].properties(
        title="Normalized VI: " + str(round(accuracy_tsne_ha, 2))
    )
)

In [None]:
tsne_concatenated_label_color_domain = sorted(embeddings_df["t-sne_concatenated_label"].drop_duplicates().dropna().values)

In [None]:
tsne_concatenated_label_color_range = build_color_range_for_domain(
    tsne_concatenated_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_tsne_concatenated = accuracy_df_concatenated.query("method == 't-sne'").iloc[0][accuracy_column]

In [None]:
accuracy_tsne_concatenated

In [None]:
tsne_concatenated_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["tsne_x_concatenated", "tsne_y_concatenated"],
    ["t-SNE 1", "t-SNE 2"],
    "t-sne_concatenated_label:N",
    "t-SNE cluster",
    ["strain:N", clade_membership, "t-sne_concatenated_label:N"],
    tsne_concatenated_label_color_domain,
    tsne_concatenated_label_color_range,
)

tsne_concatenated_chart = (
    tsne_concatenated_list_of_chart[0] | (tsne_concatenated_list_of_chart[1].properties(
        title="Normalized VI: " + str(round(accuracy_tsne_concatenated, 2)))
    )
)

In [None]:
tsne_final_chart = alt.vconcat(
    tsne_ha_chart,
    tsne_concatenated_chart
).resolve_scale(
    color="independent",
).configure_axis(grid=False).configure_view(stroke=None)
tsne_final_chart

In [None]:
tsne_final_chart.save(output_tsne_html)
tsne_final_chart.save(output_tsne_png, format="png", scale_factor=2.0)

## UMAP

In [None]:
umap_ha_label_color_domain =  sorted(embeddings_df["umap_ha_label"].drop_duplicates().dropna().values)

In [None]:
umap_ha_label_color_range = build_color_range_for_domain(
    umap_ha_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_umap_ha = accuracy_df_ha.query("method == 'umap'").iloc[0][accuracy_column]

In [None]:
accuracy_umap_ha

In [None]:
umap_ha_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["umap_x_ha", "umap_y_ha"],
    [
        "UMAP 1",
        "UMAP 2"
    ],
    "umap_ha_label:N",
    "UMAP cluster",
    ["strain:N", clade_membership, "umap_ha_label:N"],
    umap_ha_label_color_domain,
    umap_ha_label_color_range,
)

umap_ha_chart = (
    umap_ha_list_of_chart[0] | umap_ha_list_of_chart[1].properties(
        title="Normalized VI: " + str(round(accuracy_umap_ha, 2))
    )
)

In [None]:
umap_concatenated_label_color_domain = sorted(embeddings_df["umap_concatenated_label"].drop_duplicates().dropna().values)

In [None]:
umap_concatenated_label_color_range = build_color_range_for_domain(
    umap_concatenated_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_umap_concatenated = accuracy_df_concatenated.query("method == 'umap'").iloc[0][accuracy_column]

In [None]:
accuracy_umap_concatenated

In [None]:
umap_concatenated_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["umap_x_concatenated", "umap_y_concatenated"],
    ["UMAP 1", "UMAP 2"],
    "umap_concatenated_label:N",
    "UMAP cluster",
    ["strain:N", clade_membership, "umap_concatenated_label:N"],
    umap_concatenated_label_color_domain,
    umap_concatenated_label_color_range,
)

umap_concatenated_chart = (
    umap_concatenated_list_of_chart[0] | (umap_concatenated_list_of_chart[1].properties(
        title="Normalized VI: " + str(round(accuracy_umap_concatenated, 2)))
    )
)

In [None]:
umap_final_chart = alt.vconcat(
    umap_ha_chart,
    umap_concatenated_chart
).resolve_scale(
    color="independent",
).configure_axis(grid=False).configure_view(stroke=None)
umap_final_chart

In [None]:
umap_final_chart.save(output_umap_html)
umap_final_chart.save(output_umap_png, format="png", scale_factor=2.0)

## All embeddings by clade membership

In [None]:
# TODO:
# - Add MCC accuracies as titles per plot
charts = linking_tree_with_plots_brush(
    embeddings_df,
    [
        'mds1_concatenated',
        'mds2_concatenated',
        'mds1_ha',
        'mds2_ha',
        'tsne_x_concatenated',
        'tsne_y_concatenated',
        'tsne_x_ha',
        'tsne_y_ha',
        'pca1_concatenated',
        'pca2_concatenated',
        'pca1_ha',
        'pca2_ha',
        'umap_x_concatenated',
        'umap_y_concatenated',
        'umap_x_ha',
        'umap_y_ha',
    ],
    [
        'MDS 1',
        'MDS 2',
        'MDS 1',
        'MDS 2',
        't-SNE 1',
        't-SNE 2',
        't-SNE 1',
        't-SNE 2', 
        'PC 1 (Expected Variance: {}%'.format(round(explained_variance_pca_concatenated_values[0]*100,2)) + ")",
        'PC 2 (Expected Variance: {}%'.format(round(explained_variance_pca_concatenated_values[1]*100,2)) + ")",
        'PC 1 (Expected Variance: {}%'.format(round(explained_variance_pca_ha_values[0]*100,2)) + ")",
        'PC 2 (Expected Variance: {}%'.format(round(explained_variance_pca_ha_values[1]*100,2)) + ")",
        'UMAP 1',
        'UMAP 2',
        'UMAP 1',
        'UMAP 2',
    ],
    clade_membership+":N",
    "Clade membership",
    ['strain', clade_membership],
    clade_color_domain,
    clade_color_range,
)

In [None]:
chart_embeddings = alt.vconcat(
    charts[6].properties(title=["HA only", "Normalized VI: " + str(round(accuracy_pca_ha, 2))]) | charts[5].properties(title=["HA and NA", "Normalized VI: " + str(round(accuracy_pca_concatenated, 2))]),
    charts[2].properties(title="Normalized VI: " + str(round(accuracy_mds_ha, 2))) | charts[1].properties(title="Normalized VI: " + str(round(accuracy_mds_concatenated, 2))),
    charts[4].properties(title="Normalized VI: " + str(round(accuracy_tsne_ha, 2))) | charts[3].properties(title="Normalized VI: " + str(round(accuracy_tsne_concatenated, 2))),
    charts[8].properties(title="Normalized VI: " + str(round(accuracy_umap_ha, 2))) | charts[7].properties(title="Normalized VI: " + str(round(accuracy_umap_concatenated, 2)))
).configure_axis(grid=False).configure_view(stroke=None)

In [None]:
chart_embeddings

In [None]:
chart_embeddings.save(output_full_html)
chart_embeddings.save(output_full_png, format="png", scale_factor=2.0)

In [None]:
all_a2_records = embeddings_df.query("(clade_membership == 'A2')")

In [None]:
a2_mcc, a2_re_mcc = all_a2_records["MCC"].value_counts().head(2).index.values

In [None]:
a2_mcc

In [None]:
a2_re_mcc

In [None]:
a2_records = all_a2_records[all_a2_records["MCC"] == a2_mcc]

In [None]:
a2_re_records = all_a2_records[all_a2_records["MCC"] == a2_re_mcc]

In [None]:
a2_records.shape

In [None]:
a2_re_records.shape

In [None]:
a2_y_value = a2_records["y_value"].min() + ((a2_records["y_value"].max() - a2_records["y_value"].min()) / 2)

In [None]:
a2_max_divergence = a2_records["divergence"].max() + 0.001

In [None]:
a2_re_y_value = a2_re_records["y_value"].min() + (
    (a2_re_records["y_value"].max() - a2_re_records["y_value"].min()) / 2
)

In [None]:
a2_re_max_divergence = a2_re_records["divergence"].max() + 0.001

In [None]:
text_df = pd.DataFrame([
    {
        "divergence": a2_max_divergence,
        "y_value": a2_y_value,
        "text": "A2",
    },
    {
        "divergence": a2_re_max_divergence,
        "y_value": a2_re_y_value,
        "text": "A2/re",
    },
])

In [None]:
text_df

In [None]:
a2_labels = alt.Chart(text_df).mark_text().encode(
    x="divergence:Q",
    y="y_value:Q",
    text="text:N",
)

In [None]:
ha_na_only_chart = (
    (charts[0] + a2_labels) &
    (charts[5].properties(title=f"Normalized VI: {round(accuracy_pca_concatenated, 2):.2f} ({round(accuracy_pca_ha, 2):.2f})") |
     charts[1].properties(title=f"Normalized VI: {round(accuracy_mds_concatenated, 2):.2f} ({round(accuracy_mds_ha, 2):.2f})")) &
    (charts[3].properties(title=f"Normalized VI: {round(accuracy_tsne_concatenated, 2):.2f} ({round(accuracy_tsne_ha, 2):.2f})") |
     charts[7].properties(title=f"Normalized VI: {round(accuracy_umap_concatenated, 2):.2f} ({round(accuracy_umap_ha, 2):.2f})"))
).configure_axis(grid=False).configure_view(stroke=None)

In [None]:
ha_na_only_chart

In [None]:
ha_na_only_chart.save(output_ha_na_html)
ha_na_only_chart.save(output_ha_na_png, format="png", scale_factor=2.0)

In [None]:
poster_embeddings_by_clade = alt.vconcat(
    charts[0].properties(width=1100),
    (
        charts[5].properties(title=f"Normalized VI: {round(accuracy_pca_concatenated, 2):.2f} ({round(accuracy_pca_ha, 2):.2f})") |
        charts[1].properties(title=f"Normalized VI: {round(accuracy_mds_concatenated, 2):.2f} ({round(accuracy_mds_ha, 2):.2f})") |
        charts[3].properties(title=f"Normalized VI: {round(accuracy_tsne_concatenated, 2):.2f} ({round(accuracy_tsne_ha, 2):.2f})") |
        charts[7].properties(title=f"Normalized VI: {round(accuracy_umap_concatenated, 2):.2f} ({round(accuracy_umap_ha, 2):.2f})")
    ),
).configure_axis(grid=False).configure_view(stroke=None)
poster_embeddings_by_clade