# Figure 1. Model Schematic

Summarize tree and frequencies for two timepoints from simulated data for Figure 1.

Note: [this notebook is executed by Snakemake](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration) and expects to have a global `snakemake` variable that provides input and output files and optionally params.

In [None]:
# Define inputs.
tree_for_timepoint_t = snakemake.input.tree_for_timepoint_t
tree_for_timepoint_u = snakemake.input.tree_for_timepoint_u
frequencies_for_timepoint_t = snakemake.input.frequencies_for_timepoint_t
frequencies_for_timepoint_u = snakemake.input.frequencies_for_timepoint_u

# Define outputs.
distance_model_figure = snakemake.output.figure

In [None]:
"""
# Define inputs.
tree_for_timepoint_t = "../results/auspice/flu_simulated_simulated_sample_3_2029-10-01_tree.json"
tree_for_timepoint_u = "../results/auspice/flu_simulated_simulated_sample_3_2030-10-01_tree.json"
frequencies_for_timepoint_t = "../results/auspice/flu_simulated_simulated_sample_3_2029-10-01_tip-frequencies.json"
frequencies_for_timepoint_u = "../results/auspice/flu_simulated_simulated_sample_3_2030-10-01_tip-frequencies.json"

# Define outputs.
distance_model_figure = "../manuscript/figures/distance-based-fitness-model.pdf"
"""

In [None]:
from augur.titer_model import TiterCollection
from augur.utils import json_to_tree
import datetime
import json
import matplotlib as mpl
import matplotlib.dates as mdates
from matplotlib import gridspec
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import numpy as np
import pandas as pd
import seaborn as sns

from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import squareform

from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN
from treetime.utils import numeric_date

%matplotlib inline

In [None]:
np.random.seed(314159)

In [None]:
sns.set_style("ticks")

In [None]:
# Display figures at a reasonable default size.
mpl.rcParams['figure.figsize'] = (6, 4)

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
    
# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 100

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14

mpl.rc('text', usetex=False)

In [None]:
tip_size = 10
end_date = 2004.3

In [None]:
def float_to_datestring(time):
    """Convert a floating point date from TreeTime `numeric_date` to a date string
    """
    # Extract the year and remainder from the floating point date.
    year = int(time)
    remainder = time - year

    # Calculate the day of the year (out of 365 + 0.25 for leap years).
    tm_yday = int(remainder * 365.25)
    if tm_yday == 0:
        tm_yday = 1

    # Construct a date object from the year and day of the year.
    date = datetime.datetime.strptime("%s-%s" % (year, tm_yday), "%Y-%j")

    # Build the date string with zero-padded months and days.
    date_string = "%s-%.2i-%.2i" % (date.year, date.month, date.day)

    return date_string

