In [None]:
import sys
sys.path.append("notebooks/scripts/")

In [None]:
from collections import defaultdict

import altair as alt
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from Helpers import linking_tree_with_plots_brush
from Helpers import make_node_branch_widths, make_branch_lines_for_columns
from Helpers import get_clade_label_chart

%matplotlib inline

## Define inputs, outputs, and parameters

In [None]:
colors_path = snakemake.input.colors
tree_path = snakemake.input.tree
embeddings_path = snakemake.input.annotated_embeddings
accuracy_path = snakemake.input.accuracy_table
explained_variance_pca = snakemake.input.explained_variance_pca

In [None]:
interactive_chart_by_clades = snakemake.output.fullChart
static_chart_by_clades = snakemake.output.fullChartPNG

interactive_chart_by_clusters = snakemake.output.fullChartHDBSCAN
static_chart_by_clusters = snakemake.output.fullChartHDBSCANPNG

interactive_mds_chart = snakemake.output.MDS_Supplement
static_mds_chart = snakemake.output.MDS_Supplement_PNG

In [None]:
sns.set_style("ticks")

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 120

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rc('text', usetex=False)

In [None]:
max_items_per_column_in_legend = 16

## Load data

In [None]:
node_branch_widths = make_node_branch_widths(tree_path)

In [None]:
node_branch_widths.head()

In [None]:
embeddings_df = pd.read_csv(embeddings_path, sep="\t")

In [None]:
embeddings_df = embeddings_df.rename(
    columns={
        "numdate": "date",
    }
)

In [None]:
embeddings_df.head()

In [None]:
clade_counts = embeddings_df.query("is_internal_node == False")["clade_membership"].value_counts()

In [None]:
clade_counts

In [None]:
clades_to_plot = sorted(clade_counts[clade_counts >= 10].index.values)

In [None]:
clades_to_plot

In [None]:
len(clades_to_plot)

In [None]:
colors = pd.read_csv(colors_path, sep="\t", names=[i for i in range(0,101)], nrows=101)

In [None]:
colors.head(16)

In [None]:
domain = clades_to_plot

In [None]:
clade_color_range = colors.iloc[len(clades_to_plot) - 1].dropna().tolist()

In [None]:
len(clade_color_range)

In [None]:
domain.append("other")

In [None]:
clade_color_range.append("#999999")

In [None]:
embeddings_df["clade_membership_color"] = embeddings_df["clade_membership"].apply(
    lambda clade: clade if clade in clades_to_plot else "other"
)

In [None]:
embeddings_df.head()

In [None]:
embeddings_df.query("is_internal_node == False")["clade_membership_color"].value_counts()

In [None]:
accuracy_df = pd.read_csv(accuracy_path)

In [None]:
accuracy_df.head()

In [None]:
explained_variance_df = pd.read_csv(explained_variance_pca)

In [None]:
explained_variance_df.head()

## Setup branches

In [None]:
embeddings_df.head()

In [None]:
embedding_columns = [
    "pca1",
    "pca2",
    "mds1",
    "mds2",
    "mds3",
    "tsne_x",
    "tsne_y",
    "umap_x",
    "umap_y",
]

In [None]:
embedding_positions = embeddings_df.loc[
    :,
    ["strain", "parent_name", "clade_membership_color"] + embedding_columns
]

In [None]:
embedding_positions.head()

In [None]:
embedding_segments = embedding_positions.merge(
    embedding_positions,
    left_on="parent_name",
    right_on="strain",
    how="inner",
    suffixes=["", "_parent"],
).drop(
    columns=[
        "clade_membership_color",
        "strain_parent",
        "parent_name_parent",
    ]
).rename(
    columns={
        "clade_membership_color_parent": "clade_membership_color",
    }
).merge(
    node_branch_widths,
    left_on="strain",
    right_on="node",
    how="inner",
)

In [None]:
embedding_segments.head()

In [None]:
embedding_segments.shape

Add parent clade membership color to embeddings for use in tree plots.

In [None]:
parent_clade_membership_color = embedding_segments.loc[:, ["strain", "clade_membership_color"]].rename(
    columns={
        "clade_membership_color": "parent_clade_membership_color",
    }
)

In [None]:
parent_clade_membership_color.head()

In [None]:
embeddings_df = embeddings_df.merge(
    parent_clade_membership_color,
    on="strain",
    validate="1:1",
)

