# Plot frequency errors by delay type for natural A/H3N2 or simulated A/H3N2-like populations 

In [None]:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
sns.set_style("ticks")

In [None]:
small_frequency_threshold = snakemake.params.small_frequency_threshold

In [None]:
large_frequency_threshold = snakemake.params.large_frequency_threshold

In [None]:
realistic_delay_type = snakemake.params.realistic_delay_type

## Compare clade frequencies

Compare the frequencies of clades present at non-zero frequency in the "no delay" analysis compared to frequencies of the same clades with "ideal" and "observed" or "realistic" delays.

In [None]:
all_clade_frequencies = pd.read_csv(
    snakemake.input.tip_attributes,
    sep="\t",
    parse_dates=["timepoint"],
    usecols=["delay_type", "timepoint", "clade_membership", "frequency", "delta_month"],
)

Select only records for a single forecast horizon (delta month), since we want to analyze differences between initial frequencies of clades at each timepoint based on their delay scenario.

In [None]:
clades = all_clade_frequencies[
    all_clade_frequencies["delta_month"] == 12
].drop(columns=["delta_month"]).reset_index(drop=True)

In [None]:
clades.head()

In [None]:
clades["delay_type"].value_counts()

In [None]:
no_delay_clades = clades.query("delay_type == 'none'").copy()

In [None]:
no_delay_clades.shape

In [None]:
(no_delay_clades["frequency"] < small_frequency_threshold).sum()

In [None]:
((no_delay_clades["frequency"] >= small_frequency_threshold) &
 (no_delay_clades["frequency"] < large_frequency_threshold)).sum()

In [None]:
(no_delay_clades["frequency"] >= large_frequency_threshold).sum()

In [None]:
clades_by_delay = pd.pivot(
    clades,
    values=["frequency"],
    index=["timepoint", "clade_membership"],
    columns=["delay_type"],
).fillna(0).reset_index()

In [None]:
clades_by_delay.head()

In [None]:
clades_by_delay.columns = [
    "_".join([value for value in column if value])
    for column in clades_by_delay.columns
]

In [None]:
clades_by_delay.head()

In [None]:
clades_by_delay.groupby("timepoint").agg({
    "frequency_none": "sum",
    "frequency_ideal": "sum",
    f"frequency_{realistic_delay_type}": "sum",
})

In [None]:
clades_by_delay.shape

In [None]:
total_clades_no_delay = sum(clades_by_delay["frequency_none"] > 0)

In [None]:
total_clades_no_delay

In [None]:
total_clades_ideal_delay = sum(clades_by_delay["frequency_ideal"] > 0)

In [None]:
total_clades_ideal_delay

In [None]:
total_clades_ideal_delay / total_clades_no_delay

In [None]:
total_clades_realistic_delay = sum(clades_by_delay[f"frequency_{realistic_delay_type}"] > 0)

In [None]:
total_clades_realistic_delay

In [None]:
total_clades_realistic_delay / total_clades_no_delay

## Compare clade frequencies

In [None]:
clades_by_delay

In [None]:
max_clade_frequency = clades_by_delay.loc[
    :,
    ["frequency_none", "frequency_ideal", f"frequency_{realistic_delay_type}"]
].max().max()

In [None]:
max_clade_frequency

In [None]:
max_clade_frequency_threshold = max_clade_frequency + (max_clade_frequency / 10)

In [None]:
x_clades = y_clades = np.linspace(0, max_clade_frequency_threshold, 10)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), dpi=200)

ax1.plot(
    clades_by_delay["frequency_none"],
    clades_by_delay["frequency_ideal"],
    "o",
    color="#999999",
    alpha=0.25,
)

ax1.plot(
    x_clades,
    y_clades,
    color="#000000",
    alpha=0.25,
    zorder=-10,
)

ax1.set_xlabel("Clade frequency without delay")
ax1.set_ylabel("Clade frequency with ideal delay")

ax1.set_aspect('equal', 'box')

ax2.plot(
    clades_by_delay["frequency_none"],
    clades_by_delay[f"frequency_{realistic_delay_type}"],
    "o",
    color="#999999",
    alpha=0.25,
)

