In [None]:
import sys
sys.path.append("notebooks/scripts/")

In [None]:
import altair as alt
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from Helpers import linking_tree_with_plots_brush
from Helpers import make_node_branch_widths, make_branch_lines_for_columns

%matplotlib inline

## Define inputs, outputs, and parameters

In [None]:
colors_path = snakemake.input.colors
tree_path = snakemake.input.tree
embeddings_path = snakemake.input.annotated_embeddings
accuracy_path = snakemake.input.accuracy_table
explained_variance_pca = snakemake.input.explained_variance_pca

In [None]:
interactive_chart_by_clades = snakemake.output.fullChart
static_chart_by_clades = snakemake.output.fullChartPNG

interactive_chart_by_clusters = snakemake.output.fullChartHDBSCAN20182020
static_chart_by_clusters = snakemake.output.fullChartHDBSCANPNG20182020

interactive_mds_chart = snakemake.output.MDS_Supplement
static_mds_chart = snakemake.output.MDS_Supplement_PNG

## Load data

In [None]:
node_branch_widths = make_node_branch_widths(tree_path)

In [None]:
node_branch_widths.head()

In [None]:
embeddings_df = pd.read_csv(embeddings_path, sep="\t")

In [None]:
embeddings_df.head()

In [None]:
embeddings_df = embeddings_df.rename(
    columns={
        "num_date": "date",
    }
)

In [None]:
clade_counts = embeddings_df.query("is_internal_node == False")["clade_membership"].value_counts()

In [None]:
clade_counts

Only assign colors to clades with at least 10 samples. This approach allows us to clearly see larger clades using fewer colors.

In [None]:
clades_to_plot_with_color = sorted(clade_counts[clade_counts >= 10].index.values)

In [None]:
clades_to_plot_with_color

In [None]:
clades_to_plot = sorted(embeddings_df["clade_membership"].drop_duplicates().values)

In [None]:
clades_to_plot

In [None]:
domain = clades_to_plot_with_color

In [None]:
len(clades_to_plot_with_color)

In [None]:
colors = pd.read_csv(colors_path, sep="\t", names=[i for i in range(0,101)], nrows=101)

In [None]:
clade_color_range = colors.iloc[len(clades_to_plot_with_color) - 1].dropna().tolist()

In [None]:
len(clade_color_range)

In [None]:
domain.append("other")

In [None]:
clade_color_range.append("#999999")

In [None]:
embeddings_df["clade_membership_color"] = embeddings_df["clade_membership"].apply(
    lambda clade: clade if clade in clades_to_plot_with_color else "other"
)

In [None]:
embeddings_df.head()

In [None]:
embeddings_df["clade_membership_color"].value_counts()

In [None]:
domain

In [None]:
clade_color_range

In [None]:
accuracy_df = pd.read_csv(accuracy_path)

In [None]:
accuracy_df.head()

In [None]:
explained_variance_df = pd.read_csv(explained_variance_pca)

In [None]:
explained_variance_PCA = explained_variance_df["explained variance"].values.tolist()

In [None]:
explained_variance_PCA

## Setup branches

In [None]:
embedding_columns = [
    "pca1",
    "pca2",
    "mds1",
    "mds2",
    "mds3",
    "tsne_x",
    "tsne_y",
    "umap_x",
    "umap_y",
]

In [None]:
embedding_positions = embeddings_df.loc[
    :,
    ["strain", "parent_name", "clade_membership_color"] + embedding_columns
]

In [None]:
embedding_positions.head()

In [None]:
embedding_segments = embedding_positions.merge(
    embedding_positions,
    left_on="parent_name",
    right_on="strain",
    how="inner",
    suffixes=["", "_parent"],
).drop(
    columns=[
        "clade_membership_color",
        "strain_parent",
        "parent_name_parent",
    ]
).rename(
    columns={
        "clade_membership_color_parent": "clade_membership_color",
    }
).merge(
    node_branch_widths,
    left_on="strain",
    right_on="node",
    how="inner",
)

In [None]:
embedding_segments.head()

In [None]:
embedding_segments.shape

## Plot MDS embeddings

In [None]:
list_of_chart = linking_tree_with_plots_brush(
    embeddings_df,
    ["mds1", "mds2", "mds2", "mds3"],
    ["MDS 1", "MDS 2", "MDS 2", "MDS 3"],
    "clade_membership_color:N",
    "Clade membership",
    ['strain', "clade_membership"],
    domain,
    clade_color_range,
)

In [None]:
mds12_branch_lines = make_branch_lines_for_columns(
    embedding_segments,
    "mds1",
    "mds2",
    domain,
    clade_color_range,
)

In [None]:
mds23_branch_lines = make_branch_lines_for_columns(
    embedding_segments,
    "mds2",
    "mds3",
    domain,
    clade_color_range,
)

