In [None]:
import sys
sys.path.append("notebooks/scripts/")

In [None]:
import altair as alt
from altair_saver import save
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from Helpers import linking_tree_with_plots_brush

%matplotlib inline

## Define inputs, outputs, and parameters

In [None]:
colors_path = snakemake.input.colors
embeddings_path = snakemake.input.annotated_embeddings
accuracy_path = snakemake.input.accuracy_table
explained_variance_pca = snakemake.input.explained_variance_pca

In [None]:
interactive_chart_by_clades = snakemake.output.fullChart
static_chart_by_clades = snakemake.output.fullChartPNG

interactive_chart_by_clusters = snakemake.output.fullChartHDBSCAN
static_chart_by_clusters = snakemake.output.fullChartHDBSCANPNG

explained_variance_pca_chart = snakemake.output.Explained_variance_PCA
interactive_pca_chart = snakemake.output.PCA_Supplement
static_pca_chart = snakemake.output.PCA_Supplement_PNG

interactive_mds_chart = snakemake.output.MDS_Supplement
static_mds_chart = snakemake.output.MDS_Supplement_PNG

In [None]:
sns.set_style("ticks")

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 120

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rc('text', usetex=False)

## Load data

In [None]:
colors = pd.read_csv(colors_path, sep="\t", names=[i for i in range(0,101)])

In [None]:
colors.head()

In [None]:
embeddings_df = pd.read_csv(embeddings_path, sep="\t")

In [None]:
embeddings_df = embeddings_df.rename(
    columns={
        "num_date": "date",
        "y_value": "y"
    }
)

In [None]:
embeddings_df["clade_membership"] = embeddings_df["clade_membership"].fillna("unassigned")

In [None]:
embeddings_df.head()

In [None]:
embeddings_df.columns

In [None]:
accuracy_df = pd.read_csv(accuracy_path)

In [None]:
accuracy_df.head()

In [None]:
explained_variance_df = pd.read_csv(explained_variance_pca)

In [None]:
explained_variance_df.head()

## Setup colors

In [None]:
def build_color_range_for_domain(domain, colors, value_for_unassigned=None):
    # Rows are zero-indexed, so to get N colors, we select row N - 1.
    range_ = colors.loc[len(domain) - 1].dropna().values.tolist()
   
    # Replace known values for "unassigned" clade or cluster labels.
    index_for_unassigned = None
    if value_for_unassigned is not None and value_for_unassigned in domain:
        index_for_unassigned = domain.index(value_for_unassigned)
        range_[index_for_unassigned] = "#999999"
        
    return range_

In [None]:
domain = embeddings_df["clade_membership"].drop_duplicates().values.tolist()

In [None]:
range_ = build_color_range_for_domain(domain, colors)

## Plot PCA variance and embeddings

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.plot(
    explained_variance_df["principal components"],
    explained_variance_df["explained variance"],
    "o"
)

ax.set_xlabel("Principal Component")
ax.set_ylabel("Explained Variance")

ax.set_ylim(bottom=0)

plt.tight_layout()
plt.savefig(explained_variance_pca_chart)

In [None]:
explained_variance_PCA = explained_variance_df["explained variance"].values.tolist()

In [None]:
embeddings_df.head()

In [None]:
pca_charts = linking_tree_with_plots_brush(
    embeddings_df,
    ['pca1', 'pca2', 'pca3', 'pca4', 'pca5', 'pca6', 'pca7', 'pca8'],
    [
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0]*100,2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1]*100,2)) + ")",
        'PC 3 (Explained Variance : {}%'.format(round(explained_variance_PCA[2]*100,2)) + ")",
        'PC 4 (Explained Variance : {}%'.format(round(explained_variance_PCA[3]*100,2)) + ")",
        'PC 5 (Explained Variance : {}%'.format(round(explained_variance_PCA[4]*100,2)) + ")",
        'PC 6 (Explained Variance : {}%'.format(round(explained_variance_PCA[5]*100,2)) + ")",
        'PC 7 (Explained Variance : {}%'.format(round(explained_variance_PCA[6]*100,2)) + ")",
        'PC 8 (Explained Variance : {}%'.format(round(explained_variance_PCA[7]*100,2)) + ")",
    ],
    "clade_membership:N",
    ['strain', "clade_membership"],
    domain,
    range_
)

In [None]:
pca_chart = (pca_charts[0]) & (pca_charts[1] | pca_charts[2]) & (pca_charts[3] | pca_charts[4])
pca_chart

In [None]:
pca_chart.save(interactive_pca_chart)
save(pca_chart, static_pca_chart, scale_factor=2.0)

## Plot MDS embeddings

In [None]:
mds_charts = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2', 'mds3', 'mds4'],
    ["MDS 1", "MDS 2", "MDS 3", "MDS 4"],
    "clade_membership:N",
    ['strain', "clade_membership"],
    domain,
    range_
)

