In [None]:
import sys
sys.path.append("notebooks/scripts/")

In [None]:
import altair as alt
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from Helpers import linking_tree_with_plots_brush
from Helpers import make_node_branch_widths, make_branch_lines_for_columns

%matplotlib inline

## Define inputs, outputs, and parameters

In [None]:
colors_path = snakemake.input.colors
tree_path = snakemake.input.tree
embeddings_path = snakemake.input.annotated_embeddings
accuracy_path = snakemake.input.accuracy_table
explained_variance_pca = snakemake.input.explained_variance_pca

In [None]:
interactive_chart_by_clades = snakemake.output.fullChart
static_chart_by_clades = snakemake.output.fullChartPNG

interactive_chart_by_clusters = snakemake.output.fullChartHDBSCAN
static_chart_by_clusters = snakemake.output.fullChartHDBSCANPNG

explained_variance_pca_chart = snakemake.output.Explained_variance_PCA
interactive_pca_chart = snakemake.output.PCA_Supplement
static_pca_chart = snakemake.output.PCA_Supplement_PNG

interactive_mds_chart = snakemake.output.MDS_Supplement
static_mds_chart = snakemake.output.MDS_Supplement_PNG

output_tsne_recombinant_counts_png = snakemake.output.tsne_recombinant_counts if hasattr(snakemake.output, 'tsne_recombinant_counts') else None
output_tsne_recombinant_counts_table = snakemake.output.tsne_recombinant_counts_table if hasattr(snakemake.output, 'tsne_recombinant_counts_table') else None

In [None]:
clade_membership = snakemake.params.clade_membership

In [None]:
pca_label = snakemake.params.pca_label

In [None]:
mds_label = snakemake.params.mds_label

In [None]:
tsne_label = snakemake.params.tsne_label

In [None]:
umap_label = snakemake.params.umap_label

In [None]:
plot_branches = snakemake.params.plot_branches

In [None]:
plot_branches

In [None]:
sns.set_style("ticks")

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 120

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rc('text', usetex=False)

## Load data

In [None]:
colors = pd.read_csv(colors_path, sep="\t", names=[i for i in range(0,400)], header=None, nrows=400)

In [None]:
colors.head()

In [None]:
embeddings_df = pd.read_csv(embeddings_path, sep="\t")

In [None]:
embeddings_df = embeddings_df.rename(
    columns={
        "numdate": "date",
    }
)

In [None]:
embeddings_df[clade_membership] = embeddings_df[clade_membership].fillna("unassigned")

In [None]:
embeddings_df.head()

In [None]:
embeddings_df.columns

In [None]:
accuracy_df = pd.read_csv(accuracy_path)

In [None]:
accuracy_df.head()

In [None]:
explained_variance_df = pd.read_csv(explained_variance_pca)

In [None]:
explained_variance_df.head()

## Setup branches, if requested

In [None]:
node_branch_widths = make_node_branch_widths(tree_path)
print(node_branch_widths.head())

if plot_branches:    
    embedding_columns = [
        "pca1",
        "pca2",
        "mds1",
        "mds2",
        "mds3",
        "tsne_x",
        "tsne_y",
        "umap_x",
        "umap_y",
    ]
    
    embedding_positions = embeddings_df.loc[
        :,
        ["strain", "parent_name", clade_membership] + embedding_columns
    ]
    
    embedding_segments = embedding_positions.merge(
        embedding_positions,
        left_on="parent_name",
        right_on="strain",
        how="inner",
        suffixes=["", "_parent"],
    ).drop(
        columns=[
            clade_membership,
            "strain_parent",
            "parent_name_parent",
        ]
    ).rename(
        columns={
            f"{clade_membership}_parent": clade_membership,
        }
    ).merge(
        node_branch_widths,
        left_on="strain",
        right_on="node",
        how="inner",
    )
    
    print(embedding_segments.head())
    print(embedding_segments.shape)
    