In [None]:
embeddings_df.head()

In [None]:
(embeddings_df["clade_membership_color"] != embeddings_df["parent_clade_membership_color"]).sum()

In [None]:
embeddings_df["clade_membership_short"] = embeddings_df["clade_membership_color"].apply(
    lambda clade: clade.split("/")[-1]
)

In [None]:
clade_label_positions_in_tree = embeddings_df.loc[
    (embeddings_df["is_internal_node"]) & (embeddings_df["clade_membership_short"] != "other"),
    ["clade_membership_short", "divergence", "y_value"]
].sort_values([
    "clade_membership_short",
    "divergence",
]).groupby(
    "clade_membership_short"
).first().reset_index()

clade_label_positions_in_tree["divergence"] = clade_label_positions_in_tree["divergence"] - 0.001

clade_label_positions_in_tree["y_value"] = clade_label_positions_in_tree["y_value"] + 25

clade_labels_for_tree_chart = alt.Chart(clade_label_positions_in_tree).mark_text().encode(
    x="divergence:Q",
    y="y_value:Q",
    text="clade_membership_short:N",
)

## Plot PCA variance and embeddings

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.plot(
    explained_variance_df["principal components"],
    explained_variance_df["explained variance"],
    "o"
)

ax.set_xlabel("Principal Component")
ax.set_ylabel("Explained Variance")
ax.set_ylim(bottom=0)

plt.tight_layout()

In [None]:
explained_variance_PCA = explained_variance_df["explained variance"].values.tolist()

In [None]:
pcs = explained_variance_df["principal components"].values

In [None]:
pcs

In [None]:
(tree, pca) = linking_tree_with_plots_brush(
    embeddings_df,
    [f"pca{pc}" for pc in pcs],
    [
        f"PC {pc} (Explained Variance : {variance * 100:.2f}%)"
        for pc, variance in zip(pcs, explained_variance_PCA)
    ],
    "clade_membership_color:N",
    "Clade membership",
    ['strain', "clade_membership"],
    domain,
    clade_color_range,
)

In [None]:
pca_branch_lines = make_branch_lines_for_columns(embedding_segments, "pca1", "pca2", domain, clade_color_range)

In [None]:
PCAFluBrush = (tree & (pca_branch_lines + pca)).configure_axis(grid=False).configure_view(stroke=None)
PCAFluBrush

## Plot MDS embeddings

In [None]:
(tree, mds12, mds23) = linking_tree_with_plots_brush(
    embeddings_df,
    ["mds1", "mds2", "mds2", "mds3"],
    ["MDS 1", "MDS 2", "MDS 2", "MDS 3"],
    "clade_membership_color:N",
    "Clade membership",
    ['strain', "clade_membership"],
    domain,
    clade_color_range,
)

In [None]:
mds12_branch_lines = make_branch_lines_for_columns(embedding_segments, "mds1", "mds2", domain, clade_color_range)

In [None]:
mds23_branch_lines = make_branch_lines_for_columns(embedding_segments, "mds2", "mds3", domain, clade_color_range)

In [None]:
clade_labels_for_mds12_chart = get_clade_label_chart(
    embeddings_df,
    "mds1",
    "mds2",
    "clade_membership_short",
    xoffset_by_label={
        "A1": -2,
    },
    yoffset_by_label={
        "135N": -1,
    }
)

In [None]:
clade_labels_for_mds23_chart = get_clade_label_chart(
    embeddings_df,
    "mds2",
    "mds3",
    "clade_membership_short",
    drop_labels={"A2"},
    xoffset_by_label={
        "135N": -3,
    },
    yoffset_by_label={
        "A1": 2,
    }
)

In [None]:
MDSFluBrush = (
    (tree + clade_labels_for_tree_chart) &
    (
        (mds12_branch_lines + mds12 + clade_labels_for_mds12_chart) |
        (mds23_branch_lines + mds23 + clade_labels_for_mds23_chart)
    )
).configure_axis(grid=False).configure_view(stroke=None)
MDSFluBrush

In [None]:
MDSFluBrush.save(interactive_mds_chart)
MDSFluBrush.save(static_mds_chart, format="png", ppi=300)

## Plot all embeddings by clade