In [None]:
def plot_tree_by_datetime(tree, color_by_trait=None, size_by_trait=None, initial_branch_width=5, tip_size=10,
              start_date=None, end_date=None, include_color_bar=False, ax=None, colorbar_ax=None,
              earliest_node_date=None, default_color="#cccccc", default_color_branch="#999999", override_y_values=None,
              cmap=None, default_size=0.001, plot_projection_from_date=None, plot_projection_to_date=None,
              projection_attr="projected_frequency", projection_line_threshold=1e-2, size_scaler=1e3):
    """Plot a BioPython Phylo tree in the BALTIC-style.
    """
    # Plot H3N2 tree in BALTIC style from Bio.Phylo tree.
    if override_y_values is None:
        override_y_values = {}
        
    yvalues = [node.yvalue for node in tree.find_clades()]
    y_span = max(yvalues)
    y_unit = y_span / float(len(yvalues))

    # Setup colors.
    if color_by_trait:
        trait_name = color_by_trait
        
        if cmap is None:
            traits = [k.attr[trait_name] for k in tree.find_clades() if trait_name in k.attr]
            norm = mpl.colors.Normalize(min(traits), max(traits))
            cmap = mpl.cm.viridis

    #
    # Setup the figure grid.
    #

    if ax is None:
        if include_color_bar:
            fig = plt.figure(figsize=(8, 6), facecolor='w')
            gs = gridspec.GridSpec(2, 1, height_ratios=[14, 1], width_ratios=[1], hspace=0.1, wspace=0.1)
            ax = fig.add_subplot(gs[0])
            colorbar_ax = fig.add_subplot(gs[1])
        else:
            fig = plt.figure(figsize=(8, 4), facecolor='w')
            gs = gridspec.GridSpec(1, 1)
            ax = fig.add_subplot(gs[0])

    L=len([k for k in tree.find_clades() if k.is_terminal()])

    # Setup arrays for tip and internal node coordinates.
    tip_circles_x = []
    tip_circles_y = []
    tip_circles_color = []
    tip_circle_sizes = []
    node_circles_x = []
    node_circles_y = []
    node_circles_color = []
    node_line_widths = []
    node_line_segments = []
    node_line_colors = []
    branch_line_segments = []
    branch_line_widths = []
    branch_line_colors = []
    branch_line_labels = []
    projection_line_segments = []

    for k in tree.find_clades(): ## iterate over objects in tree
        x=k.attr["collection_date_ordinal"] ## or from x position determined earlier
        
        if earliest_node_date and x < earliest_node_date:
            continue
        
        if k.name in override_y_values:
            y = override_y_values[k.name]
        else:
            y = y_span - k.yvalue ## get y position from .drawTree that was run earlier, but could be anything else

        if k.parent is None:
            xp = None
        else:
            xp=k.parent.attr["collection_date_ordinal"] ## get x position of current object's parent

        #if x==None: ## matplotlib won't plot Nones, like root
        #    x=0.0
        if xp==None:
            xp=x

        c = default_color
        if color_by_trait and trait_name in k.attr:
            if isinstance(cmap, dict):
                c = cmap[k.attr[trait_name]]
            else:
                c = cmap(norm(k.attr[trait_name]))

        branchWidth=initial_branch_width
        if k.is_terminal(): ## if leaf...            
            if size_by_trait is not None and size_by_trait in k.attr:
                s = (size_scaler * np.sqrt(k.attr.get(size_by_trait, default_size)))
            else:
                s = tip_size ## tip size can be fixed

            tip_circle_sizes.append(s)
            tip_circles_x.append(x)
            tip_circles_y.append(y)
            tip_circles_color.append(c)
            
            if plot_projection_to_date is not None and plot_projection_from_date is not None:
                if k.attr.get(projection_attr, 0.0) > projection_line_threshold:
                    future_s = (size_scaler * np.sqrt(k.attr.get(projection_attr)))
                    future_x = plot_projection_to_date + np.random.randint(-60, 0)
                    future_y = y
                    future_c = c

                    tip_circle_sizes.append(future_s)
                    tip_circles_x.append(future_x)
                    tip_circles_y.append(future_y)
                    tip_circles_color.append(future_c)
                    
                    projection_line_segments.append([(x + 1, y), (future_x, y)])
            
        else: ## if node...
            k_leaves = [child
                        for child in k.find_clades()
                        if child.is_terminal()]

            # Scale branch widths by the number of tips.
            branchWidth += initial_branch_width * len(k_leaves) / float(L)

            if len(k.clades)==1:
                node_circles_x.append(x)
                node_circles_y.append(y)
                node_circles_color.append(c)

            ax.plot([x,x],[y_span - k.clades[-1].yvalue, y_span - k.clades[0].yvalue], lw=branchWidth, color=default_color_branch, ls='-', zorder=9, solid_capstyle='round')

        branch_line_segments.append([(xp, y), (x, y)])
        branch_line_widths.append(branchWidth)
        branch_line_colors.append(default_color_branch)

    branch_lc = LineCollection(branch_line_segments, zorder=9)
    branch_lc.set_color(branch_line_colors)
    branch_lc.set_linewidth(branch_line_widths)
    branch_lc.set_label(branch_line_labels)
    branch_lc.set_linestyle("-")
    ax.add_collection(branch_lc)
    
    if len(projection_line_segments) > 0:
        projection_lc = LineCollection(projection_line_segments, zorder=-10)
        projection_lc.set_color("#cccccc")
        projection_lc.set_linewidth(1)
        projection_lc.set_linestyle("--")
        projection_lc.set_alpha(0.5)
        ax.add_collection(projection_lc)

    # Add circles for tips and internal nodes.
    tip_circle_sizes = np.array(tip_circle_sizes)
    ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes, facecolor=tip_circles_color, edgecolors='#000000', linewidths=0.5, alpha=0.75, zorder=11) ## plot circle for every tip
    #ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes*1.75, facecolor="#000000", edgecolor='none', zorder=10) ## plot black circle underneath
    ax.scatter(node_circles_x, node_circles_y, facecolor=node_circles_color, s=50, edgecolor='none', zorder=10, lw=2, marker='|') ## mark every node in the tree to highlight that it's a multitype tree

    #ax.set_ylim(-10, y_span - 300)

    ax.spines['top'].set_visible(False) ## no axes
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)

    ax.tick_params(axis='y',size=0)
    ax.set_yticklabels([])

    if start_date:
        ax.set_xlim(left=start_date)

    if end_date:
        ax.set_xlim(right=end_date)

    if include_color_bar:
        cb1 = mpl.colorbar.ColorbarBase(
            colorbar_ax,
            cmap=cmap,
            norm=norm,
            orientation='horizontal'
        )
        cb1.set_label(color_by_trait)

    return ax, colorbar_ax

## Load trees

Load an auspice tree for both timepoint t and timepoint u. The first tree needs to be annotated with the projected frequency at time u and weighted distance to the future.

Both trees need to be annotated with amino acid sequences for the tips as an `aa_sequence` key in each tip's `attr` attribute.

In [None]:
with open(tree_for_timepoint_t, "r") as fh:
    tree_json_for_t = json.load(fh)
    