In [None]:
MDSFluBrush = (
    (list_of_chart[0]) &
    (
        (mds12_branch_lines + list_of_chart[1]) |
        (mds23_branch_lines + list_of_chart[2])
    )
).configure_axis(grid=False).configure_view(stroke=None)
MDSFluBrush

In [None]:
MDSFluBrush.save(interactive_mds_chart)
MDSFluBrush.save(static_mds_chart, format="png", scale_factor=2.0)

## Plot all embeddings by clade

In [None]:
data = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2', 'tsne_x', 'tsne_y', 'pca1', 'pca2', 'umap_x', 'umap_y'],
    [
        'MDS 1',
        'MDS 2',
        't-SNE 1',
        't-SNE 2',
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0]*100,2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1]*100,2)) + ")",
        'UMAP 1',
        'UMAP 2'
    ],
    'clade_membership_color:N',
    "Clade membership",
    ['strain', 'clade_membership'],
    domain,
    clade_color_range
)

In [None]:
pca = data[3]

In [None]:
pca_branch_lines = make_branch_lines_for_columns(embedding_segments, "pca1", "pca2", domain, clade_color_range)

In [None]:
mds_branch_lines = make_branch_lines_for_columns(embedding_segments, "mds1", "mds2", domain, clade_color_range)

In [None]:
tsne_branch_lines = make_branch_lines_for_columns(embedding_segments, "tsne_x", "tsne_y", domain, clade_color_range)

In [None]:
umap_branch_lines = make_branch_lines_for_columns(embedding_segments, "umap_x", "umap_y", domain, clade_color_range)

In [None]:
PCAMDS = (
    (pca_branch_lines + data[3]) |
    (mds_branch_lines + data[1])
)
TSNEUMAP = (
    (tsne_branch_lines + data[2]) | 
    (umap_branch_lines + data[4])
)
embeddings = alt.vconcat(PCAMDS,TSNEUMAP)
embeddings
fullChart = alt.vconcat(data[0],embeddings).configure_axis(grid=False).configure_view(stroke=None)
fullChart

In [None]:
fullChart.save(interactive_chart_by_clades)
fullChart.save(static_chart_by_clades, format="png", scale_factor=2.0)

In [None]:
poster_embeddings_by_clade = alt.vconcat(
    data[0].properties(width=1100),
    (data[3] | data[1] | data[2] | data[4]),
).configure_axis(grid=False).configure_view(stroke=None)
poster_embeddings_by_clade

## Plot all embeddings by cluster

In [None]:
def build_color_range_for_domain(domain, colors, value_for_unassigned=None):
    # Rows are zero-indexed, so to get N colors, we select row N - 1.
    range_ = colors.loc[len(domain) - 1].dropna().values.tolist()
   
    # Replace known values for "unassigned" clade or cluster labels.
    index_for_unassigned = None
    if value_for_unassigned is not None and value_for_unassigned in domain:
        index_for_unassigned = domain.index(value_for_unassigned)
        range_[index_for_unassigned] = "#999999"
        
    return range_

In [None]:
pca_label_color_domain =  sorted(embeddings_df["pca_label"].drop_duplicates().dropna().values)
pca_label_color_range = build_color_range_for_domain(
    pca_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
pca_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['pca1', 'pca2'],
    [
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0] * 100, 2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1] * 100, 2)) + ")"
    ],
    'pca_label:N',
    "PCA cluster",
    ['strain', 'clade_membership', 'pca_label'],
    pca_label_color_domain,
    pca_label_color_range,
)

In [None]:
mds_label_color_domain =  sorted(embeddings_df["mds_label"].drop_duplicates().dropna().values)
mds_label_color_range = build_color_range_for_domain(
    mds_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
mds_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2'],
    ['MDS 1', 'MDS 2'],
    'mds_label:N',
    "MDS cluster",
    ['strain', 'clade_membership', 'mds_label'],
    mds_label_color_domain,
    mds_label_color_range,
)

In [None]:
tsne_label_color_domain =  sorted(embeddings_df["t-sne_label"].drop_duplicates().dropna().values)
tsne_label_color_range = build_color_range_for_domain(
    tsne_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
tsne_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['tsne_x', 'tsne_y'],
    ['t-SNE 1', 't-SNE 2'],
    't-sne_label:N',
    "t-SNE cluster",
    ['strain', 'clade_membership', 't-sne_label'],
    tsne_label_color_domain,
    tsne_label_color_range,
)

In [None]:
umap_label_color_domain =  sorted(embeddings_df["umap_label"].drop_duplicates().dropna().values)
umap_label_color_range = build_color_range_for_domain(
    umap_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
umap_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['umap_x', 'umap_y'],
    ['UMAP 1', 'UMAP 2'],
    'umap_label:N',
    "UMAP cluster",
    ['strain', 'clade_membership', 'umap_label'],
    umap_label_color_domain,
    umap_label_color_range,
)

