## Imports

In [None]:
import sys
sys.path.extend(["./notebooks/scripts"])

In [None]:
import altair as alt
from altair_saver import save
from augur.utils import json_to_tree, read_node_data
import json
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import re
import seaborn as sns
from sklearn.metrics import confusion_matrix, matthews_corrcoef

from Helpers import linking_tree_with_plots_clickable, linking_tree_with_plots_brush, scatterplot_with_tooltip_interactive
from Helpers import get_y_positions, get_euclidean_data_frame

#%matplotlib inline

In [None]:
alt.renderers.set_embed_options(
    padding={"left": 0, "right": 0, "bottom": 1, "top": 1}
)

In [None]:
sns.set_style("ticks")
# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 100
# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 10
mpl.rcParams['axes.labelsize'] = 10
mpl.rcParams['legend.fontsize'] = 8
mpl.rcParams['xtick.labelsize'] = 10
mpl.rcParams['ytick.labelsize'] = 10
mpl.rcParams['axes.titlesize'] = 8
mpl.rc('text', usetex=False)

## Inputs

In [None]:
colors_path = snakemake.input.colors
embeddings_path = snakemake.input.annotated_embeddings
accuracy_path = snakemake.input.accuracy_table

explained_variance_pca_ha = snakemake.input.explained_variance_pca_ha
explained_variance_pca_concatenated = snakemake.input.explained_variance_pca_concatenated

output_pca_html = snakemake.output.HANAFullChartBrushablePCAHTML
output_pca_png = snakemake.output.HANAFullChartBrushablePCAPNG
output_mds_html = snakemake.output.HANAFullChartBrushableMDSHTML
output_mds_png = snakemake.output.HANAFullChartBrushableMDSPNG
output_tsne_html = snakemake.output.HANAFullChartBrushableTSNEHTML
output_tsne_png = snakemake.output.HANAFullChartBrushableTSNEPNG
output_full_html = snakemake.output.fullChartHTML
output_full_png = snakemake.output.fullChartPNG

## Load data

In [None]:
colors = pd.read_csv(colors_path, sep="\t", names=[i for i in range(0,101)])

In [None]:
embeddings_df = pd.read_csv(embeddings_path, sep="\t")

In [None]:
embeddings_df.rename(
    columns={
        "y_value": "y",
        "num_date": "date",
    },
    inplace=True
)

In [None]:
embeddings_df.head()

In [None]:
# Parametrizing node_df
clade_membership = "MCC"

In [None]:
accuracy_df = pd.read_csv(accuracy_path)

In [None]:
accuracy_df

In [None]:
explained_variance_df_ha = pd.read_csv(explained_variance_pca_ha)

In [None]:
explained_variance_df_ha

In [None]:
explained_variance_pca_ha_values = explained_variance_df_ha["explained variance"].values.tolist()

In [None]:
explained_variance_pca_ha_values

In [None]:
explained_variance_df_concatenated = pd.read_csv(explained_variance_pca_concatenated)

In [None]:
explained_variance_df_concatenated

In [None]:
explained_variance_pca_concatenated_values = explained_variance_df_concatenated["explained variance"].values.tolist()


In [None]:
explained_variance_pca_concatenated_values

## Build color scales

In [None]:
def build_color_range_for_domain(domain, colors, value_for_unassigned=None):
    # Rows are zero-indexed, so to get N colors, we select row N - 1.
    range_ = colors.loc[len(domain) - 1].dropna().values.tolist()
   
    # Replace known values for "unassigned" clade or cluster labels.
    index_for_unassigned = None
    if value_for_unassigned is not None and value_for_unassigned in domain:
        index_for_unassigned = domain.index(value_for_unassigned)
        range_[index_for_unassigned] = "#999999"
        
    return range_

In [None]:
clade_color_domain = embeddings_df[clade_membership].drop_duplicates().values.tolist()

In [None]:
clade_color_range = build_color_range_for_domain(clade_color_domain, colors, value_for_unassigned="unassigned")

## PCA

In [None]:
pca_ha_label_color_domain =  sorted(embeddings_df["pca_label_ha"].drop_duplicates().values)