# Always annotate branch widths for plotting tree panels.
embeddings_df = embeddings_df.merge(
    node_branch_widths,
    left_on="strain",
    right_on="node",
    validate="1:1",
)

## Setup colors

In [None]:
def build_color_range_for_domain(domain, colors, value_for_unassigned=None):
    # Rows are zero-indexed, so to get N colors, we select row N - 1.
    range_ = colors.loc[len(domain) - 1].dropna().values.tolist()
   
    # Replace known values for "unassigned" clade or cluster labels.
    index_for_unassigned = None
    if value_for_unassigned is not None and value_for_unassigned in domain:
        index_for_unassigned = domain.index(value_for_unassigned)
        range_[index_for_unassigned] = "#999999"
        
    return range_

In [None]:
domain = sorted(embeddings_df[clade_membership].drop_duplicates().values.tolist())

In [None]:
range_ = build_color_range_for_domain(domain, colors, value_for_unassigned="unassigned")

In [None]:
len(domain)

In [None]:
len(range_)

## Plot PCA variance and embeddings

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.plot(
    explained_variance_df["principal components"],
    explained_variance_df["explained variance"],
    "o"
)

ax.set_xlabel("Principal Component")
ax.set_ylabel("Explained Variance")

ax.set_ylim(bottom=0)

plt.tight_layout()
plt.savefig(explained_variance_pca_chart)

In [None]:
explained_variance_PCA = explained_variance_df["explained variance"].values.tolist()

In [None]:
embeddings_df.head()

In [None]:
legend_columns = int(np.ceil(len(domain) / 30))

In [None]:
pca_charts = linking_tree_with_plots_brush(
    embeddings_df,
    ['pca1', 'pca2'],
    [
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0]*100,2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1]*100,2)) + ")",
    ],
    f"{clade_membership}:N",
    "Clade membership",
    ['strain', clade_membership],
    domain,
    range_,
    legend_columns=legend_columns,
)

In [None]:
if plot_branches:
    pca_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "pca1",
        "pca2",
        domain,
        range_,
        f"{clade_membership}:N",
    )
    
    pca_chart = (
        (pca_charts[0]) &
        (pca_branch_lines + pca_charts[1])
    )
else:
    pca_chart = (
        (pca_charts[0]) & (pca_charts[1])
    )
    
pca_chart = pca_chart.configure_axis(grid=False).configure_view(stroke=None)

In [None]:
pca_chart

In [None]:
pca_chart.save(interactive_pca_chart)
pca_chart.save(static_pca_chart, format="png", scale_factor=2.0)

## Plot MDS embeddings

In [None]:
mds_charts = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2', 'mds2', 'mds3'],
    ["MDS 1", "MDS 2", "MDS 2", "MDS 3"],
    f"{clade_membership}:N",
    "Clade membership",
    ['strain', clade_membership],
    domain,
    range_,
    legend_columns=legend_columns,
)

In [None]:
if plot_branches:
    mds12_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "mds1",
        "mds2",
        domain,
        range_,
        f"{clade_membership}:N",
    )
    mds23_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "mds2",
        "mds3",
        domain,
        range_,
        f"{clade_membership}:N",
    )
    
    mds_chart = (
        (mds_charts[0]) &
        (
            (mds12_branch_lines + mds_charts[1]) |
            (mds23_branch_lines + mds_charts[2])
        )
    )
else:
    mds_chart = (
        (mds_charts[0]) &
        (mds_charts[1] | mds_charts[2])
    )    

mds_chart = mds_chart.configure_axis(grid=False).configure_view(stroke=None)

In [None]:
mds_chart

In [None]:
mds_chart.save(interactive_mds_chart)
mds_chart.save(static_mds_chart, format="png", scale_factor=2.0)

## Plot all embeddings by clade

In [None]:
data = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2', 'tsne_x', 'tsne_y', 'pca1', 'pca2', 'umap_x', 'umap_y'],
    [
        'MDS 1',
        'MDS 2',
        't-SNE 1',
        't-SNE 2',
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0]*100,2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1]*100,2)) + ")",
        'UMAP 1',
        'UMAP 2'
    ],
    f'{clade_membership}:N',
    "Clade membership",
    ['strain', clade_membership],
    domain,
    range_,
    legend_columns=legend_columns,
)

