## Imports

In [None]:
import sys
sys.path.extend(["./notebooks/scripts"])

In [None]:
import altair as alt
from altair_saver import save
import pandas as pd

from Helpers import linking_tree_with_plots_brush

import bcubed

In [None]:
alt.renderers.set_embed_options(
    padding={"left": 0, "right": 0, "bottom": 1, "top": 1}
)

## Inputs

In [None]:
colors_path = snakemake.input.colors
embeddings_path = snakemake.input.annotated_embeddings
accuracy_path = snakemake.input.accuracy_table

explained_variance_pca_ha = snakemake.input.explained_variance_pca_ha
explained_variance_pca_concatenated = snakemake.input.explained_variance_pca_concatenated

output_pca_html = snakemake.output.HANAFullChartBrushablePCAHTML
output_pca_png = snakemake.output.HANAFullChartBrushablePCAPNG
output_mds_html = snakemake.output.HANAFullChartBrushableMDSHTML
output_mds_png = snakemake.output.HANAFullChartBrushableMDSPNG
output_tsne_html = snakemake.output.HANAFullChartBrushableTSNEHTML
output_tsne_png = snakemake.output.HANAFullChartBrushableTSNEPNG
output_umap_html = snakemake.output.HANAFullChartBrushableUMAPHTML
output_umap_png = snakemake.output.HANAFullChartBrushableUMAPPNG
output_full_html = snakemake.output.fullChartHTML
output_full_png = snakemake.output.fullChartPNG

## Load data

In [None]:
colors = pd.read_csv(colors_path, sep="\t", names=[i for i in range(0,101)])

In [None]:
embeddings_df = pd.read_csv(embeddings_path, sep="\t")

In [None]:
embeddings_df.rename(
    columns={
        "y_value": "y",
        "num_date": "date",
    },
    inplace=True
)

In [None]:
embeddings_df.head()

In [None]:
# Parametrizing node_df
clade_membership = "MCC"

In [None]:
accuracy_df = pd.read_csv(accuracy_path)

In [None]:
accuracy_df

In [None]:
explained_variance_df_ha = pd.read_csv(explained_variance_pca_ha)

In [None]:
explained_variance_df_ha

In [None]:
explained_variance_pca_ha_values = explained_variance_df_ha["explained variance"].values.tolist()

In [None]:
explained_variance_pca_ha_values

In [None]:
explained_variance_df_concatenated = pd.read_csv(explained_variance_pca_concatenated)

In [None]:
explained_variance_df_concatenated

In [None]:
explained_variance_pca_concatenated_values = explained_variance_df_concatenated["explained variance"].values.tolist()


In [None]:
explained_variance_pca_concatenated_values

## Build color scales

In [None]:
def build_color_range_for_domain(domain, colors, value_for_unassigned=None):
    # Rows are zero-indexed, so to get N colors, we select row N - 1.
    range_ = colors.loc[len(domain) - 1].dropna().values.tolist()
   
    # Replace known values for "unassigned" clade or cluster labels.
    index_for_unassigned = None
    if value_for_unassigned is not None and value_for_unassigned in domain:
        index_for_unassigned = domain.index(value_for_unassigned)
        range_[index_for_unassigned] = "#999999"
        
    return range_

In [None]:
clade_color_domain = embeddings_df[clade_membership].drop_duplicates().values.tolist()

In [None]:
# Order MCCs with "unassigned" always listed first followed by MCCs
# in numerical order.
clade_color_domain = sorted(
    clade_color_domain,
    key=lambda value: -1 if value == "unassigned" else int(value.split("_")[-1])
)

In [None]:
clade_color_range = build_color_range_for_domain(clade_color_domain, colors, value_for_unassigned="unassigned")

## PCA

In [None]:
pca_ha_label_color_domain =  sorted(embeddings_df["pca_label_ha"].drop_duplicates().values)

In [None]:
pca_ha_label_color_range = build_color_range_for_domain(
    pca_ha_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_pca_ha = accuracy_df.query(
    "(embedding == 'pca') & (analysis_name == 'ha')"
).iloc[0]["MCC"]

In [None]:
pca_ha_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["pca1_ha", "pca2_ha"],
    [
        f"PC 1 (Explained variance: {round(explained_variance_pca_ha_values[0] * 100, 2)}%)",
        f"PC 2 (Explained variance: {round(explained_variance_pca_ha_values[1] * 100, 2)}%)"
    ],
    "pca_label_ha:N",
    ["strain:N", clade_membership, "pca_label_ha:N"],
    pca_ha_label_color_domain,
    pca_ha_label_color_range,
)