In [None]:
pca_ha_label_color_range = build_color_range_for_domain(
    pca_ha_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_pca_ha = accuracy_df.query(
    "(embedding == 'pca') & (analysis_name == 'ha')"
).iloc[0]["MCC"]

In [None]:
pca_ha_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["pca1_ha", "pca2_ha"],
    [
        f"PCA 1 (Explained variance: {round(explained_variance_pca_ha_values[0] * 100, 2)}%)",
        f"PCA 2 (Explained variance: {round(explained_variance_pca_ha_values[1] * 100, 2)}%)"
    ],
    "pca_label_ha:N",
    ["strain:N", clade_membership, "pca_label_ha:N"],
    pca_ha_label_color_domain,
    pca_ha_label_color_range,
)

pca_ha_chart = (
    pca_ha_list_of_chart[0] | pca_ha_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_pca_ha, 4))
    )
)

In [None]:
pca_concatenated_label_color_domain =  sorted(embeddings_df["pca_label_concatenated"].drop_duplicates().values)

In [None]:
pca_concatenated_label_color_range = build_color_range_for_domain(
    pca_concatenated_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
accuracy_pca_concatenated = accuracy_df.query(
    "(embedding == 'pca') & (analysis_name == 'concatenated')"
).iloc[0]["MCC"]

In [None]:
pca_concatenated_list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["pca1_concatenated", "pca2_concatenated"],
    ["PCA 1", "PCA 2"],
    "pca_label_concatenated:N",
    ["strain:N", clade_membership, "pca_label_concatenated:N"],
    pca_concatenated_label_color_domain,
    pca_concatenated_label_color_range,
)

pca_concatenated_chart = (
    pca_concatenated_list_of_chart[0] | (pca_concatenated_list_of_chart[1].properties(
        title="MCC: " + str(round(accuracy_pca_concatenated, 4)))
    )
)

In [None]:
# TODO: fix conflicting color scales
pca_final_chart = alt.vconcat(pca_ha_chart, pca_concatenated_chart)
pca_final_chart

In [None]:
pca_final_chart.save(output_pca_html)
save(pca_final_chart, output_pca_png, scale_factor=2.0)

## MDS

In [None]:
MDS_df_ha = pd.read_csv(mds_df_ha,index_col=0)
MDS_df_concatenated = pd.read_csv(mds_df_concatenated,index_col=0)

In [None]:
merged_mds_df_ha = MDS_df_ha.merge(node_df_ha[["strain", "date", "y", clade_membership]], on="strain")
merged_mds_df_concatenated = MDS_df_concatenated.merge(node_df_ha[["strain", "date", "y", clade_membership]], on="strain")

In [None]:
merged_mds_df_ha

In [None]:
domain =  sorted(merged_mds_df_ha[clade_membership].drop_duplicates().values)
range_ = colors[len(domain):len(domain)+1].dropna(axis=1).values.tolist()[0]
chart_12_mds = scatterplot_with_tooltip_interactive(merged_mds_df_ha,'mds1','mds2',"mds1","mds2",['strain',clade_membership],clade_membership+":N", domain, range_)
chart_12_mds

In [None]:
domain =  sorted(merged_mds_df_ha[clade_membership].drop_duplicates().values)
range_ = colors[len(domain):len(domain)+1].dropna(axis=1).values.tolist()[0]
list_of_chart_ha = linking_tree_with_plots_brush(merged_mds_df_ha,['mds1','mds2'],["MDS1", "MDS2"], clade_membership+":N", ['strain',clade_membership], domain, range_)
chart_ha = list_of_chart_ha[0]|list_of_chart_ha[1]
domain =  sorted(merged_mds_df_concatenated[clade_membership].drop_duplicates().values)
range_ = colors[len(domain):len(domain)+1].dropna(axis=1).values.tolist()[0]
list_of_chart_concatenated = linking_tree_with_plots_brush(merged_mds_df_concatenated,['mds1','mds2'],["MDS1", "MDS2"], clade_membership+":N", ['strain',clade_membership], domain, range_)
chart_concat = list_of_chart_concatenated[0]|list_of_chart_concatenated[1]
alt.vconcat(chart_ha, chart_concat)

In [None]:
MDS_df_ha = pd.read_csv(mds_df_ha,index_col=0)
MDS_df_concatenated = pd.read_csv(mds_df_concatenated,index_col=0)

In [None]:
merged_mds_df_ha = MDS_df_ha.merge(node_df_ha[["strain", "date", "y", clade_membership]], on="strain")
merged_mds_df_concatenated = MDS_df_concatenated.merge(node_df_ha[["strain", "date", "y", clade_membership]], on="strain")