In [None]:
pca = data[3]
mds = data[1]
tsne = data[2]
umap = data[4]

In [None]:
if plot_branches:
    pca_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "pca1",
        "pca2",
        domain,
        range_,
        f"{clade_membership}:N",
    )
    
    mds_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "mds1",
        "mds2",
        domain,
        range_,
        f"{clade_membership}:N",
    )
    
    tsne_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "tsne_x",
        "tsne_y",
        domain,
        range_,
        f"{clade_membership}:N",
    )

    umap_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "umap_x",
        "umap_y",
        domain,
        range_,
        f"{clade_membership}:N",
    )
    
    PCAMDS = (
        (pca_branch_lines + pca) |
        (mds_branch_lines + mds)
    )
    TSNEUMAP = (
        (tsne_branch_lines + tsne) |
        (umap_branch_lines + umap)
    )
else:
    PCAMDS = pca | mds
    TSNEUMAP = tsne | umap

In [None]:
embeddings = alt.vconcat(PCAMDS,TSNEUMAP)
embeddings
fullChart = alt.vconcat(data[0],embeddings).configure_axis(grid=False).configure_view(stroke=None)

In [None]:
fullChart

In [None]:
fullChart.save(interactive_chart_by_clades)
fullChart.save(static_chart_by_clades, format="png", scale_factor=2.0)

In [None]:
poster_embeddings_by_clade = alt.vconcat(
    data[0].properties(width=1100),
    (data[3] | data[1] | data[2] | data[4]),
).configure_axis(grid=False).configure_view(stroke=None)
poster_embeddings_by_clade

## Plot all embeddings by cluster

In [None]:
pca_label_color_domain =  sorted(embeddings_df[pca_label].drop_duplicates().dropna().values)
pca_label_color_range = build_color_range_for_domain(
    pca_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
pca_legend_columns = int(np.ceil(len(pca_label_color_domain) / 20))

In [None]:
pca_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['pca1', 'pca2'],
    [
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0] * 100, 2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1] * 100, 2)) + ")"
    ],
    f'{pca_label}:N',
    "PCA cluster",
    ['strain', clade_membership, pca_label],
    pca_label_color_domain,
    pca_label_color_range,
    legend_columns=pca_legend_columns,
)

In [None]:
mds_label_color_domain =  sorted(embeddings_df[mds_label].drop_duplicates().dropna().values)
mds_label_color_range = build_color_range_for_domain(
    mds_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
mds_legend_columns = int(np.ceil(len(mds_label_color_domain) / 20))

In [None]:
mds_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2'],
    ['MDS 1', 'MDS 2'],
    f'{mds_label}:N',
    "MDS cluster",
    ['strain', clade_membership, mds_label],
    mds_label_color_domain,
    mds_label_color_range,
    legend_columns=mds_legend_columns,
)

In [None]:
tsne_label_color_domain =  sorted(embeddings_df[tsne_label].drop_duplicates().dropna().values)
tsne_label_color_range = build_color_range_for_domain(
    tsne_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
tsne_legend_columns = int(np.ceil(len(tsne_label_color_domain) / 25))

In [None]:
tsne_legend_columns

In [None]:
tsne_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['tsne_x', 'tsne_y'],
    ['t-SNE 1', 't-SNE 2'],
    f'{tsne_label}:N',
    "t-SNE cluster",
    ['strain', clade_membership, tsne_label],
    tsne_label_color_domain,
    tsne_label_color_range,
    legend_columns=tsne_legend_columns,
)