In [None]:
data = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2', 'tsne_x', 'tsne_y', 'pca1', 'pca2', 'umap_x', 'umap_y'],
    [
        'MDS 1',
        'MDS 2',
        't-SNE 1',
        't-SNE 2',
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0]*100,2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1]*100,2)) + ")",
        'UMAP 1',
        'UMAP 2'
    ],
    'clade_membership_color:N',
    "Clade membership",
    ['strain', 'clade_membership'],
    domain,
    clade_color_range
)

In [None]:
clade_labels_for_pca_chart = get_clade_label_chart(
    embeddings_df,
    "pca1",
    "pca2",
    "clade_membership_short",
    xoffset_by_label={
        "A1b": 1,
        "135N": -0.75,
        "A4": -0.5,
    },
    yoffset_by_label={
        "135K": 0.5,
    }
)

In [None]:
pca = data[3] + clade_labels_for_pca_chart

In [None]:
pca

In [None]:
pca_branch_lines = make_branch_lines_for_columns(embedding_segments, "pca1", "pca2", domain, clade_color_range)

In [None]:
clade_labels_for_mds_chart = get_clade_label_chart(
    embeddings_df,
    "mds1",
    "mds2",
    "clade_membership_short",
    xoffset_by_label={
        "A1": -2,
    },
    yoffset_by_label={
        "135N": -1,
    }
)

In [None]:
mds = data[1] + clade_labels_for_mds_chart

In [None]:
mds_branch_lines = make_branch_lines_for_columns(embedding_segments, "mds1", "mds2", domain, clade_color_range)

In [None]:
clade_labels_for_tsne_chart = get_clade_label_chart(
    embeddings_df,
    "tsne_x",
    "tsne_y",
    "clade_membership_short",
    xoffset_by_label={
        "A1b": 1,
        "135K": -1,
    }
)

In [None]:
tsne = data[2] + clade_labels_for_tsne_chart

In [None]:
tsne_branch_lines = make_branch_lines_for_columns(embedding_segments, "tsne_x", "tsne_y", domain, clade_color_range)

In [None]:
clade_labels_for_umap_chart = get_clade_label_chart(
    embeddings_df,
    "umap_x",
    "umap_y",
    "clade_membership_short",
    xoffset_by_label={
        "135K": -1,
        "A1b": 1,
    },
    yoffset_by_label={
        "135K": 1,
        "A1b": 1,
    }
)

In [None]:
umap = data[4] + clade_labels_for_umap_chart

In [None]:
umap_branch_lines = make_branch_lines_for_columns(embedding_segments, "umap_x", "umap_y", domain, clade_color_range)

In [None]:
(
    (
        (pca_branch_lines) |
        (mds_branch_lines)
    ) &
    (
        (tsne_branch_lines) |
        (umap_branch_lines)
    )
).configure_axis(grid=False).configure_view(stroke=None)

In [None]:
PCAMDS = (
    (pca_branch_lines + pca) |
    (mds_branch_lines + mds)
)
TSNEUMAP = (
    (tsne_branch_lines + tsne) |
    (umap_branch_lines + umap)
)
embeddings = alt.vconcat(PCAMDS,TSNEUMAP)
embeddings
fullChart = alt.vconcat((data[0] + clade_labels_for_tree_chart), embeddings).configure_axis(grid=False).configure_view(stroke=None)
fullChart

In [None]:
fullChart.save(interactive_chart_by_clades)
fullChart.save(static_chart_by_clades, format="png", ppi=300)

## Plot all embeddings by cluster

In [None]:
def build_color_range_for_domain(domain, colors, value_for_unassigned=None):
    # Rows are zero-indexed, so to get N colors, we select row N - 1.
    range_ = colors.loc[len(domain) - 1].dropna().values.tolist()
   
    # Replace known values for "unassigned" clade or cluster labels.
    index_for_unassigned = None
    if value_for_unassigned is not None and value_for_unassigned in domain:
        index_for_unassigned = domain.index(value_for_unassigned)
        range_[index_for_unassigned] = "#999999"
        
    return range_

In [None]:
pca_label_color_domain =  sorted(embeddings_df["pca_label"].drop_duplicates().dropna().values)
pca_label_color_range = build_color_range_for_domain(
    pca_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
pca_legend_columns = 2 if len(pca_label_color_domain) > max_items_per_column_in_legend else 1

In [None]:
pca_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['pca1', 'pca2'],
    [
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0] * 100, 2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1] * 100, 2)) + ")"
    ],
    'pca_label:N',
    "PCA cluster",
    ['strain', 'clade_membership', 'pca_label'],
    pca_label_color_domain,
    pca_label_color_range,
    legend_columns=pca_legend_columns,
    color_branches=False,
)