In [None]:
mcc_calc_mds = MDS_df_ha.merge(mcc_calc_df[["strain", "date", "y", clade_membership]], on="strain")
mcc_calc_mds_concatenated = MDS_df_concatenated.merge(mcc_calc_df[["strain", "date", "y", clade_membership]], on="strain")
KDE_df_normal = get_euclidean_data_frame(sampled_df=mcc_calc_mds, column_for_analysis=clade_membership, embedding="method", column_list=['mds1', 'mds2'])

In [None]:
domain =  merged_mds_df_ha["mds_label"].drop_duplicates().values
range_ = colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
chart_12_mds = scatterplot_with_tooltip_interactive(merged_mds_df_ha,'mds1','mds2',"mds1","mds2",['strain',clade_membership],'mds_label:N', domain, range_)
chart_12_mds

In [None]:
from sklearn.metrics import confusion_matrix, matthews_corrcoef

In [None]:
KDE_df_cluster = get_euclidean_data_frame(sampled_df=mcc_calc_mds[["mds1", "mds2", "strain", "mds_label"]], column_for_analysis="mds_label", embedding="mds", column_list=["mds1", "mds2"])
confusion_matrix_val_ha = confusion_matrix(KDE_df_normal["clade_status"], KDE_df_cluster["clade_status"])
matthews_cc_val_ha = matthews_corrcoef(KDE_df_normal["clade_status"], KDE_df_cluster["clade_status"])

In [None]:
KDE_df_cluster = get_euclidean_data_frame(sampled_df=mcc_calc_mds_concatenated[["mds1", "mds2", "strain", "mds_label"]], column_for_analysis="mds_label", embedding="mds", column_list=["mds1", "mds2"])
confusion_matrix_val_concatenated = confusion_matrix(KDE_df_normal["clade_status"], KDE_df_cluster["clade_status"])
matthews_cc_val_concatenated = matthews_corrcoef(KDE_df_normal["clade_status"], KDE_df_cluster["clade_status"])

In [None]:
domain =  merged_mds_df_ha[clade_membership].drop_duplicates().values
range_ = colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
list_of_chart_ha = linking_tree_with_plots_brush(merged_mds_df_ha,['mds1','mds2'],["MDS1", "MDS2"], clade_membership+":N", ['strain',clade_membership], domain, range_)
chart_ha = list_of_chart_ha[0]|list_of_chart_ha[1].properties(title="MCC: " + str(round(matthews_cc_val_ha,4)))
domain =  merged_mds_df_concatenated[clade_membership].drop_duplicates().values
range_ = colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
list_of_chart_concatenated = linking_tree_with_plots_brush(merged_mds_df_concatenated,['mds1','mds2'],["MDS1", "MDS2"], clade_membership+":N", ['strain',clade_membership], domain, range_)
chart_concat = list_of_chart_concatenated[0]|list_of_chart_concatenated[1].properties(title="MCC: " + str(round(matthews_cc_val_concatenated,4)))
chart_total = alt.vconcat(chart_ha, chart_concat)
chart_total


In [None]:
domain =  sorted(merged_mds_df_ha["mds_label"].drop_duplicates().values)
if -1 in domain:
    range_ = ["#999999"] + colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
else: 
    range_ = colors[len(domain):len(domain)+1].dropna(axis=1).values.tolist()[0]
list_of_chart_ha = linking_tree_with_plots_brush(merged_mds_df_ha,['mds1','mds2'],["MDS1", "MDS2"], 'mds_label:N', ['strain',clade_membership], domain, range_)
chart_ha = list_of_chart_ha[0]|list_of_chart_ha[1].properties(title="MCC: " + str(round(matthews_cc_val_ha,4)))
domain =  sorted(merged_mds_df_ha["mds_label"].drop_duplicates().values)
if -1 in domain:
    range_ = ["#999999"] + colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
else: 
    range_ = colors[len(domain):len(domain)+1].dropna(axis=1).values.tolist()[0]
list_of_chart_concatenated = linking_tree_with_plots_brush(merged_mds_df_concatenated,['mds1','mds2'],["MDS1", "MDS2"], 'mds_label:N', ['strain',clade_membership], domain, range_)
chart_concat = list_of_chart_concatenated[0]|list_of_chart_concatenated[1].properties(title="MCC: " + str(round(matthews_cc_val_concatenated,4)))
final_chart = alt.vconcat(chart_ha, chart_concat)
final_chart