pca_ha_chart = (
    pca_ha_list_of_chart[0] | pca_ha_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_pca_ha, 4))
    )
)

In [None]:
pca_concatenated_label_color_domain = sorted(embeddings_df["pca_label_concatenated"].drop_duplicates().values)

In [None]:
pca_concatenated_label_color_range = build_color_range_for_domain(
    pca_concatenated_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_pca_concatenated = accuracy_df.query(
    "(embedding == 'pca') & (analysis_name == 'concatenated')"
).iloc[0]["MCC"]

In [None]:
pca_concatenated_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["pca1_concatenated", "pca2_concatenated"],
    [
        f"PC 1 (Explained variance: {round(explained_variance_pca_concatenated_values[0] * 100, 2)}%)",
        f"PC 2 (Explained variance: {round(explained_variance_pca_concatenated_values[1] * 100, 2)}%)",
    ],
    "pca_label_concatenated:N",
    ["strain:N", clade_membership, "pca_label_concatenated:N"],
    pca_concatenated_label_color_domain,
    pca_concatenated_label_color_range,
)

pca_concatenated_chart = (
    pca_concatenated_list_of_chart[0] | (pca_concatenated_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_pca_concatenated, 4)))
    )
)

In [None]:
pca_final_chart = alt.vconcat(
    pca_ha_chart,
    pca_concatenated_chart
).resolve_scale(
    color="independent",
)
pca_final_chart

In [None]:
embeddings_df.columns

In [None]:
cdict = embeddings_df[["strain", "pca_label_ha"]].set_index("strain").transpose().to_dict()

for k, v in cdict.items():
    cdict[k] = set(v.values())

In [None]:
ldict = embeddings_df[["strain", clade_membership]].set_index("strain").transpose().to_dict()

for k, v in ldict.items():
    ldict[k] = set(v.values())

In [None]:
precision = bcubed.precision(cdict, ldict)
recall = bcubed.recall(cdict, ldict)
fscore_pca = bcubed.fscore(precision, recall)

In [None]:
fscore_pca

In [None]:
pca_final_chart.save(output_pca_html)
save(pca_final_chart, output_pca_png, scale_factor=2.0)

## MDS

In [None]:
mds_ha_label_color_domain =  sorted(embeddings_df["mds_label_ha"].drop_duplicates().values)

In [None]:
mds_ha_label_color_range = build_color_range_for_domain(
    mds_ha_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_mds_ha = accuracy_df.query(
    "(embedding == 'mds') & (analysis_name == 'ha')"
).iloc[0]["MCC"]

In [None]:
mds_ha_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["mds1_ha", "mds2_ha"],
    [
        "MDS 1",
        "MDS 2"
    ],
    "mds_label_ha:N",
    ["strain:N", clade_membership, "mds_label_ha:N"],
    mds_ha_label_color_domain,
    mds_ha_label_color_range,
)

mds_ha_chart = (
    mds_ha_list_of_chart[0] | mds_ha_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_mds_ha, 4))
    )
)

In [None]:
mds_concatenated_label_color_domain = sorted(embeddings_df["mds_label_concatenated"].drop_duplicates().values)

In [None]:
mds_concatenated_label_color_range = build_color_range_for_domain(
    mds_concatenated_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_mds_concatenated = accuracy_df.query(
    "(embedding == 'mds') & (analysis_name == 'concatenated')"
).iloc[0]["MCC"]

In [None]:
mds_concatenated_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["mds1_concatenated", "mds2_concatenated"],
    ["MDS 1", "MDS 2"],
    "mds_label_concatenated:N",
    ["strain:N", clade_membership, "mds_label_concatenated:N"],
    mds_concatenated_label_color_domain,
    mds_concatenated_label_color_range,
)

mds_concatenated_chart = (
    mds_concatenated_list_of_chart[0] | (mds_concatenated_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_mds_concatenated, 4)))
    )
)

In [None]:
mds_final_chart = alt.vconcat(
    mds_ha_chart,
    mds_concatenated_chart
).resolve_scale(
    color="independent",
)
mds_final_chart

In [None]:
cdict = embeddings_df[["strain", "mds_label_ha"]].set_index("strain").transpose().to_dict()

for k, v in cdict.items():
    cdict[k] = set(v.values())

In [None]:
ldict = embeddings_df[["strain", clade_membership]].set_index("strain").transpose().to_dict()