ax2.plot(
    x_clades,
    y_clades,
    color="#000000",
    alpha=0.25,
    zorder=-10,
)

ax2.set_xlabel("Clade frequency without delay")
ax2.set_ylabel(f"Clade frequency with {realistic_delay_type} delay")

ax2.set_aspect('equal', 'box')

sns.despine()

plt.tight_layout()

In [None]:
clades_by_delay["ideal_error"] = clades_by_delay["frequency_none"] - clades_by_delay["frequency_ideal"]

In [None]:
clades_by_delay[f"{realistic_delay_type}_error"] = clades_by_delay["frequency_none"] - clades_by_delay[f"frequency_{realistic_delay_type}"]

In [None]:
bins_min = min(clades_by_delay["ideal_error"].min(), clades_by_delay[f"{realistic_delay_type}_error"].min())

In [None]:
bins_max = max(clades_by_delay["ideal_error"].max(), clades_by_delay[f"{realistic_delay_type}_error"].max())

In [None]:
bins_min

In [None]:
bins_max

In [None]:
clades_by_delay[f"{realistic_delay_type}_error"].describe()

In [None]:
bins_min_all = clades_by_delay[f"{realistic_delay_type}_error"].mean() - clades_by_delay[f"{realistic_delay_type}_error"].std()

In [None]:
bins_max_all = clades_by_delay[f"{realistic_delay_type}_error"].mean() + clades_by_delay[f"{realistic_delay_type}_error"].std()

In [None]:
bins_min_all

In [None]:
bins_max_all

In [None]:
all_clades_bins = np.arange(bins_min_all - 0.005, bins_max_all + 0.005, 0.0025)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 2), dpi=200)

ax.axvline(
    x=0,
    label="none",
    color="C0",
)

ax.hist(
    clades_by_delay["ideal_error"],
    bins=all_clades_bins,
    label="ideal",
    alpha=0.5,
    color="C1",
)

ax.hist(
    clades_by_delay[f"{realistic_delay_type}_error"],
    bins=all_clades_bins,
    label=realistic_delay_type,
    alpha=0.5,
    color="C2",
)

ax.set_xlabel("Clade frequency error (without delay - with delay)")
ax.set_ylabel("Number of clades")

ax.text(
    0.25,
    0.25,
    "overestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.text(
    0.75,
    0.25,
    "underestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.legend(
    title="Delay type",
    frameon=False,
)

plt.tight_layout()
sns.despine()

In [None]:
clades_by_delay["ideal_error"].median()

In [None]:
clades_by_delay["ideal_error"].mean()

In [None]:
clades_by_delay["ideal_error"].std()

In [None]:
clades_by_delay["ideal_error"].var()

In [None]:
clades_by_delay[f"{realistic_delay_type}_error"].median()

In [None]:
clades_by_delay[f"{realistic_delay_type}_error"].mean()

In [None]:
clades_by_delay[f"{realistic_delay_type}_error"].std()

In [None]:
clades_by_delay[f"{realistic_delay_type}_error"].var()

In [None]:
clades_by_delay[f"{realistic_delay_type}_error"].var() / clades_by_delay["ideal_error"].var()

## Define small and large clades

Small clades represent some minimum inital frequency up to the limit of the larger clades. Large clades have a higher minimum frequency.

In [None]:
small_clades = clades_by_delay.query(
    f"(frequency_none >= {small_frequency_threshold}) & (frequency_none < {large_frequency_threshold})"
)

In [None]:
small_clades.shape

In [None]:
small_clades_bins = np.arange(
    small_clades.loc[:, "realistic_error"].min(),
    small_clades.loc[:, "realistic_error"].max() + 0.005,
    0.0025
)

In [None]:
large_clades_bins = np.arange(bins_min, bins_max + 0.005, 0.005)

In [None]:
large_clades = clades_by_delay.query(f"frequency_none >= {large_frequency_threshold}")

In [None]:
large_clades.shape

In [None]:
clades_by_delay.shape

In [None]:
clades_by_delay.query("frequency_none < 0.01").shape