In [None]:
final_chart.save(output_mds_html)
#save(final_chart, output_mds_png, scale_factor=2.0)

## HDBSCAN clustering on t-SNE 

In [None]:
TSNE_df_ha = pd.read_csv(tsne_df_ha, index_col=0)
TSNE_df_concatenated = pd.read_csv(tsne_df_concatenated,index_col=0)

In [None]:
merged_tsne_df_ha = TSNE_df_ha.merge(node_df_ha[["strain", "date", "y", clade_membership]], on="strain")
merged_tsne_df_concatenated = TSNE_df_concatenated.merge(node_df_ha[["strain", "date", "y", clade_membership]], on="strain")

In [None]:
mcc_calc_tsne = TSNE_df_ha.merge(mcc_calc_df[["strain", "date", "y", clade_membership]], on="strain")
mcc_calc_tsne_concatenated = TSNE_df_concatenated.merge(mcc_calc_df[["strain", "date", "y", clade_membership]], on="strain")
KDE_df_normal = get_euclidean_data_frame(sampled_df=mcc_calc_tsne, column_for_analysis=clade_membership, embedding="method", column_list=['tsne_x', 'tsne_y'])

In [None]:
domain =  merged_tsne_df_ha["t-sne_label"].drop_duplicates().values
range_ = colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
chart_12_tsne = scatterplot_with_tooltip_interactive(merged_tsne_df_ha,'tsne_x','tsne_y',"tsne_x","tsne_y",['strain',clade_membership],'t-sne_label:N', domain, range_)
chart_12_tsne

In [None]:
KDE_df_cluster = get_euclidean_data_frame(sampled_df=mcc_calc_tsne[["tsne_x", "tsne_y", "strain", "t-sne_label"]], column_for_analysis="t-sne_label", embedding="tsne", column_list=["tsne_x", "tsne_y"])
confusion_matrix_val_ha = confusion_matrix(KDE_df_normal["clade_status"], KDE_df_cluster["clade_status"])
matthews_cc_val_ha = matthews_corrcoef(KDE_df_normal["clade_status"], KDE_df_cluster["clade_status"])

In [None]:
KDE_df_cluster = get_euclidean_data_frame(sampled_df=mcc_calc_tsne_concatenated[["tsne_x", "tsne_y", "strain", "t-sne_label"]], column_for_analysis="t-sne_label", embedding="tsne", column_list=["tsne_x", "tsne_y"])
confusion_matrix_val_concatenated = confusion_matrix(KDE_df_normal["clade_status"], KDE_df_cluster["clade_status"])
matthews_cc_val_concatenated = matthews_corrcoef(KDE_df_normal["clade_status"], KDE_df_cluster["clade_status"])

In [None]:
domain =  sorted(merged_tsne_df_ha["t-sne_label"].drop_duplicates().values)
if -1 in domain:
    range_ = ["#999999"] + colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
else: 
    range_ = colors[len(domain):len(domain)+1].dropna(axis=1).values.tolist()[0]
list_of_chart_ha = linking_tree_with_plots_brush(merged_tsne_df_ha,['tsne_x','tsne_y'],["MDS1", "MDS2"], 't-sne_label:N', ['strain',clade_membership], domain, range_)
chart_ha = list_of_chart_ha[0]|list_of_chart_ha[1].properties(title="MCC: " + str(round(matthews_cc_val_ha,4)))
domain =  merged_tsne_df_concatenated["t-sne_label"].drop_duplicates().values
range_ = colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
list_of_chart_concatenated = linking_tree_with_plots_brush(merged_tsne_df_concatenated,['tsne_x','tsne_y'],["MDS1", "MDS2"], 't-sne_label:N', ['strain',clade_membership], domain, range_)
chart_concat = list_of_chart_concatenated[0]|list_of_chart_concatenated[1].properties(title="MCC: " + str(round(matthews_cc_val_concatenated,4)))
final_chart = alt.vconcat(chart_ha, chart_concat).resolve_scale(color='independent')
final_chart

In [None]:
save(final_chart, output_tsne_html)
#save(final_chart, output_tsne_png, scale_factor=2.0)

# Running T-SNE on the Dataset 

In [None]:
domain =  sorted(merged_tsne_df_ha[clade_membership].drop_duplicates().values)
if -1 in domain:
    range_ = ["#999999"] + colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