for k, v in ldict.items():
    ldict[k] = set(v.values())

In [None]:
precision = bcubed.precision(cdict, ldict)
recall = bcubed.recall(cdict, ldict)
fscore_mds = bcubed.fscore(precision, recall)

In [None]:
fscore_mds

In [None]:
mds_final_chart.save(output_mds_html)
save(mds_final_chart, output_mds_png, scale_factor=2.0)

## t-SNE 

In [None]:
tsne_ha_label_color_domain =  sorted(embeddings_df["t-sne_label_ha"].drop_duplicates().values)

In [None]:
tsne_ha_label_color_range = build_color_range_for_domain(
    tsne_ha_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_tsne_ha = accuracy_df.query(
    "(embedding == 't-sne') & (analysis_name == 'ha')"
).iloc[0]["MCC"]

In [None]:
tsne_ha_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["tsne_x_ha", "tsne_y_ha"],
    [
        "t-SNE 1",
        "t-SNE 2"
    ],
    "t-sne_label_ha:N",
    ["strain:N", clade_membership, "t-sne_label_ha:N"],
    tsne_ha_label_color_domain,
    tsne_ha_label_color_range,
)

tsne_ha_chart = (
    tsne_ha_list_of_chart[0] | tsne_ha_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_tsne_ha, 4))
    )
)

In [None]:
tsne_concatenated_label_color_domain = sorted(embeddings_df["t-sne_label_concatenated"].drop_duplicates().values)

In [None]:
tsne_concatenated_label_color_range = build_color_range_for_domain(
    tsne_concatenated_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_tsne_concatenated = accuracy_df.query(
    "(embedding == 't-sne') & (analysis_name == 'concatenated')"
).iloc[0]["MCC"]

In [None]:
tsne_concatenated_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["tsne_x_concatenated", "tsne_y_concatenated"],
    ["t-SNE 1", "t-SNE 2"],
    "t-sne_label_concatenated:N",
    ["strain:N", clade_membership, "t-sne_label_concatenated:N"],
    tsne_concatenated_label_color_domain,
    tsne_concatenated_label_color_range,
)

tsne_concatenated_chart = (
    tsne_concatenated_list_of_chart[0] | (tsne_concatenated_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_tsne_concatenated, 4)))
    )
)

In [None]:
tsne_final_chart = alt.vconcat(
    tsne_ha_chart,
    tsne_concatenated_chart
).resolve_scale(
    color="independent",
)
tsne_final_chart

In [None]:
cdict = embeddings_df[["strain", "t-sne_label_ha"]].set_index("strain").transpose().to_dict()

for k, v in cdict.items():
    cdict[k] = set(v.values())

In [None]:
ldict = embeddings_df[["strain", clade_membership]].set_index("strain").transpose().to_dict()

for k, v in ldict.items():
    ldict[k] = set(v.values())

In [None]:
precision = bcubed.precision(cdict, ldict)
recall = bcubed.recall(cdict, ldict)
fscore_tsne = bcubed.fscore(precision, recall)

In [None]:
fscore_tsne

In [None]:
save(tsne_final_chart, output_tsne_html)
save(tsne_final_chart, output_tsne_png, scale_factor=2.0)

## UMAP

In [None]:
umap_ha_label_color_domain =  sorted(embeddings_df["umap_label_ha"].drop_duplicates().values)

In [None]:
umap_ha_label_color_range = build_color_range_for_domain(
    umap_ha_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_umap_ha = accuracy_df.query(
    "(embedding == 'umap') & (analysis_name == 'ha')"
).iloc[0]["MCC"]

In [None]:
umap_ha_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["umap_x_ha", "umap_y_ha"],
    [
        "UMAP 1",
        "UMAP 2"
    ],
    "umap_label_ha:N",
    ["strain:N", clade_membership, "umap_label_ha:N"],
    umap_ha_label_color_domain,
    umap_ha_label_color_range,
)

umap_ha_chart = (
    umap_ha_list_of_chart[0] | umap_ha_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_umap_ha, 4))
    )
)

In [None]:
umap_concatenated_label_color_domain = sorted(embeddings_df["umap_label_concatenated"].drop_duplicates().values)