In [None]:
accuracy_by_method = dict(accuracy_df.loc[:, ["method", "normalized_vi"]].values)

In [None]:
accuracy_by_method

In [None]:
pca_cluster_branch_lines = make_branch_lines_for_columns(embedding_segments, "pca1", "pca2")

In [None]:
mds_cluster_branch_lines = make_branch_lines_for_columns(embedding_segments, "mds1", "mds2")

In [None]:
tsne_cluster_branch_lines = make_branch_lines_for_columns(embedding_segments, "tsne_x", "tsne_y")

In [None]:
umap_cluster_branch_lines = make_branch_lines_for_columns(embedding_segments, "umap_x", "umap_y")

In [None]:
composed_pca_by_cluster = pca_by_cluster[0] | (pca_cluster_branch_lines + pca_by_cluster[1]).properties(
    title= f"Normalized VI: {accuracy_by_method['pca']:.2f}"
)

composed_mds_by_cluster = mds_by_cluster[0] | (mds_cluster_branch_lines + mds_by_cluster[1]).properties(
    title= f"Normalized VI: {accuracy_by_method['mds']:.2f}"
)

composed_tsne_by_cluster = tsne_by_cluster[0] | (tsne_cluster_branch_lines + tsne_by_cluster[1]).properties(
    title= f"Normalized VI: {accuracy_by_method['t-sne']:.2f}"
)

composed_umap_by_cluster = umap_by_cluster[0] | (umap_cluster_branch_lines + umap_by_cluster[1]).properties(
    title= f"Normalized VI: {accuracy_by_method['umap']:.2f}"
)

In [None]:
pca_mds = alt.vconcat(composed_pca_by_cluster, composed_mds_by_cluster).resolve_scale(color='independent')
tsne_umap = alt.vconcat(composed_tsne_by_cluster, composed_umap_by_cluster).resolve_scale(color='independent')
full_chart_by_cluster = alt.vconcat(pca_mds, tsne_umap).resolve_scale(color='independent').configure_axis(grid=False).configure_view(stroke=None)
full_chart_by_cluster

In [None]:
full_chart_by_cluster.save(interactive_chart_by_clusters)
full_chart_by_cluster.save(static_chart_by_clusters, format="png", scale_factor=2.0)

Plot figures without legends for poster layout.

In [None]:
pca_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['pca1', 'pca2'],
    [
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0] * 100, 2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1] * 100, 2)) + ")"
    ],
    'pca_label:N',
    "PCA cluster",
    ['strain', 'clade_membership', 'pca_label'],
    pca_label_color_domain,
    pca_label_color_range,
    plot_legend=False,
)

In [None]:
mds_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2'],
    ['MDS 1', 'MDS 2'],
    'mds_label:N',
    "MDS cluster",
    ['strain', 'clade_membership', 'mds_label'],
    mds_label_color_domain,
    mds_label_color_range,
    plot_legend=False,
)

In [None]:
tsne_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['tsne_x', 'tsne_y'],
    ['t-SNE 1', 't-SNE 2'],
    't-sne_label:N',
    "t-SNE cluster",
    ['strain', 'clade_membership', 't-sne_label'],
    tsne_label_color_domain,
    tsne_label_color_range,
    plot_legend=False,
)

In [None]:
umap_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['umap_x', 'umap_y'],
    ['UMAP 1', 'UMAP 2'],
    'umap_label:N',
    "UMAP cluster",
    ['strain', 'clade_membership', 'umap_label'],
    umap_label_color_domain,
    umap_label_color_range,
    plot_legend=False,
)

In [None]:
poster_embeddings_by_cluster = alt.vconcat(
    (
        pca_by_cluster[0].properties(width=250, height=250) |
        mds_by_cluster[0].properties(width=250, height=250) |
        tsne_by_cluster[0].properties(width=250, height=250) |
        umap_by_cluster[0].properties(width=250, height=250)
    ).resolve_scale(color='independent'),
    (
        pca_by_cluster[1].properties(width=250, height=250, title= f"Normalized VI: {accuracy_by_method['pca']:.2f}") |
        mds_by_cluster[1].properties(width=250, height=250, title= f"Normalized VI: {accuracy_by_method['mds']:.2f}") |
        tsne_by_cluster[1].properties(width=250, height=250, title= f"Normalized VI: {accuracy_by_method['t-sne']:.2f}") |
        umap_by_cluster[1].properties(width=250, height=250, title= f"Normalized VI: {accuracy_by_method['umap']:.2f}")
    ).resolve_scale(color='independent')
).configure_legend(disable=True).configure_axis(grid=False).configure_view(stroke=None)

poster_embeddings_by_cluster