tree_for_t = json_to_tree(tree_json_for_t)

In [None]:
latest_sample_date_in_t = max([node.attr["num_date"] for node in tree_for_t.find_clades(terminal=True)])

In [None]:
latest_sample_date_in_t

In [None]:
earliest_date_to_plot = latest_sample_date_in_t - 2.0

In [None]:
with open(tree_for_timepoint_u, "r") as fh:
    tree_json_for_u = json.load(fh)

In [None]:
tree_for_u = json_to_tree(tree_json_for_u)

In [None]:
tree_for_u

Annotate ordinal collection dates from floating point dates on both trees.

In [None]:
for node in tree_for_t.find_clades():
    node.attr["collection_date_ordinal"] = pd.to_datetime(float_to_datestring(node.attr["num_date"])).toordinal()

In [None]:
for node in tree_for_u.find_clades():
    node.attr["collection_date_ordinal"] = pd.to_datetime(float_to_datestring(node.attr["num_date"])).toordinal()

## Load frequencies

Load tip frequencies from auspice. These should include a `projected_pivot` key and one or more pivots after that timepoint for each tip.

In [None]:
with open(frequencies_for_timepoint_t, "r") as fh:
    frequencies_for_t = json.load(fh)

In [None]:
with open(frequencies_for_timepoint_u, "r") as fh:
    frequencies_for_u = json.load(fh)

In [None]:
pivots = frequencies_for_t.pop("pivots")

In [None]:
projection_pivot = frequencies_for_t.pop("projection_pivot")

In [None]:
projection_pivot_index_for_t = pivots.index(projection_pivot)

In [None]:
frequency_records_for_t = []
for sample, sample_frequencies in frequencies_for_t.items():
    for pivot, sample_frequency in zip(pivots, sample_frequencies["frequencies"]):
        frequency_records_for_t.append({
            "strain": sample,
            "timepoint": float_to_datestring(pivot),
            "pivot": pivot,
            "frequency": sample_frequency
        })

In [None]:
frequency_df_for_t = pd.DataFrame(frequency_records_for_t)

In [None]:
frequency_df_for_t["timepoint"] = pd.to_datetime(frequency_df_for_t["timepoint"])

Repeat the above analysis to get observed frequencies at timepoint u. We ignore all projected frequencies from this later timepoint, however.

In [None]:
pivots_for_u = frequencies_for_u.pop("pivots")

In [None]:
projection_pivot_for_u = frequencies_for_u.pop("projection_pivot")

In [None]:
projection_pivot_index = pivots_for_u.index(projection_pivot_for_u)

In [None]:
pivots_for_u[:projection_pivot_index + 1]

In [None]:
frequency_records_for_u = []
for sample, sample_frequencies in frequencies_for_u.items():
    for pivot, sample_frequency in zip(pivots_for_u, sample_frequencies["frequencies"]):
        # Ignore projected frequencies from timepoint u.
        if pivot <= projection_pivot_for_u:
            frequency_records_for_u.append({
                "strain": sample,
                "timepoint": float_to_datestring(pivot),
                "pivot": pivot,
                "frequency": sample_frequency
            })

In [None]:
frequency_df_for_u = pd.DataFrame(frequency_records_for_u)

In [None]:
frequency_df_for_u["timepoint"] = pd.to_datetime(frequency_df_for_u["timepoint"])

In [None]:
frequency_df_for_u.head()

Annotate trees with frequencies at corresponding timepoints. For the tree at timepoint t, annotate both current and projected frequencies. For the tree at timepoint u, annotate the current frequencies.

In [None]:
pivots[projection_pivot_index_for_t]

In [None]:
projection_pivot_index_for_t

In [None]:
max_frequency = 0.5

In [None]:
for tip in tree_for_t.find_clades(terminal=True):
    tip.attr["frequency_at_t"] = min(frequencies_for_t[tip.name]["frequencies"][projection_pivot_index_for_t], max_frequency)
    tip.attr["projected_frequency_at_u"] = min(frequencies_for_t[tip.name]["frequencies"][-1], max_frequency)

In [None]:
projection_pivot

In [None]:
for tip in tree_for_u.find_clades(terminal=True):
    if tip.attr["num_date"] > projection_pivot:
        tip.attr["frequency_at_u"] = min(frequencies_for_u[tip.name]["frequencies"][projection_pivot_index], max_frequency)
    else:
        tip.attr["frequency_at_u"] = 0.0

In [None]:
tips_with_nonzero_frequencies = set()

for tip in tree_for_t.find_clades(terminal=True):
    if tip.attr["frequency_at_t"] > 0:
        tips_with_nonzero_frequencies.add(tip.name)

for tip in tree_for_u.find_clades(terminal=True):
    if tip.attr["frequency_at_u"] > 0:
        tips_with_nonzero_frequencies.add(tip.name)

In [None]:
len(tips_with_nonzero_frequencies)