In [None]:
mds_chart = (mds_charts[0]) & (mds_charts[1] | mds_charts[2])
mds_chart

In [None]:
mds_chart.save(interactive_mds_chart)
save(mds_chart, static_mds_chart, scale_factor=2.0)

## Plot all embeddings by clade

In [None]:
data = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2', 'tsne_x', 'tsne_y', 'pca1', 'pca2', 'umap_x', 'umap_y'],
    [
        'MDS 1',
        'MDS 2',
        't-SNE 1',
        't-SNE 2',
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0]*100,2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1]*100,2)) + ")",
        'UMAP 1',
        'UMAP 2'
    ],
    'clade_membership:N',
    ['strain', 'clade_membership'],
    domain,
    range_
)

In [None]:
PCAMDS = data[3]|data[1]
TSNEUMAP = data[2]|data[4]
embeddings = alt.vconcat(PCAMDS,TSNEUMAP)
embeddings
fullChart = alt.vconcat(data[0],embeddings)
fullChart

In [None]:
fullChart.save(interactive_chart_by_clades)
save(fullChart, static_chart_by_clades, scale_factor=2.0)

## Plot all embeddings by cluster

In [None]:
pca_label_color_domain =  sorted(embeddings_df["pca_label"].drop_duplicates().values)
pca_label_color_range = build_color_range_for_domain(
    pca_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
pca_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['pca1', 'pca2'],
    [
        'PC 1 (Explained Variance : {}%'.format(round(explained_variance_PCA[0] * 100, 2)) + ")",
        'PC 2 (Explained Variance : {}%'.format(round(explained_variance_PCA[1] * 100, 2)) + ")"
    ],
    'pca_label:N',
    ['strain', 'clade_membership', 'pca_label'],
    pca_label_color_domain,
    pca_label_color_range,
)

In [None]:
mds_label_color_domain =  sorted(embeddings_df["mds_label"].drop_duplicates().values)
mds_label_color_range = build_color_range_for_domain(
    mds_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
mds_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['mds1', 'mds2'],
    ['MDS 1', 'MDS 2'],
    'mds_label:N',
    ['strain', 'clade_membership', 'mds_label'],
    mds_label_color_domain,
    mds_label_color_range,
)

In [None]:
tsne_label_color_domain =  sorted(embeddings_df["t-sne_label"].drop_duplicates().values)
tsne_label_color_range = build_color_range_for_domain(
    tsne_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
tsne_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['tsne_x', 'tsne_y'],
    ['t-SNE 1', 't-SNE 2'],
    't-sne_label:N',
    ['strain', 'clade_membership', 't-sne_label'],
    tsne_label_color_domain,
    tsne_label_color_range,
)

In [None]:
umap_label_color_domain =  sorted(embeddings_df["umap_label"].drop_duplicates().values)
umap_label_color_range = build_color_range_for_domain(
    umap_label_color_domain,
    colors,
    value_for_unassigned=-1,
)

In [None]:
umap_by_cluster = linking_tree_with_plots_brush(
    embeddings_df,
    ['umap_x', 'umap_y'],
    ['UMAP 1', 'UMAP 2'],
    'umap_label:N',
    ['strain', 'clade_membership', 'umap_label'],
    umap_label_color_domain,
    umap_label_color_range,
)

In [None]:
accuracy_by_method = dict(accuracy_df.loc[:, ["embedding", "normalized_VI"]].values)

In [None]:
accuracy_by_method

In [None]:
composed_pca_by_cluster = pca_by_cluster[0] | pca_by_cluster[1].properties(
    title= f"Normalized VI Value: {accuracy_by_method['pca']}"
)

composed_mds_by_cluster = mds_by_cluster[0] | mds_by_cluster[1].properties(
    title= f"Normalized VI Value: {accuracy_by_method['mds']}"
)

composed_tsne_by_cluster = tsne_by_cluster[0] | tsne_by_cluster[1].properties(
    title= f"Normalized VI Value: {accuracy_by_method['t-sne']}"
)

composed_umap_by_cluster = umap_by_cluster[0] | umap_by_cluster[1].properties(
    title= f"Normalized VI Value: {accuracy_by_method['umap']}"
)

In [None]:
pca_mds = alt.vconcat(composed_pca_by_cluster, composed_mds_by_cluster).resolve_scale(color='independent')
tsne_umap = alt.vconcat(composed_tsne_by_cluster, composed_umap_by_cluster).resolve_scale(color='independent')
full_chart_by_cluster = alt.vconcat(pca_mds, tsne_umap).resolve_scale(color='independent')
full_chart_by_cluster

In [None]:
full_chart_by_cluster.save(interactive_chart_by_clusters)
save(full_chart_by_cluster, static_chart_by_clusters, scale_factor=2.0)