# Plot clade frequency errors by delay type and forecast horizon for a given population type

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
sns.set_style("ticks")

In [None]:
# Display figures at a reasonable default size.
mpl.rcParams['figure.figsize'] = (6, 4)

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
    
# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 150

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14

mpl.rc('text', usetex=False)

In [None]:
realistic_delay_type = snakemake.params.realistic_delay_type

In [None]:
delay_order = ("none", "ideal", realistic_delay_type)

In [None]:
large_frequency_lower_threshold = snakemake.params.large_frequency_lower_threshold

In [None]:
large_frequency_upper_threshold = snakemake.params.large_frequency_upper_threshold

## Load clade frequencies

In [None]:
frequencies = pd.read_csv(
    snakemake.input.frequencies,
    sep="\t",
    parse_dates=["timepoint", "future_timepoint"],
).rename(
    columns={"delta_month": "horizon"},
)

In [None]:
frequencies["horizon"] = frequencies["horizon"].astype(int)

In [None]:
frequencies.head()

In [None]:
frequencies["frequency_error"] = frequencies["observed_frequency"] - frequencies["projected_frequency"]

In [None]:
frequencies["abs_frequency_error"] = np.abs(frequencies["frequency_error"])

In [None]:
frequencies.shape

## Annotate initial frequency without delay to all clades

In [None]:
frequencies_without_delay = frequencies.loc[
    frequencies["delay_type"] == "none",
    ("clade_membership", "timepoint", "future_timepoint", "frequency")
].drop_duplicates()

In [None]:
frequencies_without_delay.shape

In [None]:
frequencies_without_delay = frequencies_without_delay.rename(
    columns={"frequency": "frequency_without_delay"},
)

In [None]:
frequencies_without_delay.head()

In [None]:
frequencies.shape

In [None]:
frequencies = frequencies.merge(
    frequencies_without_delay,
    how="left",
    on=["timepoint", "future_timepoint", "clade_membership"],
)

In [None]:
pd.isnull(frequencies["frequency_without_delay"]).sum()

In [None]:
frequencies["frequency_without_delay"] = frequencies["frequency_without_delay"].fillna(0)

In [None]:
frequencies.head()

In [None]:
frequencies.loc[:, ["frequency_without_delay", "frequency"]]

In [None]:
frequencies[frequencies["frequency_without_delay"] < 0.1]

In [None]:
frequencies[
    (frequencies["timepoint"] == "2011-04-01") &
    (frequencies["future_timepoint"] == "2011-07-01") &
    (frequencies["delay_type"] == "none")
]

In [None]:
frequencies.shape

In [None]:
(frequencies["frequency"] >= 0.95).sum()

In [None]:
(frequencies["frequency"] >= 0.01).sum()

In [None]:
(frequencies["frequency"] >= 0.05).sum()

In [None]:
(frequencies["frequency"] >= 0.1).sum()

In [None]:
frequencies["frequency"].max()

In [None]:
frequencies["frequency"].min()

In [None]:
((frequencies["frequency"] > 0.1) & (frequencies["frequency"] < 0.95)).sum()

In [None]:
((frequencies["frequency_without_delay"] > 0.1) & (frequencies["frequency_without_delay"] < 0.95)).sum()

In [None]:
((frequencies["frequency_without_delay"] == 0) & (frequencies["observed_frequency"] == 0)).sum()

In [None]:
((frequencies["frequency"] == 0) & (frequencies["observed_frequency"] == 0)).sum()

In [None]:
distinct_large_clades_with_delay = set(frequencies.loc[
    (frequencies["frequency"] > 0.1) & (frequencies["frequency"] < 0.95),
    "clade_membership"
].drop_duplicates().values)

In [None]:
distinct_large_clades_without_delay = set(frequencies.loc[
    (frequencies["frequency_without_delay"] > 0.1) & (frequencies["frequency_without_delay"] < 0.95),
    "clade_membership"
].drop_duplicates().values)