In [None]:
umap_concatenated_label_color_range = build_color_range_for_domain(
    umap_concatenated_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_umap_concatenated = accuracy_df.query(
    "(embedding == 'umap') & (analysis_name == 'concatenated')"
).iloc[0]["MCC"]

In [None]:
umap_concatenated_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["umap_x_concatenated", "umap_y_concatenated"],
    ["UMAP 1", "UMAP 2"],
    "umap_label_concatenated:N",
    ["strain:N", clade_membership, "umap_label_concatenated:N"],
    umap_concatenated_label_color_domain,
    umap_concatenated_label_color_range,
)

umap_concatenated_chart = (
    umap_concatenated_list_of_chart[0] | (umap_concatenated_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_umap_concatenated, 4)))
    )
)

In [None]:
umap_final_chart = alt.vconcat(
    umap_ha_chart,
    umap_concatenated_chart
).resolve_scale(
    color="independent",
)
umap_final_chart

In [None]:
cdict = embeddings_df[["strain", "umap_label_ha"]].set_index("strain").transpose().to_dict()

for k, v in cdict.items():
    cdict[k] = set(v.values())

In [None]:
ldict = embeddings_df[["strain", clade_membership]].set_index("strain").transpose().to_dict()

for k, v in ldict.items():
    ldict[k] = set(v.values())

In [None]:
precision = bcubed.precision(cdict, ldict)
recall = bcubed.recall(cdict, ldict)
fscore_umap = bcubed.fscore(precision, recall)

In [None]:
fscore_umap

In [None]:
save(umap_final_chart, output_umap_html)
save(umap_final_chart, output_umap_png, scale_factor=2.0)

## All embeddings by clade membership

In [None]:
# TODO:
# - Add MCC accuracies as titles per plot
charts = linking_tree_with_plots_brush(
    embeddings_df,
    [
        'mds1_concatenated',
        'mds2_concatenated',
        'mds1_ha',
        'mds2_ha',
        'tsne_x_concatenated',
        'tsne_y_concatenated',
        'tsne_x_ha',
        'tsne_y_ha',
        'pca1_concatenated',
        'pca2_concatenated',
        'pca1_ha',
        'pca2_ha',
        'umap_x_concatenated',
        'umap_y_concatenated',
        'umap_x_ha',
        'umap_y_ha',
    ],
    [
        'MDS 1',
        'MDS 2',
        'MDS 1',
        'MDS 2',
        't-SNE 1',
        't-SNE 2',
        't-SNE 1',
        't-SNE 2', 
        'PC 1 (Expected Variance: {}%'.format(round(explained_variance_pca_concatenated_values[0]*100,2)) + ")",
        'PC 2 (Expected Variance: {}%'.format(round(explained_variance_pca_concatenated_values[1]*100,2)) + ")",
        'PC 1 (Expected Variance: {}%'.format(round(explained_variance_pca_ha_values[0]*100,2)) + ")",
        'PC 2 (Expected Variance: {}%'.format(round(explained_variance_pca_ha_values[1]*100,2)) + ")",
        'UMAP 1',
        'UMAP 2',
        'UMAP 1',
        'UMAP 2',
    ],
    clade_membership+":N",
    ['strain', clade_membership],
    clade_color_domain,
    clade_color_range,
)

In [None]:
chart_embeddings = alt.vconcat(
    charts[0],
    charts[6].properties(title=["HA only", "MCC: " + str(round(accuracy_pca_ha, 4))]) | charts[5].properties(title=["HA and NA", "MCC: " + str(round(accuracy_pca_concatenated, 4))]),
    charts[2].properties(title="MCC: " + str(round(accuracy_mds_ha, 4))) | charts[1].properties(title="MCC: " + str(round(accuracy_mds_concatenated, 4))),
    charts[4].properties(title="MCC: " + str(round(accuracy_tsne_ha, 4))) | charts[3].properties(title="MCC: " + str(round(accuracy_tsne_concatenated, 4))),
    charts[8].properties(title="MCC: " + str(round(accuracy_umap_ha, 4))) | charts[7].properties(title="MCC: " + str(round(accuracy_umap_concatenated, 4)))
)
chart_embeddings

In [None]:
chart_embeddings.save(output_full_html)
save(chart_embeddings, output_full_png, scale_factor=2.0)

In [None]:
print("the FIRST value is the fscore, the second is the MCC.")
print("pca_ha:")
print(fscore_pca)
print(accuracy_pca_ha)
print("mds_ha:")
print(fscore_mds)
print(accuracy_mds_ha)
print("t-sne_ha:")
print(fscore_tsne)
print(accuracy_tsne_ha)
print("umap_ha:")
print(fscore_umap)
print(accuracy_umap_ha)