## t-SNE to cluster sequences

Cluster sequences for tips in the latest tree which should be a super set of tips in the earliest tree. We only consider tips with a projected frequency greater than zero from timepoint t to u or tips that are collected after timepoint t. Clustering happens in one dimension through t-SNE dimensionality reduction. This is a simple way of identifying sequences that are "close" to each other in a low dimensional space for comparison of tips within and between timepoints.

In [None]:
projected_frequency_by_sample_from_t = {
    node.name: node.attr.get("projected_frequency", 0.0)
    for node in tree_for_t.find_clades(terminal=True)
}

In [None]:
nodes = [
    node for node in tree_for_u.find_clades(terminal=True)
    if node.attr["num_date"] > earliest_date_to_plot
]

In [None]:
total_nodes = len(nodes)

In [None]:
total_nodes

In [None]:
distances = np.zeros((total_nodes, total_nodes))
for i, node_a in enumerate(nodes):
    node_a_array = np.frombuffer(node_a.attr["aa_sequence"].encode(), 'S1')
    
    for j, node_b in enumerate(nodes):
        if node_a.name == node_b.name:
            distance = 0.0
        elif distances[j, i] > 0:
            distance = distances[j, i]
        else:
            node_b_array = np.frombuffer(node_b.attr["aa_sequence"].encode(), 'S1')
            distance = (node_a_array != node_b_array).sum()
            
        distances[i, j] = distance

In [None]:
sns.heatmap(
    distances,
    cmap="cividis",
    robust=True,
    square=True,
    xticklabels=False,
    yticklabels=False
)

In [None]:
X_embedded = TSNE(n_components=2, learning_rate=400, metric="precomputed", random_state=314).fit_transform(distances)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.plot(X_embedded[:, 0], X_embedded[:, 1], ".", alpha=0.25)

In [None]:
clustering = DBSCAN(eps=10, min_samples=20).fit(X_embedded)

In [None]:
df = pd.DataFrame(X_embedded, columns=["dimension 0", "dimension 1"])

In [None]:
df["label"] = clustering.labels_

In [None]:
label_normalizer = mpl.colors.Normalize(df["label"].min(), df["label"].max())

In [None]:
cmap = list(reversed(sns.color_palette("Paired", n_colors=len(df["label"].unique()))))

In [None]:
df["color"] = df["label"].apply(lambda value: cmap[value])

In [None]:
cmap_for_tree = dict(df.loc[:, ["label", "color"]].values)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.scatter(
    df["dimension 0"],
    df["dimension 1"],
    alpha=0.25,
    c=df["color"]
)
ax.set_xlabel("dimension 0")
ax.set_ylabel("dimension 1")
plt.tight_layout()

In [None]:
X_embedded_1d = TSNE(n_components=1, learning_rate=500, metric="precomputed", random_state=314).fit_transform(distances)

In [None]:
X_embedded_1d.shape

Annotate nodes in both trees with ranks from t-SNE.

In [None]:
tree_t_nodes_by_name = {node.name: node for node in tree_for_t.find_clades(terminal=True)}

In [None]:
rank_records = []
for i, node in enumerate(nodes):
    node.attr["rank"] = X_embedded_1d[i, 0]
    node.attr["label"] = clustering.labels_[i]
    
    if node.name in tree_t_nodes_by_name:
        tree_t_nodes_by_name[node.name].attr["rank"] = X_embedded_1d[i, 0]
        tree_t_nodes_by_name[node.name].attr["label"] = clustering.labels_[i]
    
    rank_records.append({
        "strain": node.name,
        "rank": node.attr["rank"],
        "label": node.attr["label"]
    })

In [None]:
rank_df = pd.DataFrame(rank_records)

In [None]:
rank_normalizer = mpl.colors.Normalize(rank_df["rank"].min(), rank_df["rank"].max())

In [None]:
rank_df["color"] = rank_df["label"].apply(lambda value: cmap[value])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 0.5))
ax.scatter(X_embedded_1d[:, 0], np.zeros_like(X_embedded_1d[:, 0]), marker=".", alpha=0.04, c=rank_df["color"].values.tolist())
ax.set_ylim(-0.001, 0.001)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 1))
ax.scatter(X_embedded_1d[:, 0], rank_df["label"], marker=".", alpha=0.02, c=rank_df["color"].values.tolist())
#ax.set_ylim(-0.001, 0.001)

## Annotate t-SNE-based cluster information for both sets of frequencies.

In [None]:
rank_frequency_df_for_t = frequency_df_for_t.merge(
    rank_df,
    on="strain"
).sort_values(["label", "strain", "timepoint"])

In [None]:
rank_frequency_df_for_t["ordinal_timepoint"] = rank_frequency_df_for_t["timepoint"].apply(lambda value: value.toordinal())

In [None]:
rank_frequency_df_for_u = frequency_df_for_u.merge(
    rank_df,
    on="strain"
).sort_values(["label", "strain", "timepoint"])