In [None]:
len(distinct_large_clades_with_delay)

In [None]:
len(distinct_large_clades_without_delay)

In [None]:
distinct_large_clades_with_delay - distinct_large_clades_without_delay

In [None]:
distinct_large_clades_without_delay - distinct_large_clades_with_delay

In [None]:
frequencies[frequencies["clade_membership"].isin(distinct_large_clades_with_delay - distinct_large_clades_without_delay)]

## Plot clade frequency errors by delay type and forecast horizon

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=200)

sns.boxplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    hue_order=delay_order,
    data=frequencies,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    data=frequencies,
    hue_order=delay_order,
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

ax.text(
    0.5,
    0.95,
    "underestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.text(
    0.5,
    0.05,
    "overestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="lower left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel("Clade frequency error (all clades)")

sns.despine()
plt.tight_layout()

In [None]:
frequencies.groupby([
    "horizon",
    "delay_type"
]).aggregate({
    "frequency_error": ["mean", "median", "std"],
})

Plot clade frequency errors for larger clades only.

In [None]:
large_frequency_lower_threshold

In [None]:
large_frequency_upper_threshold

In [None]:
large_frequencies = frequencies.query(
    f"(frequency_without_delay >= {large_frequency_lower_threshold}) & (frequency_without_delay <= {large_frequency_upper_threshold})"
).copy()

In [None]:
large_frequencies["frequency_without_delay"].describe()

In [None]:
large_frequencies.shape

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 5), dpi=200)

sns.boxplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    hue_order=delay_order,
    data=large_frequencies,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    data=large_frequencies,
    hue_order=delay_order,
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

ax.text(
    0.5,
    0.97,
    "underestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.text(
    0.5,
    0.05,
    "overestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="lower left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel(r"""Clade frequency error
({large_frequency_lower_threshold}% $\leq$ initial frequency)""".format(
    large_frequency_lower_threshold=int(large_frequency_lower_threshold * 100),
))

sns.despine()
plt.tight_layout()

plt.savefig(snakemake.output.forecast_frequency_errors_by_delay_and_horizon)

In [None]:
table_template_header = r"""
\begin{tabular*}{0.7\textwidth}{rrrrr}
\toprule
        &            & \multicolumn{3}{c}{Clade frequency error} \\
Horizon & Delay type & Mean & Median & Std Dev \\
\midrule
"""

absolute_table_template_header = r"""
\begin{tabular*}{0.7\textwidth}{rrrrr}
\toprule
        &            & \multicolumn{3}{c}{Absolute clade frequency error} \\
Horizon & Delay type & Mean & Median & Std Dev \\
\midrule
"""

total_absolute_table_template_header = r"""
\begin{tabular*}{0.7\textwidth}{rrrrr}
\toprule
        &            & \multicolumn{3}{c}{Total absolute clade frequency error} \\
Horizon & Delay type & Mean & Median & Std Dev \\
\midrule
"""

table_template_row = r"{horizon} & {delay_type} & {mean:.2f} & {median:.2f} & {std:.2f} \\"

table_template_footer = r"""
\bottomrule
\end{tabular*}
"""

In [None]:
large_frequencies_errors_summary = large_frequencies.groupby(["horizon", "delay_type"], sort=False).agg({
    "frequency_error": ["mean", "median", "std"],
}).round(2)

In [None]:
large_frequencies_errors_summary.columns = [
    column[1]
    for column in large_frequencies_errors_summary.columns
]

In [None]:
large_frequencies_errors_summary = large_frequencies_errors_summary.reset_index()

In [None]:
with open(snakemake.output.forecast_frequency_errors_summary_table, "w", encoding="utf-8") as oh:
    oh.write(table_template_header + "\n")
    
    for record in large_frequencies_errors_summary.to_dict(orient="records"):
        oh.write(table_template_row.format(**record) + "\n")
    
    oh.write(table_template_footer + "\n")

In [None]:
large_frequencies_errors_summary

## Plot absolute clade frequency errors

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 5), dpi=200)