In [None]:
umap_label_color_domain =  sorted(embeddings_df[umap_label].drop_duplicates().dropna().values)
umap_label_color_range = build_color_range_for_domain(
    umap_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
umap_legend_columns = int(np.ceil(len(umap_label_color_domain) / 20))

In [None]:
umap_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['umap_x', 'umap_y'],
    ['UMAP 1', 'UMAP 2'],
    f'{umap_label}:N',
    "UMAP cluster",
    ['strain', clade_membership, umap_label],
    umap_label_color_domain,
    umap_label_color_range,
    legend_columns=umap_legend_columns,
)

In [None]:
accuracy_by_method = dict(accuracy_df.loc[:, ["method", "normalized_vi"]].values)

In [None]:
accuracy_by_method

In [None]:
if plot_branches:
    pca_cluster_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "pca1",
        "pca2",
    )
    pca_clusters = (pca_cluster_branch_lines + pca_by_cluster[1])
    
    mds_cluster_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "mds1",
        "mds2",
    )
    mds_clusters = (mds_cluster_branch_lines + mds_by_cluster[1])
    
    tsne_cluster_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "tsne_x",
        "tsne_y",
    )
    tsne_clusters = (tsne_cluster_branch_lines + tsne_by_cluster[1])

    umap_cluster_branch_lines = make_branch_lines_for_columns(
        embedding_segments,
        "umap_x",
        "umap_y",
    )
    umap_clusters = (umap_cluster_branch_lines + umap_by_cluster[1])
else:
    pca_clusters = pca_by_cluster[1]
    mds_clusters = mds_by_cluster[1]
    tsne_clusters = tsne_by_cluster[1]
    umap_clusters = umap_by_cluster[1]

In [None]:
composed_pca_by_cluster = pca_by_cluster[0] | pca_clusters.properties(
    title= f"Normalized VI: {accuracy_by_method['pca']:.2f}"
)

composed_mds_by_cluster = mds_by_cluster[0] | mds_clusters.properties(
    title= f"Normalized VI: {accuracy_by_method['mds']:.2f}"
)

composed_tsne_by_cluster = tsne_by_cluster[0] | tsne_clusters.properties(
    title= f"Normalized VI: {accuracy_by_method['t-sne']:.2f}"
)

composed_umap_by_cluster = umap_by_cluster[0] | umap_clusters.properties(
    title= f"Normalized VI: {accuracy_by_method['umap']:.2f}"
)

In [None]:
pca_mds = alt.vconcat(composed_pca_by_cluster, composed_mds_by_cluster).resolve_scale(color='independent')
tsne_umap = alt.vconcat(composed_tsne_by_cluster, composed_umap_by_cluster).resolve_scale(color='independent')
full_chart_by_cluster = alt.vconcat(pca_mds, tsne_umap).resolve_scale(color='independent').configure_axis(grid=False).configure_view(stroke=None)

In [None]:
full_chart_by_cluster

In [None]:
full_chart_by_cluster.save(interactive_chart_by_clusters)
full_chart_by_cluster.save(static_chart_by_clusters, format="png", scale_factor=2.0)

In [None]:
tsne_label_column = [column for column in embeddings_df.columns if column.startswith("t-sne_label")][0]

In [None]:
tsne_label_column

In [None]:
embeddings_df[tsne_label_column] != -1

In [None]:
(embeddings_df["Nextclade_pango_collapsed"].str.startswith("X").fillna(False))

In [None]:
embeddings_df.loc[embeddings_df["is_internal_node"] == False]

In [None]:
tsne_recombinant_counts = embeddings_df.loc[
    (
        (embeddings_df["is_internal_node"] == False) &
        (embeddings_df["Nextclade_pango_collapsed"].str.startswith("X").fillna(False)) &
        (embeddings_df[tsne_label_column] != -1)
    ),
    [
        "Nextclade_pango_collapsed",
        tsne_label_column,
    ]
].value_counts().reset_index(name="count").query("count >= 10")

In [None]:
tsne_recombinant_counts[tsne_label_column] = tsne_recombinant_counts[tsne_label_column].astype(int)

In [None]:
tsne_recombinant_counts.shape