In [None]:
rank_frequency_df_for_u["ordinal_timepoint"] = rank_frequency_df_for_u["timepoint"].apply(lambda value: value.toordinal())

In [None]:
start_date = pd.to_datetime("2028-10-01").toordinal()
end_date = pd.to_datetime("2030-11-15").toordinal()
frequency_end_date = pd.to_datetime("2030-10-01").toordinal()

timepoint_t = pd.to_datetime(float_to_datestring(projection_pivot)).toordinal()
timepoint_u = pd.to_datetime(float_to_datestring(projection_pivot_for_u)).toordinal()

In [None]:
frequency_steps = [0, 0.25, 0.5, 0.75, 1.0]

## Plot tree

In [None]:
yvalues = [node.yvalue for node in tree_for_t.find_clades(terminal=True)]
y_span = max(yvalues)

In [None]:
fig = plt.figure(figsize=(12, 8), facecolor='w')
gs = gridspec.GridSpec(2, 2, height_ratios=[1, 0.5], width_ratios=[1, 1], hspace=0.25, wspace=0.1)

# Tree plot for timepoint t

tree_ax = fig.add_subplot(gs[0])    
tree_ax, colorbar_ax = plot_tree_by_datetime(
    tree_for_t,
    color_by_trait="label",
    size_by_trait="frequency_at_t",
    ax=tree_ax,
    start_date=start_date,
    end_date=end_date,
    tip_size=tip_size,
    initial_branch_width=1,
    plot_projection_from_date=timepoint_t,
    plot_projection_to_date=timepoint_u,
    projection_attr="projected_frequency_at_u",
    cmap=cmap_for_tree
)
tree_ax.set_ylim(4000, 6700)
#tree_ax.set_ylim(400, 750)

years = mdates.YearLocator()
years_fmt = mdates.DateFormatter("%y")
months = mdates.MonthLocator()
tree_ax.xaxis.set_major_locator(years)
tree_ax.xaxis.set_major_formatter(years_fmt)
tree_ax.xaxis.set_minor_locator(months)
tree_ax.format_xdata = mdates.DateFormatter("%b %y")

tree_ax.text(0.46, 1.0, "$\mathbf{x}(t)$", 
             horizontalalignment='center',
             verticalalignment='center',
             transform=tree_ax.transAxes,
             fontdict={"fontsize": 14})
tree_ax.text(0.94, 1.0, "$\mathbf{\hat{x}}(u)$", 
             horizontalalignment='center',
             verticalalignment='center',
             transform=tree_ax.transAxes,
             fontdict={"fontsize": 14})
tree_ax.axvline(x=timepoint_t, ymax=0.96, color="#999999", linestyle="--", alpha=0.5)
tree_ax.axvline(x=timepoint_u, ymax=0.96, color="#999999", linestyle="--", alpha=0.5)

# Frequency plot for timepoint t

frequency_ax = fig.add_subplot(gs[2])
baseline = np.zeros_like(pivots)
for strain, strain_df in rank_frequency_df_for_t.groupby(["label", "strain"]):
    frequency_ax.fill_between(
        strain_df["ordinal_timepoint"].values,
        baseline, baseline + strain_df["frequency"].values,
        color=strain_df["color"].unique()[0]
    )
    baseline = baseline + strain_df["frequency"].values

frequency_ax.axvline(x=timepoint_t, color="#999999", linestyle="--")
frequency_ax.axvline(x=timepoint_u, color="#999999", linestyle="--")    

frequency_ax.text(
    0.72,
    0.995,
    "Forecast",
    horizontalalignment="center",
    verticalalignment="center",
    transform=frequency_ax.transAxes,
    fontdict={"fontsize": 12}
)

frequency_ax.set_yticks(frequency_steps)
frequency_ax.set_yticklabels(['{:3.0f}%'.format(x*100) for x in frequency_steps])
frequency_ax.set_ylabel("Frequency")
frequency_ax.set_xlabel("Date")

frequency_ax.set_xlim(start_date, end_date)
frequency_ax.set_ylim(bottom=0.0)

frequency_ax.xaxis.set_major_locator(years)
frequency_ax.xaxis.set_major_formatter(years_fmt)
frequency_ax.xaxis.set_minor_locator(months)
frequency_ax.format_xdata = mdates.DateFormatter("%b %y")

# Tree plot for timepoint u

tree_u_ax = fig.add_subplot(gs[1])    
tree_u_ax, colorbar_u_ax = plot_tree_by_datetime(
    tree_for_u,
    color_by_trait="label",
    size_by_trait="frequency_at_u",
    ax=tree_u_ax,
    start_date=start_date,
    end_date=end_date,
    tip_size=tip_size,
    initial_branch_width=1,
    cmap=cmap_for_tree
)
tree_u_ax.set_ylim(4100, 6700)

#tree_u_ax.set_ylim(400, 750)