In [None]:
clades_by_delay.query("frequency_none < 0.01").shape[0] / clades_by_delay.shape[0]

In [None]:
clades_by_delay.query("frequency_none < 0.01").groupby("timepoint")["frequency_none"].sum().sort_values()

In [None]:
clades_by_delay.query("frequency_none < 0.01").groupby("timepoint")["frequency_none"].sum().sort_values().mean()

In [None]:
clades_by_delay.query("frequency_none >= 0.01").shape

In [None]:
clades_by_delay.query("frequency_none >= 0.05").shape

In [None]:
clades_by_delay.query("frequency_none >= 0.1").shape

In [None]:
clades_by_delay.query("frequency_none >= 0.15").shape

In [None]:
large_clades

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 2.5), dpi=200)

ax.axvline(
    x=0,
    label="none",
    color="C0",
)

ax.hist(
    small_clades["ideal_error"],
    bins=small_clades_bins,
    label="ideal",
    alpha=0.5,
    color="C1",
)

ax.hist(
    small_clades[f"{realistic_delay_type}_error"],
    bins=small_clades_bins,
    label=realistic_delay_type,
    alpha=0.5,
    color="C2",
)

ax.set_xlabel("Clade frequency error (without delay - with delay)")
ax.set_ylabel(r"""Number of clades
({small_frequency_threshold}% $\leq$ frequency $<${large_frequency_threshold}%)""".format(
    small_frequency_threshold=int(small_frequency_threshold * 100),
    large_frequency_threshold=int(large_frequency_threshold * 100),
))

ax.text(
    0.25,
    0.25,
    "overestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.text(
    0.75,
    0.25,
    "underestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.legend(
    title="Delay type",
    frameon=False,
)

plt.tight_layout()
sns.despine()

In [None]:
small_clades["ideal_error"].min()

In [None]:
small_clades["ideal_error"].max()

In [None]:
small_clades["ideal_error"].median()

In [None]:
small_clades["ideal_error"].mean()

In [None]:
small_clades["ideal_error"].std()

In [None]:
small_clades["ideal_error"].var()

In [None]:
(small_clades["ideal_error"] > 0).sum()

In [None]:
(small_clades["ideal_error"] > 0).sum() / small_clades.shape[0]

In [None]:
(small_clades[f"{realistic_delay_type}_error"] > 0).sum()

In [None]:
(small_clades[f"{realistic_delay_type}_error"] > 0).sum() / small_clades.shape[0]

In [None]:
small_clades[f"{realistic_delay_type}_error"].min()

In [None]:
small_clades[f"{realistic_delay_type}_error"].max()

In [None]:
small_clades[f"{realistic_delay_type}_error"].median()

In [None]:
small_clades[f"{realistic_delay_type}_error"].mean()

In [None]:
small_clades[f"{realistic_delay_type}_error"].std()

In [None]:
small_clades[f"{realistic_delay_type}_error"].std() * 3

In [None]:
small_clades[f"{realistic_delay_type}_error"].var()

In [None]:
small_clades[f"{realistic_delay_type}_error"].var() / large_clades["ideal_error"].var()

In [None]:
small_clades["ideal_error"].std() / small_clades[f"{realistic_delay_type}_error"].std()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 2), dpi=200)

ax.axvline(
    x=0,
    label="none",
    color="C0",
)

ax.hist(
    large_clades["ideal_error"],
    bins=large_clades_bins,
    label="ideal",
    alpha=0.5,
    color="C1",
)

ax.hist(
    large_clades[f"{realistic_delay_type}_error"],
    bins=large_clades_bins,
    label=realistic_delay_type,
    alpha=0.5,
    color="C2",
)

ax.set_xlabel("Clade frequency error (without delay - with delay)")
ax.set_ylabel(r"""Number of clades
($\geq${large_frequency_threshold}% frequency)""".format(large_frequency_threshold=int(large_frequency_threshold * 100)))

ax.text(
    0.25,
    0.25,
    "overestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.text(
    0.75,
    0.25,
    "underestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.legend(
    title="Delay type",
    frameon=False,
)

plt.tight_layout()
sns.despine()

