In [None]:
import sys
sys.path.append("notebooks/scripts/")

In [None]:
import altair as alt
from altair_saver import save
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from Helpers import linking_tree_with_plots_brush

%matplotlib inline

## Define inputs, outputs, and parameters

In [None]:
colors_path = snakemake.input.colors
embeddings_path = snakemake.input.annotated_embeddings
accuracy_path = snakemake.input.accuracy_table
explained_variance_pca = snakemake.input.explained_variance_pca

In [None]:
interactive_chart_by_clusters = snakemake.output.fullChartHDBSCAN20182020
static_chart_by_clusters = snakemake.output.fullChartHDBSCANPNG20182020

## Load data

In [None]:
colors = pd.read_csv(colors_path, sep="\t", names=[i for i in range(0,101)])

In [None]:
embeddings_df = pd.read_csv(embeddings_path, sep="\t")

In [None]:
embeddings_df = embeddings_df.rename(
    columns={
        "num_date": "date",
        "y_value": "y"
    }
)

In [None]:
embeddings_df.head()

In [None]:
accuracy_df = pd.read_csv(accuracy_path)

In [None]:
accuracy_df.head()

In [None]:
explained_variance_df = pd.read_csv(explained_variance_pca)

In [None]:
explained_variance_PCA = explained_variance_df["explained variance"].values.tolist()

In [None]:
explained_variance_PCA

## Plot all embeddings by cluster

In [None]:
def build_color_range_for_domain(domain, colors, value_for_unassigned=None):
    # Rows are zero-indexed, so to get N colors, we select row N - 1.
    range_ = colors.loc[len(domain) - 1].dropna().values.tolist()
   
    # Replace known values for "unassigned" clade or cluster labels.
    index_for_unassigned = None
    if value_for_unassigned is not None and value_for_unassigned in domain:
        index_for_unassigned = domain.index(value_for_unassigned)
        range_[index_for_unassigned] = "#999999"
        
    return range_

In [None]:
pca_label_color_domain =  sorted(embeddings_df["pca_label"].drop_duplicates().values)
pca_label_color_range = build_color_range_for_domain(
    pca_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
pca_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['pca1', 'pca2'],
    [
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0] * 100, 2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1] * 100, 2)) + ")"
    ],
    'pca_label:N',
    ['strain', 'clade_membership', 'pca_label'],
    pca_label_color_domain,
    pca_label_color_range,
)

In [None]:
mds_label_color_domain =  sorted(embeddings_df["mds_label"].drop_duplicates().values)
mds_label_color_range = build_color_range_for_domain(
    mds_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
mds_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2'],
    ['MDS 1', 'MDS 2'],
    'mds_label:N',
    ['strain', 'clade_membership', 'mds_label'],
    mds_label_color_domain,
    mds_label_color_range,
)

In [None]:
tsne_label_color_domain =  sorted(embeddings_df["t-sne_label"].drop_duplicates().values)
tsne_label_color_range = build_color_range_for_domain(
    tsne_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
tsne_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['tsne_x', 'tsne_y'],
    ['t-SNE 1', 't-SNE 2'],
    't-sne_label:N',
    ['strain', 'clade_membership', 't-sne_label'],
    tsne_label_color_domain,
    tsne_label_color_range,
)

In [None]:
umap_label_color_domain =  sorted(embeddings_df["umap_label"].drop_duplicates().values)
umap_label_color_range = build_color_range_for_domain(
    umap_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
umap_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['umap_x', 'umap_y'],
    ['UMAP 1', 'UMAP 2'],
    'umap_label:N',
    ['strain', 'clade_membership', 'umap_label'],
    umap_label_color_domain,
    umap_label_color_range,
)

In [None]:
accuracy_by_method = dict(accuracy_df.loc[:, ["embedding", "normalized_VI"]].values)

In [None]:
accuracy_by_method

In [None]:
composed_pca_by_cluster = pca_by_cluster[0] | pca_by_cluster[1].properties(
    title= f"Normalized VI: {accuracy_by_method['pca']}"
)

composed_mds_by_cluster = mds_by_cluster[0] | mds_by_cluster[1].properties(
    title= f"Normalized VI: {accuracy_by_method['mds']}"
)

composed_tsne_by_cluster = tsne_by_cluster[0] | tsne_by_cluster[1].properties(
    title= f"Normalized VI: {accuracy_by_method['t-sne']}"
)

composed_umap_by_cluster = umap_by_cluster[0] | umap_by_cluster[1].properties(
    title= f"Normalized VI: {accuracy_by_method['umap']}"
)

In [None]:
pca_mds = alt.vconcat(composed_pca_by_cluster, composed_mds_by_cluster).resolve_scale(color='independent')
tsne_umap = alt.vconcat(composed_tsne_by_cluster, composed_umap_by_cluster).resolve_scale(color='independent')
full_chart_by_cluster = alt.vconcat(pca_mds, tsne_umap).resolve_scale(color='independent')
full_chart_by_cluster

In [None]:
full_chart_by_cluster.save(interactive_chart_by_clusters)
save(full_chart_by_cluster, static_chart_by_clusters, scale_factor=2.0)