sns.boxplot(
    x="horizon",
    y="abs_frequency_error",
    hue="delay_type",
    hue_order=delay_order,
    data=large_frequencies,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="abs_frequency_error",
    hue="delay_type",
    data=large_frequencies,
    hue_order=delay_order,
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="upper left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel(r"""Absolute clade frequency error
({large_frequency_lower_threshold}% $\leq$ initial frequency)""".format(
    large_frequency_lower_threshold=int(large_frequency_lower_threshold * 100),
))

sns.despine()
plt.tight_layout()

plt.savefig(snakemake.output.absolute_forecast_frequency_errors_by_delay_and_horizon)

In [None]:
large_frequencies_abs_errors_summary = large_frequencies.groupby(["horizon", "delay_type"], sort=False).agg({
    "abs_frequency_error": ["mean", "median", "std"],
}).round(2)

In [None]:
large_frequencies_abs_errors_summary.columns = [
    column[1]
    for column in large_frequencies_abs_errors_summary.columns
]

In [None]:
large_frequencies_abs_errors_summary = large_frequencies_abs_errors_summary.reset_index()

In [None]:
with open(snakemake.output.absolute_forecast_frequency_errors_summary_table, "w", encoding="utf-8") as oh:
    oh.write(absolute_table_template_header + "\n")
    
    for record in large_frequencies_abs_errors_summary.to_dict(orient="records"):
        oh.write(table_template_row.format(**record) + "\n")
    
    oh.write(table_template_footer + "\n")

In [None]:
large_frequencies_abs_errors_summary

## Plot total absolute forecast errors

Sum absolute forecast errors per future timepoint across all clades and plot by horizon and delay.

In [None]:
large_frequencies

In [None]:
total_absolute_forecast_errors = large_frequencies.groupby([
    "timepoint",
    "future_timepoint",
    "horizon",
    "delay_type",
]).aggregate({
    "abs_frequency_error": "sum",
}).reset_index().rename(
    columns={
        "abs_frequency_error": "total_absolute_forecast_error",
    }
)

In [None]:
total_absolute_forecast_errors

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 5), dpi=200)

sns.boxplot(
    x="horizon",
    y="total_absolute_forecast_error",
    hue="delay_type",
    hue_order=delay_order,
    data=total_absolute_forecast_errors,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="total_absolute_forecast_error",
    hue="delay_type",
    data=total_absolute_forecast_errors,
    hue_order=delay_order,
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="upper left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel(r"""Total absolute clade frequency error
({large_frequency_lower_threshold}% $\leq$ initial frequency)""".format(
    large_frequency_lower_threshold=int(large_frequency_lower_threshold * 100),
))

sns.despine()
plt.tight_layout()

plt.savefig(snakemake.output.total_absolute_forecast_frequency_errors_by_delay_and_horizon)

In [None]:
large_frequencies_total_abs_errors_summary = total_absolute_forecast_errors.groupby(
    ["horizon", "delay_type"],
    sort=False,
).agg({
    "total_absolute_forecast_error": ["mean", "median", "std"],
}).round(2)

In [None]:
large_frequencies_total_abs_errors_summary.columns = [
    column[1]
    for column in large_frequencies_total_abs_errors_summary.columns
]

In [None]:
large_frequencies_total_abs_errors_summary = large_frequencies_total_abs_errors_summary.reset_index()

In [None]:
with open(snakemake.output.total_absolute_forecast_frequency_errors_summary_table, "w", encoding="utf-8") as oh:
    oh.write(total_absolute_table_template_header + "\n")
    
    for record in large_frequencies_total_abs_errors_summary.to_dict(orient="records"):
        oh.write(table_template_row.format(**record) + "\n")
    
    oh.write(table_template_footer + "\n")

In [None]:
large_frequencies_total_abs_errors_summary

## Plot effect of interventions on total absolute clade frequency error by timepoint and clade

In [None]:
total_absolute_forecast_errors.head()

In [None]:
status_quo = total_absolute_forecast_errors.query(
    f"(horizon == 12) & (delay_type == '{realistic_delay_type}')"
).loc[
    :,
    ("future_timepoint", "total_absolute_forecast_error")
].copy()

In [None]:
status_quo["intervention"] = "status_quo"

In [None]:
status_quo.head()

In [None]:
status_quo.shape

In [None]:
improved_vaccine_dev = total_absolute_forecast_errors.query(
    f"(horizon == 6) & (delay_type == '{realistic_delay_type}')"
).loc[
    :,
    ("future_timepoint", "total_absolute_forecast_error")
].copy()

In [None]:
improved_vaccine_dev["intervention"] = "improved_vaccine"

In [None]:
improved_vaccine_dev.head()

In [None]:
improved_vaccine_dev.shape

In [None]:
improved_surveillance = total_absolute_forecast_errors.query(
    "(horizon == 12) & (delay_type == 'ideal')"
).loc[
    :,
    ("future_timepoint", "total_absolute_forecast_error")
].copy()

In [None]:
improved_surveillance["intervention"] = "improved_surveillance"

In [None]:
improved_surveillance.head()

In [None]:
improved_surveillance.shape

In [None]:
improved_vaccine_and_surveillance = total_absolute_forecast_errors.query(
    "(horizon == 6) & (delay_type == 'ideal')"
).loc[
    :,
    ("future_timepoint", "total_absolute_forecast_error")
].copy()

In [None]:
improved_vaccine_and_surveillance["intervention"] = "improved_vaccine_and_surveillance"

In [None]:
improved_vaccine_and_surveillance.head()

In [None]:
improved_vaccine_and_surveillance.shape

In [None]:
interventions = pd.concat([
    status_quo,
    improved_vaccine_dev,
    improved_surveillance,
    improved_vaccine_and_surveillance,
])

In [None]:
interventions.head()

In [None]:
interventions.pivot_table(
    index=["future_timepoint"],
    columns=["intervention"],
    values="total_absolute_forecast_error",
)

In [None]:
interventions_by_timepoint = interventions.pivot_table(
    index=["future_timepoint"],
    columns=["intervention"],
    values="total_absolute_forecast_error",
).dropna()

In [None]:
interventions_by_timepoint

In [None]:
interventions_by_timepoint["status_quo_vs_improved_vaccine"] = (
    interventions_by_timepoint["status_quo"] - interventions_by_timepoint["improved_vaccine"]
)

In [None]:
interventions_by_timepoint["status_quo_vs_improved_surveillance"] = (
    interventions_by_timepoint["status_quo"] - interventions_by_timepoint["improved_surveillance"]
)

In [None]:
interventions_by_timepoint["status_quo_vs_improved_vaccine_and_surveillance"] = (
    interventions_by_timepoint["status_quo"] - interventions_by_timepoint["improved_vaccine_and_surveillance"]
)

In [None]:
interventions_by_timepoint.reset_index()

In [None]:
differences_in_error_by_intervention = interventions_by_timepoint.reset_index().melt(
    id_vars=[
        "future_timepoint",
    ],
    value_vars=[
        "status_quo_vs_improved_vaccine",
        "status_quo_vs_improved_surveillance",
        "status_quo_vs_improved_vaccine_and_surveillance",
    ],
    value_name="difference_in_total_absolute_forecast_error",
)

In [None]:
differences_in_error_by_intervention.head()

In [None]:
differences_in_error_by_intervention["intervention_name"] = differences_in_error_by_intervention["intervention"].apply(
    lambda intervention: " ".join(intervention.replace("status_quo_vs_", "").split("_"))
)

In [None]:
differences_in_error_by_intervention.shape

In [None]:
intervention_order = [
    "improved vaccine",
    "improved surveillance",
    "improved vaccine and surveillance",
]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=200)

sns.boxplot(
    x="intervention_name",
    y="difference_in_total_absolute_forecast_error",
    data=differences_in_error_by_intervention,
    order=intervention_order,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="intervention_name",
    y="difference_in_total_absolute_forecast_error",
    data=differences_in_error_by_intervention,
    order=intervention_order,
    color="#000000",    
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

ax.set_xlabel("Intervention")
ax.set_ylabel("Difference in total absolute clade frequency\nerror per timepoint (status quo - intervention)")

sns.despine()
plt.tight_layout()

plt.savefig(snakemake.output.effects_of_realistic_interventions)

In [None]:
differences_in_error_by_intervention["intervention_name"].drop_duplicates()

In [None]:
differences_in_error_by_intervention.to_csv(
    snakemake.output.effects_of_realistic_interventions_source_table,
    sep="\t",
    index=False,
)

In [None]:
differences_in_error_by_intervention["intervention_improved_forecast"] = (
    differences_in_error_by_intervention["difference_in_total_absolute_forecast_error"] > 0
)

In [None]:
differences_in_error_by_intervention.head()

In [None]:
intervention_summary = differences_in_error_by_intervention.groupby("intervention_name", sort=False).agg({
    "difference_in_total_absolute_forecast_error": ["mean", "median", "std"],
    "intervention_improved_forecast": ["sum", "count"],
}).round(2)

In [None]:
intervention_summary

In [None]:
intervention_summary.columns = [
    column[1]
    for column in intervention_summary.columns
]

In [None]:
intervention_summary = intervention_summary.reset_index()

In [None]:
intervention_summary["proportion_improved"] = (intervention_summary["sum"] / intervention_summary["count"]).round(2)

In [None]:
intervention_summary = intervention_summary.drop(
    columns=["count"],
).rename(
    columns={"sum": "total_improved"}
)

In [None]:
intervention_summary

In [None]:
intervention_table_template_header = r"""
\begin{tabular*}{1.0\textwidth}{rrrrrr}
\toprule
             & \multicolumn{3}{c}{Total absolute clade frequency error} & \multicolumn{2}{c}{Forecasts improved} \\
Intervention & Mean & Median & Std Dev & Total & Proportion \\
\midrule
"""

intervention_table_template_row = r"{intervention_name} & {mean:.2f} & {median:.2f} & {std:.2f} & {total_improved} & {proportion_improved:.2f} \\"

intervention_table_template_footer = r"""
\bottomrule
\end{tabular*}
"""

In [None]:
with open(snakemake.output.effects_of_realistic_interventions_summary_table, "w", encoding="utf-8") as oh:
    oh.write(intervention_table_template_header + "\n")
    
    for record in intervention_summary.to_dict(orient="records"):
        oh.write(intervention_table_template_row.format(**record) + "\n")
    
    oh.write(intervention_table_template_footer + "\n")

Plot frequency errors by initial frequency.

In [None]:
large_frequencies

In [None]:
sns.lmplot(
    data=large_frequencies,
    x="frequency",
    y="forecast_error",
    hue="delay_type",
    col="horizon",
    col_wrap=2,
)

In [None]:
sns.lmplot(
    data=large_frequencies,
    x="frequency",
    y="absolute_forecast_error",
    hue="delay_type",
    col="horizon",
    col_wrap=2,
)

Calculate the error in the initial frequency per clade and timepoint. Then plot the forecast error as a function of the initial frequency error.

In [None]:
large_frequencies["initial_frequency_error"] = (
    large_frequencies["frequency_without_delay"] - large_frequencies["frequency"]
)

In [None]:
large_frequencies

In [None]:
sns.lmplot(
    data=large_frequencies.query("delay_type != 'none'"),
    x="initial_frequency_error",
    y="forecast_error",
    hue="delay_type",
    col="horizon",
    col_wrap=2,
)