In [None]:
large_clades["ideal_error"].min()

In [None]:
large_clades["ideal_error"].max()

In [None]:
large_clades["ideal_error"].median()

In [None]:
large_clades["ideal_error"].mean()

In [None]:
large_clades["ideal_error"].std()

In [None]:
large_clades["ideal_error"].var()

In [None]:
large_clades[f"{realistic_delay_type}_error"].min()

In [None]:
large_clades[f"{realistic_delay_type}_error"].max()

In [None]:
large_clades[f"{realistic_delay_type}_error"].median()

In [None]:
large_clades[f"{realistic_delay_type}_error"].mean()

In [None]:
large_clades[f"{realistic_delay_type}_error"].std()

In [None]:
large_clades[f"{realistic_delay_type}_error"].std() * 3

In [None]:
large_clades[f"{realistic_delay_type}_error"].var()

In [None]:
large_clades[f"{realistic_delay_type}_error"].var() / large_clades["ideal_error"].var()

In [None]:
large_clades["ideal_error"].std() / large_clades[f"{realistic_delay_type}_error"].std()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3), dpi=150, sharex=True, sharey=True)

ax1.plot(
    large_clades["frequency_none"],
    large_clades["ideal_error"],
    "o",
    alpha=0.25,
    color="C1",
)
ax1.axhline(y=0, color="#000000", zorder=-10, linewidth=1)

ax1.set_xlabel("Clade frequency")
ax1.set_ylabel("Clade frequency error\nwith ideal delay")

ax2.plot(
    large_clades["frequency_none"],
    large_clades[f"{realistic_delay_type}_error"],
    "o",
    alpha=0.25,
    color="C2",
)
ax2.axhline(y=0, color="#000000", zorder=-10, linewidth=1)

ax2.set_xlabel("Clade frequency")
ax2.set_ylabel(f"Clade frequency error\nwith {realistic_delay_type} delay")

sns.despine()

plt.tight_layout()

Composite figure for manuscript.

In [None]:
fig = plt.figure(figsize=(6.5, 8), dpi=200)

gs = GridSpec(3, 2, figure=fig)

ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, :])
ax4 = fig.add_subplot(gs[2, :])

# Shared values

# Calculate the maximum clade frequency and round up to the next unit of 10
# to set the upper bound on the axes tick labels.
max_clade_frequency = int(
    np.ceil(
        max(
            clades_by_delay["frequency_none"].max(),
            clades_by_delay["frequency_ideal"].max(),
            clades_by_delay[f"frequency_{realistic_delay_type}"].max()
        ) * 10
    ) * 10
)
clade_frequency_ticks = list(range(0, max_clade_frequency + 1, 10))
clade_frequency_tick_labels = [f"{tick:.0f}%" for tick in clade_frequency_ticks]

x_clades = y_clades = np.linspace(0, max_clade_frequency, 10)

# Panel A

ax1.plot(
    clades_by_delay["frequency_none"] * 100,
    clades_by_delay["frequency_ideal"] * 100,
    "o",
    color="#999999",
    zorder=10,
    alpha=0.25,
)

ax1.plot(
    x_clades,
    y_clades,
    color="#000000",
    zorder=-10,
    alpha=0.25,
)

ax1.set_xlabel("Clade frequency without delay")
ax1.set_ylabel("Clade frequency\nwith ideal delay")
ax1.set_xticks(
    ticks=clade_frequency_ticks,
    labels=clade_frequency_tick_labels,
)
ax1.set_yticks(
    ticks=clade_frequency_ticks,
    labels=clade_frequency_tick_labels,
)

#ax1.set_aspect('equal', 'box')

# Panel B

ax2.plot(
    clades_by_delay["frequency_none"] * 100,
    clades_by_delay[f"frequency_{realistic_delay_type}"] * 100,
    "o",
    color="#999999",
    zorder=10,
    alpha=0.25,
)

ax2.plot(
    x_clades,
    y_clades,
    color="#000000",
    zorder=-10,
    alpha=0.25,
)