tree_u_ax.xaxis.set_major_locator(years)
tree_u_ax.xaxis.set_major_formatter(years_fmt)
tree_u_ax.xaxis.set_minor_locator(months)
tree_u_ax.format_xdata = mdates.DateFormatter("%b %y")

tree_u_ax.text(0.46, 1.0, "$\mathbf{x}(t)$", 
             horizontalalignment='center',
             verticalalignment='center',
             transform=tree_u_ax.transAxes,
             fontdict={"fontsize": 14})
tree_u_ax.text(0.94, 1.0, "$\mathbf{x}(u)$", 
             horizontalalignment='center',
             verticalalignment='center',
             transform=tree_u_ax.transAxes,
             fontdict={"fontsize": 14})
tree_u_ax.axvline(x=timepoint_t, ymax=0.96, color="#999999", linestyle="--", alpha=0.5)
tree_u_ax.axvline(x=timepoint_u, ymax=0.96, color="#999999", linestyle="--", alpha=0.5)

# Frequency plot for timepoint u

frequency_u_ax = fig.add_subplot(gs[3])
baseline_u = np.zeros_like(pivots[2:])
for strain, strain_df in rank_frequency_df_for_u.groupby(["label", "strain"]):
    frequency_u_ax.fill_between(
        strain_df["ordinal_timepoint"].values[:projection_pivot_index + 1],
        baseline_u, baseline_u + strain_df["frequency"].values[:projection_pivot_index + 1],
        color=strain_df["color"].unique()[0]
    )
    baseline_u = baseline_u + strain_df["frequency"].values[:projection_pivot_index + 1]

frequency_u_ax.axvline(x=timepoint_t, color="#999999", linestyle="--")
frequency_u_ax.axvline(x=timepoint_u, color="#999999", linestyle="--")    

frequency_u_ax.text(
    0.72,
    0.995,
    "Retrospective",
    horizontalalignment="center",
    verticalalignment="center",
    transform=frequency_u_ax.transAxes,
    fontdict={"fontsize": 12}
)

frequency_u_ax.set_yticks(frequency_steps)
frequency_u_ax.set_yticklabels(['{:3.0f}%'.format(x*100) for x in frequency_steps])
frequency_u_ax.set_ylabel("Frequency")
frequency_u_ax.set_xlabel("Date")

frequency_u_ax.set_xlim(start_date, end_date)
frequency_u_ax.set_ylim(bottom=0.0)

frequency_u_ax.xaxis.set_major_locator(years)
frequency_u_ax.xaxis.set_major_formatter(years_fmt)
frequency_u_ax.xaxis.set_minor_locator(months)
frequency_u_ax.format_xdata = mdates.DateFormatter("%b %y")

fig.autofmt_xdate(rotation=0, ha="center")

# Annotate panel labels.
panel_labels_dict = {
    "weight": "bold",
    "size": 14
}
plt.figtext(0.0, 0.98, "A", **panel_labels_dict)
plt.figtext(0.0, 0.36, "B", **panel_labels_dict)
plt.figtext(0.5, 0.98, "C", **panel_labels_dict)
plt.figtext(0.5, 0.36, "D", **panel_labels_dict)

gs.tight_layout(fig, h_pad=1.0)

plt.savefig(distance_model_figure)

In [None]:
projected_frequency_records = []
projected_frequency_at_u = []
projected_colors = []
for tip in tree_for_t.find_clades(terminal=True):
    if "projected_frequency_at_u" in tip.attr and tip.attr["projected_frequency_at_u"] > 1e-2:
        projected_frequency_at_u.append(tip.attr["projected_frequency_at_u"])
        projected_colors.append(cmap_for_tree[tip.attr["label"]])
        
        projected_frequency_records.append({
            "frequency": tip.attr["projected_frequency_at_u"],
            "group": tip.attr["label"]
        })

projected_frequency_df = pd.DataFrame(projected_frequency_records)

In [None]:
observed_frequency_records = []
observed_frequency_at_u = []
observed_colors = []
for tip in tree_for_u.find_clades(terminal=True):
    if "frequency_at_u" in tip.attr and tip.attr["frequency_at_u"] > 0.0:
        observed_frequency_at_u.append(tip.attr["frequency_at_u"])
        observed_colors.append(cmap_for_tree[tip.attr["label"]])
        
        observed_frequency_records.append({
            "frequency": tip.attr["frequency_at_u"],
            "group": tip.attr["label"]
        })

observed_frequency_df = pd.DataFrame(observed_frequency_records)

In [None]:
projected_frequency_arrays = []
projected_frequency_colors = []
for group, df in projected_frequency_df.groupby("group"):
    projected_frequency_arrays.append(df["frequency"].values)
    projected_frequency_colors.append(cmap_for_tree[group])

In [None]:
projected_frequency_rank = []
projected_frequency_frequencies = []
projected_frequency_colors = []
for index, row in projected_frequency_df.groupby("group")["frequency"].sum().sort_values(ascending=False).reset_index().iterrows():
    projected_frequency_rank.append(row["group"])
    projected_frequency_frequencies.append(row["frequency"])
    projected_frequency_colors.append(cmap_for_tree[row["group"]])