else: 
    range_ = colors[len(domain):len(domain)+1].dropna(axis=1).values.tolist()[0]
scatterplot_with_tooltip_interactive(merged_tsne_df_ha,'tsne_x','tsne_y','tsne_x','tsne_y',['strain', clade_membership],clade_membership+":N", domain, range_)

In [None]:
domain =  sorted(merged_tsne_df_concatenated[clade_membership].drop_duplicates().values)
if -1 in domain:
    range_ = ["#999999"] + colors[len(domain)-1:len(domain)].dropna(axis=1).values.tolist()[0]
else: 
    range_ = colors[len(domain):len(domain)+1].dropna(axis=1).values.tolist()[0]
scatterplot_with_tooltip_interactive(merged_tsne_df_concatenated,'tsne_x','tsne_y','tsne_x','tsne_y',['strain', clade_membership],clade_membership+":N", domain, range_)

In [None]:
list_of_chart_ha = linking_tree_with_plots_brush(
    merged_tsne_df_ha,
    ['tsne_x','tsne_y'],
    ['tsne_x','tsne_y'],
    clade_membership+":N",
    ["strain:N", clade_membership+":N"],
    domain,
    range_
)
chart_tsne_ha = list_of_chart_ha[0]|list_of_chart_ha[1]
chart_tsne_ha

In [None]:
list_of_chart_concatenated = linking_tree_with_plots_brush(
    merged_tsne_df_concatenated,
    ['tsne_x','tsne_y'],
    ['tsne_x','tsne_y'],
    clade_membership+":N",
    ["strain:N", clade_membership+":N"],
    domain,
    range_
)
chart_tsne_concatenated = list_of_chart_concatenated[0]|list_of_chart_concatenated[1]
chart_tsne_concatenated

In [None]:
chart_tsne_ha & chart_tsne_concatenated

# Running UMAP on the Dataset

In [None]:
UMAP_df_ha = pd.read_csv(umap_df_ha, index_col=0)
UMAP_df_concatenated = pd.read_csv(umap_df_concatenated, index_col=0)

In [None]:
UMAP_df_concatenated

In [None]:
merged_umap_df_ha = UMAP_df_ha.merge(node_df_ha[["strain", "date", "y", clade_membership]], on="strain")
merged_umap_df_concatenated = UMAP_df_concatenated.merge(node_df_ha[["strain", "date", "y", clade_membership]], on="strain")

In [None]:
UMAP_df_ha.index.tolist() == UMAP_df_concatenated.index.values.tolist()

## All embeddings by clade membership

In [None]:
# TODO:
# - Add MCC accuracies as titles per plot
data = linking_tree_with_plots_brush(
    embeddings_df,
    [
        'mds1_concatenated',
        'mds2_concatenated',
        'mds1_ha',
        'mds2_ha',
        'tsne_x_concatenated',
        'tsne_y_concatenated',
        'tsne_x_ha',
        'tsne_y_ha',
        'pca1_concatenated',
        'pca2_concatenated',
        'pca1_ha',
        'pca2_ha',
        'umap_x_concatenated',
        'umap_y_concatenated',
        'umap_x_ha',
        'umap_y_ha',
    ],
    [
        'MDS 1',
        'MDS 2',
        'MDS 1',
        'MDS 2',
        't-SNE 1',
        't-SNE 2',
        't-SNE 1',
        't-SNE 2', 
        'PCA 1 (Expected Variance: {}%'.format(round(explained_variance_pca_concatenated_values[0]*100,2)) + ")",
        'PCA 2 (Expected Variance: {}%'.format(round(explained_variance_pca_concatenated_values[1]*100,2)) + ")",
        'PCA 1 (Expected Variance: {}%'.format(round(explained_variance_pca_ha_values[0]*100,2)) + ")",
        'PCA 2 (Expected Variance: {}%'.format(round(explained_variance_pca_ha_values[1]*100,2)) + ")",
        'UMAP 1',
        'UMAP 2',
        'UMAP 1',
        'UMAP 2',
    ],
    clade_membership+":N",
    ['strain', clade_membership],
    domain,
    range_
)

In [None]:
chart_embeddings = alt.vconcat(data[0], data[6]|data[5], data[2]|data[1], data[4]|data[3], data[8]|data[7])
chart_embeddings

In [None]:
chart_embeddings.save(output_full_html)
save(chart_embeddings, output_full_png, scale_factor=2.0)