ax2.set_xlabel("Clade frequency without delay")
ax2.set_ylabel(f"Clade frequency\nwith {realistic_delay_type} delay")
ax2.set_xticks(
    ticks=clade_frequency_ticks,
    labels=clade_frequency_tick_labels,
)
ax2.set_yticks(
    ticks=clade_frequency_ticks,
    labels=clade_frequency_tick_labels,
)

#ax2.set_aspect('equal', 'box')

# Panel C

ax3.axvline(
    x=0,
    label="none",
    color="C0",
    zorder=-10,
)

ax3.hist(
    small_clades["ideal_error"] * 100,
    bins=small_clades_bins * 100,
    label="ideal",
    alpha=0.5,
    color="C1",
)

ax3.hist(
    small_clades[f"{realistic_delay_type}_error"] * 100,
    bins=small_clades_bins * 100,
    label=realistic_delay_type,
    alpha=0.5,
    color="C2",
)

ax3.axvline(
    x=small_clades["ideal_error"].median() * 100,
    color="C1",
    zorder=-10,
    linestyle="dashed",
    linewidth=1,
)

ax3.axvline(
    x=small_clades[f"{realistic_delay_type}_error"].median() * 100,
    color="C2",
    zorder=-10,
    linestyle="dashed",
    linewidth=1,
)

ax3.set_xlabel("Clade frequency error (without delay - with delay)")
ax3.set_ylabel(r"""Number of clades
({small_frequency_threshold}% $\leq$ frequency $<${large_frequency_threshold}%)""".format(
    small_frequency_threshold=int(small_frequency_threshold * 100),
    large_frequency_threshold=int(large_frequency_threshold * 100),
))

ax3.text(
    0.15,
    0.3,
    "overestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax3.transAxes,
)

ax3.text(
    0.85,
    0.3,
    "underestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax3.transAxes,
)

ax3.legend(
    title="Delay type",
    frameon=False,
)

ax3_xticks = ax3.get_xticks()
ax3_tick_labels = [f"{tick:.0f}%" for tick in ax3_xticks]
ax3.set_xticks(
    ax3_xticks,
    ax3_tick_labels,
)

# Panel D

ax4.axvline(
    x=0,
    label="none",
    color="C0",
    zorder=-10,
)

ax4.hist(
    large_clades["ideal_error"] * 100,
    bins=large_clades_bins * 100,
    label="ideal",
    alpha=0.5,
    color="C1",
)

ax4.hist(
    large_clades[f"{realistic_delay_type}_error"] * 100,
    bins=large_clades_bins * 100,
    label=realistic_delay_type,
    alpha=0.5,
    color="C2",
)

ax4.axvline(
    x=large_clades["ideal_error"].median() * 100,
    color="C1",
    zorder=-10,
    linestyle="dashed",
    linewidth=1,
)

ax4.axvline(
    x=large_clades[f"{realistic_delay_type}_error"].median() * 100,
    color="C2",
    zorder=-10,
    linestyle="dashed",
    linewidth=1,
)

ax4.set_xlabel("Clade frequency error (without delay - with delay)")
ax4.set_ylabel(r"""Number of clades
($\geq${large_frequency_threshold}% frequency)""".format(large_frequency_threshold=int(large_frequency_threshold * 100)))

ax4.text(
    0.15,
    0.3,
    "overestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax4.transAxes,
)

ax4.text(
    0.85,
    0.3,
    "underestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax4.transAxes,
)

ax4_xticks = ax4.get_xticks()
ax4_tick_labels = [f"{tick:.0f}%" for tick in ax4_xticks]
ax4.set_xticks(
    ax4_xticks,
    ax4_tick_labels,
)

# Annotate panel labels.
panel_labels_dict = {
    "weight": "bold",
    "size": 14
}
plt.figtext(0.0, 0.97, "A", **panel_labels_dict)
plt.figtext(0.5, 0.97, "B", **panel_labels_dict)
plt.figtext(0.0, 0.64, "C", **panel_labels_dict)
plt.figtext(0.0, 0.31, "D", **panel_labels_dict)

sns.despine()

gs.tight_layout(fig, pad=0.5)
plt.savefig(snakemake.output.current_frequency_errors_by_delay)