In [None]:
observed_frequency_arrays = []
observed_frequency_colors = []
for group, df in observed_frequency_df.groupby("group"):
    observed_frequency_arrays.append(df["frequency"].values)
    observed_frequency_colors.append(cmap_for_tree[group])

In [None]:
observed_frequency_rank = []
observed_frequency_frequencies = []
observed_frequency_colors = []
for index, row in observed_frequency_df.groupby("group")["frequency"].sum().sort_values(ascending=False).reset_index().iterrows():
    if row["frequency"] > 0.05:
        observed_frequency_rank.append(row["group"])
        observed_frequency_frequencies.append(row["frequency"])
        observed_frequency_colors.append(cmap_for_tree[row["group"]])

In [None]:
rank_to_index = {
    7: 0,
    6: 1,
    8: 2,
    4: 3
}

In [None]:
rank_normalizer = mpl.colors.Normalize(X_embedded_1d.min(), X_embedded_1d.max())

In [None]:
size_scaler = 1e3
default_size = 0.001
projection_attr = "projected_frequency"
projection_line_threshold = 1e-2
plot_projection_to_date = timepoint_u

fig = plt.figure(figsize=(12, 8), facecolor='w')
gs = gridspec.GridSpec(2, 2, height_ratios=[1, 0.5], width_ratios=[1, 1], hspace=0.25, wspace=0.1)

# Plot for timepoint t
tip_circles_x_for_t = []
tip_circles_y_for_t = []
tip_circles_sizes_for_t = []
tip_circles_colors_for_t = []
projection_line_segments = []

t_ax = fig.add_subplot(gs[0])
for node in tree_for_t.find_clades(terminal=True):
    if "rank" in node.attr:
        x = node.attr["collection_date_ordinal"]
        y = node.attr["rank"]
        tip_circles_x_for_t.append(x)
        tip_circles_y_for_t.append(y)
        tip_circles_sizes_for_t.append(size_scaler * np.sqrt(node.attr.get("frequency_at_t", default_size)))
        tip_circles_colors_for_t.append(mpl.cm.gist_gray(rank_normalizer(y)))
        
        if node.attr.get(projection_attr, 0.0) > projection_line_threshold:
            future_s = (size_scaler * np.sqrt(node.attr.get(projection_attr)))
            future_x = plot_projection_to_date + np.random.randint(-60, 0)
            future_y = y

            tip_circles_sizes_for_t.append(future_s)
            tip_circles_x_for_t.append(future_x)
            tip_circles_y_for_t.append(future_y)
            tip_circles_colors_for_t.append(mpl.cm.gist_gray(rank_normalizer(y)))

            projection_line_segments.append([(x + 1, y), (future_x, y)])

t_ax.scatter(
    tip_circles_x_for_t,
    tip_circles_y_for_t,
    s=tip_circles_sizes_for_t,
    facecolor=tip_circles_colors_for_t,
    edgecolors='#000000',
    linewidths=0.5,
    alpha=0.75,
    zorder=11
)

projection_lc = LineCollection(projection_line_segments, zorder=-10)
projection_lc.set_color("#cccccc")
projection_lc.set_linewidth(1)
projection_lc.set_linestyle("--")
projection_lc.set_alpha(0.5)
t_ax.add_collection(projection_lc)

t_ax.axvline(x=timepoint_t, linestyle="--", color="#999999")
t_ax.axvline(x=timepoint_u, linestyle="--", color="#999999")

t_ax.spines['top'].set_visible(False) ## no axes
t_ax.spines['right'].set_visible(False)
t_ax.spines['left'].set_visible(False)
t_ax.tick_params(axis='y',size=0)
t_ax.set_yticklabels([])

t_ax.xaxis.set_major_locator(years)
t_ax.xaxis.set_major_formatter(years_fmt)
t_ax.xaxis.set_minor_locator(months)
t_ax.format_xdata = mdates.DateFormatter("%b %y")

t_ax.set_xlim(start_date, end_date)

# Frequency plot for timepoint t

frequency_ax = fig.add_subplot(gs[2])
baseline = np.zeros_like(pivots)
for (rank, strain), strain_df in rank_frequency_df_for_t.groupby(["rank", "strain"]):
    frequency_ax.fill_between(
        strain_df["ordinal_timepoint"].values,
        baseline, baseline + strain_df["frequency"].values,
        color=mpl.cm.gist_gray(rank_normalizer(rank))
    )
    baseline = baseline + strain_df["frequency"].values

frequency_ax.axvline(x=timepoint_t, color="#999999", linestyle="--")
frequency_ax.axvline(x=timepoint_u, color="#999999", linestyle="--")    

frequency_ax.text(
    0.72,
    0.99,
    "Projection",
    horizontalalignment="center",
    verticalalignment="center",
    transform=frequency_ax.transAxes,
    fontdict={"fontsize": 10}
)