In [None]:
tsne_recombinant_counts_chart = alt.Chart(tsne_recombinant_counts).mark_circle().encode(
    x=alt.X("Nextclade_pango_collapsed:N", title="Recombinant Pango lineage"),
    y=alt.Y(f"{tsne_label_column}:N", title="Cluster from t-SNE"),
    size="count:Q",
    tooltip=["Nextclade_pango_collapsed:N", f"{tsne_label_column}:N", "count:Q"],
).properties(
    width=600,
    height=600,
)
tsne_recombinant_counts_chart

In [None]:
if output_tsne_recombinant_counts_png:
    tsne_recombinant_counts_chart.save(output_tsne_recombinant_counts_png, format="png", scale_factor=2.0)

In [None]:
tsne_recombinant_counts[tsne_label_column] = tsne_recombinant_counts[tsne_label_column].astype(str)

In [None]:
tsne_recombinant_counts

In [None]:
if output_tsne_recombinant_counts_table:
    tsne_recombinant_counts.to_csv(
        output_tsne_recombinant_counts_table,
        index=False,
    )

In [None]:
tsne_recombinant_counts["Nextclade_pango_collapsed"].value_counts().shape

In [None]:
(tsne_recombinant_counts["Nextclade_pango_collapsed"].value_counts() == 1).sum()

In [None]:
tsne_recombinant_counts[tsne_label_column].value_counts().shape

In [None]:
(tsne_recombinant_counts[tsne_label_column].value_counts() == 1).sum()

## Prepare figures for poster

In [None]:
pca_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['pca1', 'pca2'],
    [
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0] * 100, 2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1] * 100, 2)) + ")"
    ],
    f'{pca_label}:N',
    "PCA cluster",
    ['strain', clade_membership, pca_label],
    pca_label_color_domain,
    pca_label_color_range,
    legend_columns=pca_legend_columns,
    plot_legend=False,
)

In [None]:
mds_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2'],
    ['MDS 1', 'MDS 2'],
    f'{mds_label}:N',
    "MDS cluster",
    ['strain', clade_membership, mds_label],
    mds_label_color_domain,
    mds_label_color_range,
    legend_columns=mds_legend_columns,
    plot_legend=False,
)

In [None]:
tsne_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['tsne_x', 'tsne_y'],
    ['t-SNE 1', 't-SNE 2'],
    f'{tsne_label}:N',
    "t-SNE cluster",
    ['strain', clade_membership, tsne_label],
    tsne_label_color_domain,
    tsne_label_color_range,
    legend_columns=tsne_legend_columns,
    plot_legend=False,
)

In [None]:
umap_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['umap_x', 'umap_y'],
    ['UMAP 1', 'UMAP 2'],
    f'{umap_label}:N',
    "UMAP cluster",
    ['strain', clade_membership, umap_label],
    umap_label_color_domain,
    umap_label_color_range,
    legend_columns=umap_legend_columns,
    plot_legend=False,
)

In [None]:
poster_embeddings_by_cluster = alt.vconcat(
    (
        pca_by_cluster[0].properties(width=250, height=250) |
        mds_by_cluster[0].properties(width=250, height=250) |
        tsne_by_cluster[0].properties(width=250, height=250) |
        umap_by_cluster[0].properties(width=250, height=250)
    ).resolve_scale(color='independent'),
    (
        pca_by_cluster[1].properties(width=250, height=250, title= f"Normalized VI: {accuracy_by_method['pca']:.2f}") |
        mds_by_cluster[1].properties(width=250, height=250, title= f"Normalized VI: {accuracy_by_method['mds']:.2f}") |
        tsne_by_cluster[1].properties(width=250, height=250, title= f"Normalized VI: {accuracy_by_method['t-sne']:.2f}") |
        umap_by_cluster[1].properties(width=250, height=250, title= f"Normalized VI: {accuracy_by_method['umap']:.2f}")
    ).resolve_scale(color='independent')
).configure_legend(disable=True).configure_axis(grid=False).configure_view(stroke=None)

poster_embeddings_by_cluster