In [None]:
mds_label_color_domain =  sorted(embeddings_df["mds_label"].drop_duplicates().dropna().values)
mds_label_color_range = build_color_range_for_domain(
    mds_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
mds_legend_columns = 2 if len(mds_label_color_domain) > max_items_per_column_in_legend else 1

In [None]:
mds_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2'],
    ['MDS 1', 'MDS 2'],
    'mds_label:N',
    "MDS cluster",
    ['strain', 'clade_membership', 'mds_label'],
    mds_label_color_domain,
    mds_label_color_range,
    legend_columns=mds_legend_columns,
    color_branches=False,
)

In [None]:
tsne_label_color_domain =  sorted(embeddings_df["t-sne_label"].drop_duplicates().dropna().values)
tsne_label_color_range = build_color_range_for_domain(
    tsne_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
tsne_legend_columns = 2 if len(tsne_label_color_domain) > max_items_per_column_in_legend else 1

In [None]:
tsne_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['tsne_x', 'tsne_y'],
    ['t-SNE 1', 't-SNE 2'],
    't-sne_label:N',
    "t-SNE cluster",
    ['strain', 'clade_membership', 't-sne_label'],
    tsne_label_color_domain,
    tsne_label_color_range,
    legend_columns=tsne_legend_columns,
    color_branches=False,
)

In [None]:
umap_label_color_domain =  sorted(embeddings_df["umap_label"].drop_duplicates().dropna().values)
umap_label_color_range = build_color_range_for_domain(
    umap_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
umap_legend_columns = 2 if len(umap_label_color_domain) > max_items_per_column_in_legend else 1

In [None]:
umap_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['umap_x', 'umap_y'],
    ['UMAP 1', 'UMAP 2'],
    'umap_label:N',
    "UMAP cluster",
    ['strain', 'clade_membership', 'umap_label'],
    umap_label_color_domain,
    umap_label_color_range,
    legend_columns=umap_legend_columns,
    color_branches=False,
)

In [None]:
accuracy_by_method = dict(accuracy_df.loc[:, ["method", "normalized_vi"]].values)

In [None]:
accuracy_by_method

In [None]:
pca_cluster_branch_lines = make_branch_lines_for_columns(embedding_segments, "pca1", "pca2")
mds_cluster_branch_lines = make_branch_lines_for_columns(embedding_segments, "mds1", "mds2")
tsne_cluster_branch_lines = make_branch_lines_for_columns(embedding_segments, "tsne_x", "tsne_y")
umap_cluster_branch_lines = make_branch_lines_for_columns(embedding_segments, "umap_x", "umap_y")

In [None]:
composed_pca_by_cluster = pca_by_cluster[0] | (pca_cluster_branch_lines + pca_by_cluster[1]).properties(
    title= f"Normalized VI: {accuracy_by_method['pca']:.2f}"
)

composed_mds_by_cluster = mds_by_cluster[0] | (mds_cluster_branch_lines + mds_by_cluster[1]).properties(
    title= f"Normalized VI: {accuracy_by_method['mds']:.2f}"
)

composed_tsne_by_cluster = tsne_by_cluster[0] | (tsne_cluster_branch_lines + tsne_by_cluster[1]).properties(
    title= f"Normalized VI: {accuracy_by_method['t-sne']:.2f}"
)

composed_umap_by_cluster = umap_by_cluster[0] | (umap_cluster_branch_lines + umap_by_cluster[1]).properties(
    title= f"Normalized VI: {accuracy_by_method['umap']:.2f}"
)

In [None]:
pca_mds = alt.vconcat(composed_pca_by_cluster, composed_mds_by_cluster).resolve_scale(color='independent')
tsne_umap = alt.vconcat(composed_tsne_by_cluster, composed_umap_by_cluster).resolve_scale(color='independent')
full_chart_by_cluster = alt.vconcat(pca_mds, tsne_umap).resolve_scale(color='independent').configure_axis(grid=False).configure_view(stroke=None)
full_chart_by_cluster

In [None]:
full_chart_by_cluster.save(interactive_chart_by_clusters)
full_chart_by_cluster.save(static_chart_by_clusters, format="png", ppi=300)