frequency_ax.set_yticks(frequency_steps)
frequency_ax.set_yticklabels(['{:3.0f}%'.format(x*100) for x in frequency_steps])
frequency_ax.set_ylabel("Frequency")
frequency_ax.set_xlabel("Date")

frequency_ax.set_xlim(start_date, end_date)
frequency_ax.set_ylim(bottom=0.0)

frequency_ax.xaxis.set_major_locator(years)
frequency_ax.xaxis.set_major_formatter(years_fmt)
frequency_ax.xaxis.set_minor_locator(months)
frequency_ax.format_xdata = mdates.DateFormatter("%b %y")

# Plot for timepoint u
u_ax = fig.add_subplot(gs[1])

tip_circles_x_for_u = []
tip_circles_y_for_u = []
tip_circles_sizes_for_u = []
tip_circles_colors_for_u = []

for node in tree_for_t.find_clades(terminal=True):
    if "rank" in node.attr:
        x = node.attr["collection_date_ordinal"]
        y = node.attr["rank"]
        tip_circles_x_for_u.append(x)
        tip_circles_y_for_u.append(y)
        tip_circles_sizes_for_u.append(size_scaler * np.sqrt(node.attr.get("frequency_at_t", default_size)))
        tip_circles_colors_for_u.append(mpl.cm.gist_gray(rank_normalizer(y)))

for node in tree_for_u.find_clades(terminal=True):
    if "rank" in node.attr:
        tip_circles_x_for_u.append(node.attr["collection_date_ordinal"])
        tip_circles_y_for_u.append(node.attr["rank"])
        tip_circles_sizes_for_u.append(1e3 * np.sqrt(node.attr.get("frequency_at_u", default_size)))
        tip_circles_colors_for_u.append(mpl.cm.gist_gray(rank_normalizer(node.attr["rank"])))
        
u_ax.scatter(
    tip_circles_x_for_u,
    tip_circles_y_for_u,
    s=tip_circles_sizes_for_u,
    facecolor=tip_circles_colors_for_u,
    edgecolors='#000000',
    linewidths=0.5,
    alpha=0.75,
    zorder=11
)

u_ax.axvline(x=timepoint_t, linestyle="--", color="#999999")
u_ax.axvline(x=timepoint_u, linestyle="--", color="#999999")

u_ax.spines['top'].set_visible(False) ## no axes
u_ax.spines['right'].set_visible(False)
u_ax.spines['left'].set_visible(False)
u_ax.tick_params(axis='y',size=0)
u_ax.set_yticklabels([])

u_ax.xaxis.set_major_locator(years)
u_ax.xaxis.set_major_formatter(years_fmt)
u_ax.xaxis.set_minor_locator(months)
u_ax.format_xdata = mdates.DateFormatter("%b %y")

u_ax.set_xlim(end_date, start_date)

# Frequency plot for timepoint u

frequency_u_ax = fig.add_subplot(gs[3])
baseline_u = np.zeros_like(pivots[2:])
for (rank, strain), strain_df in rank_frequency_df_for_u.groupby(["rank", "strain"]):
    frequency_u_ax.fill_between(
        strain_df["ordinal_timepoint"].values[:projection_pivot_index + 1],
        baseline_u, baseline_u + strain_df["frequency"].values[:projection_pivot_index + 1],
        color=mpl.cm.gist_gray(rank_normalizer(rank))
    )
    baseline_u = baseline_u + strain_df["frequency"].values[:projection_pivot_index + 1]

frequency_u_ax.axvline(x=timepoint_t, color="#999999", linestyle="--")
frequency_u_ax.axvline(x=timepoint_u, color="#999999", linestyle="--")    

frequency_u_ax.text(
    0.28,
    0.99,
    "Observed",
    horizontalalignment="center",
    verticalalignment="center",
    transform=frequency_u_ax.transAxes,
    fontdict={"fontsize": 10}
)

frequency_u_ax.set_yticks(frequency_steps)
frequency_u_ax.set_yticklabels(['{:3.0f}%'.format(x*100) for x in frequency_steps])
frequency_u_ax.set_xlabel("Date")

frequency_u_ax.set_xlim(end_date, start_date)
frequency_u_ax.set_ylim(bottom=0.0)

frequency_u_ax.xaxis.set_major_locator(years)
frequency_u_ax.xaxis.set_major_formatter(years_fmt)
frequency_u_ax.xaxis.set_minor_locator(months)
frequency_u_ax.format_xdata = mdates.DateFormatter("%b %y")

frequency_u_ax.spines['top'].set_visible(False)
frequency_u_ax.spines['right'].set_visible(False)
frequency_u_ax.spines['left'].set_visible(False)
frequency_u_ax.tick_params(axis='y',size=0)
frequency_u_ax.set_yticklabels([])

fig.autofmt_xdate(rotation=0, ha="center")
gs.tight_layout(fig, h_pad=1.0)