# Plot clade frequency errors by delay type and forecast horizon for natural H3N2 populations 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
sns.set_style("ticks")

## Load clade frequencies

In [None]:
frequencies = pd.read_csv(
    "../results/clade_frequencies_for_h3n2.tsv",
    sep="\t",
    parse_dates=["timepoint", "future_timepoint"],
).rename(
    columns={"delta_month": "horizon"},
)

In [None]:
frequencies["horizon"] = frequencies["horizon"].astype(int)

In [None]:
frequencies.head()

In [None]:
frequencies["frequency_error"] = frequencies["observed_frequency"] - frequencies["projected_frequency"]

In [None]:
frequencies["abs_frequency_error"] = np.abs(frequencies["frequency_error"])

In [None]:
frequencies.shape

## Annotate initial frequency without delay to all clades

In [None]:
frequencies_without_delay = frequencies.loc[
    frequencies["delay_type"] == "none",
    ("clade_membership", "timepoint", "frequency")
].drop_duplicates()

In [None]:
frequencies_without_delay.shape

In [None]:
frequencies_without_delay = frequencies_without_delay.rename(
    columns={"frequency": "frequency_without_delay"},
)

In [None]:
frequencies_without_delay.head()

In [None]:
frequencies.shape

In [None]:
frequencies = frequencies.merge(
    frequencies_without_delay,
    how="left",
    on=["timepoint", "clade_membership"],
)

In [None]:
pd.isnull(frequencies["frequency_without_delay"]).sum()

In [None]:
frequencies["frequency_without_delay"] = frequencies["frequency_without_delay"].fillna(0)

In [None]:
frequencies.head()

In [None]:
frequencies.loc[:, ["frequency_without_delay", "frequency"]]

In [None]:
((frequencies["frequency"] > 0.1) & (frequencies["frequency"] < 0.95)).sum()

In [None]:
((frequencies["frequency_without_delay"] > 0.1) & (frequencies["frequency_without_delay"] < 0.95)).sum()

In [None]:
((frequencies["frequency_without_delay"] == 0) & (frequencies["observed_frequency"] == 0)).sum()

In [None]:
((frequencies["frequency"] == 0) & (frequencies["observed_frequency"] == 0)).sum()

In [None]:
distinct_large_clades_with_delay = set(frequencies.loc[
    (frequencies["frequency"] > 0.1) & (frequencies["frequency"] < 0.95),
    "clade_membership"
].drop_duplicates().values)

In [None]:
distinct_large_clades_without_delay = set(frequencies.loc[
    (frequencies["frequency_without_delay"] > 0.1) & (frequencies["frequency_without_delay"] < 0.95),
    "clade_membership"
].drop_duplicates().values)

In [None]:
len(distinct_large_clades_with_delay)

In [None]:
len(distinct_large_clades_without_delay)

In [None]:
distinct_large_clades_with_delay - distinct_large_clades_without_delay

In [None]:
distinct_large_clades_without_delay - distinct_large_clades_with_delay

In [None]:
frequencies[frequencies["clade_membership"].isin(distinct_large_clades_with_delay - distinct_large_clades_without_delay)]

## Plot clade frequency errors by delay type and forecast horizon

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=150)

sns.boxplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    hue_order=("none", "ideal", "observed"),
    data=frequencies,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    data=frequencies,
    hue_order=("none", "ideal", "observed"),
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

ax.text(
    0.5,
    0.95,
    "underestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.text(
    0.5,
    0.05,
    "overestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="lower left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel("Clade frequency error (all clades)")

sns.despine()
plt.tight_layout()

Plot clade frequency errors for larger clades only.

In [None]:
large_frequency_lower_threshold = 0.1

In [None]:
large_frequency_upper_threshold = 0.95

In [None]:
large_frequencies = frequencies.query(
    f"(frequency_without_delay >= {large_frequency_lower_threshold}) & (frequency_without_delay <= {large_frequency_upper_threshold})"
)

In [None]:
large_frequencies["frequency_without_delay"].describe()

In [None]:
large_frequencies.shape

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=150)

sns.boxplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    hue_order=("none", "ideal", "observed"),
    data=large_frequencies,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    data=large_frequencies,
    hue_order=("none", "ideal", "observed"),
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

ax.text(
    0.5,
    0.95,
    "underestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.text(
    0.5,
    0.05,
    "overestimated",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="lower left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel(r"""Clade frequency error
({large_frequency_lower_threshold}% $\leq$ initial frequency $\leq$ {large_frequency_upper_threshold}%)""".format(
    large_frequency_lower_threshold=int(large_frequency_lower_threshold * 100),
    large_frequency_upper_threshold=int(large_frequency_upper_threshold * 100),
))

sns.despine()
plt.tight_layout()

In [None]:
large_frequencies.groupby(["horizon", "delay_type"], sort=False).agg({
    "frequency_error": ["mean", "median", "std"],
    "abs_frequency_error": ["mean", "median", "std"],
})

In [None]:
large_frequencies.head()

## Plot absolute clade frequency errors

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=150)

sns.boxplot(
    x="horizon",
    y="abs_frequency_error",
    hue="delay_type",
    hue_order=("none", "ideal", "observed"),
    data=large_frequencies,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="abs_frequency_error",
    hue="delay_type",
    data=large_frequencies,
    hue_order=("none", "ideal", "observed"),
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="upper left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel(r"""Absolute clade frequency error
({large_frequency_lower_threshold}% $\leq$ initial frequency $\leq$ {large_frequency_upper_threshold}%)""".format(
    large_frequency_lower_threshold=int(large_frequency_lower_threshold * 100),
    large_frequency_upper_threshold=int(large_frequency_upper_threshold * 100),
))

sns.despine()
plt.tight_layout()

## Plot difference in mean absolute error (MAE) by horizon and delay type vs. same horizon without delay

In [None]:
large_mae_frequencies = large_frequencies.groupby(["horizon", "delay_type", "timepoint"])["abs_frequency_error"].mean().reset_index()

In [None]:
large_mae_frequencies

In [None]:
large_mae_frequencies_by_delays = large_mae_frequencies.pivot(
    index=["horizon", "timepoint"],
    values=["abs_frequency_error"],
    columns=["delay_type"],
).fillna(0)

In [None]:
large_mae_frequencies_by_delays.head()

In [None]:
large_mae_frequencies_by_delays.columns = ["ideal", "none", "observed"]

In [None]:
large_mae_frequencies_by_delays = large_mae_frequencies_by_delays.reset_index()

In [None]:
large_mae_frequencies_by_delays

In [None]:
large_mae_frequencies_by_delays["ideal_mae_difference"] = large_mae_frequencies_by_delays["none"] - large_mae_frequencies_by_delays["ideal"]

In [None]:
large_mae_frequencies_by_delays["observed_mae_difference"] = large_mae_frequencies_by_delays["none"] - large_mae_frequencies_by_delays["observed"]

In [None]:
large_mae_frequency_differences = large_mae_frequencies_by_delays.melt(
    id_vars=["horizon", "timepoint"],
    value_vars=["ideal_mae_difference", "observed_mae_difference"],
    var_name="delay_type",
    value_name="mae_difference",
)

In [None]:
large_mae_frequency_differences["delay_type"] = large_mae_frequency_differences["delay_type"].apply(lambda delay: delay.split("_")[0])

In [None]:
large_mae_frequency_differences

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=150)

sns.boxplot(
    x="horizon",
    y="mae_difference",
    hue="delay_type",
    hue_order=("ideal", "observed"),
    data=large_mae_frequency_differences,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="mae_difference",
    hue="delay_type",
    data=large_mae_frequency_differences,
    hue_order=("ideal", "observed"),
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[2:],
    labels=labels[2:],
    loc="lower left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel("Difference in MAE without and with delay")

sns.despine()
plt.tight_layout()

## Plot effect of interventions on absolute clade frequency error by timepoint and clade

In [None]:
status_quo = large_frequencies.query("(horizon == 12) & (delay_type == 'observed')").loc[
    :,
    ("future_timepoint", "clade_membership", "abs_frequency_error")
].copy()

In [None]:
status_quo["intervention"] = "status_quo"

In [None]:
status_quo.head()

In [None]:
status_quo.shape

In [None]:
improved_vaccine_dev = large_frequencies.query("(horizon == 6) & (delay_type == 'observed')").loc[
    :,
    ("future_timepoint", "clade_membership", "abs_frequency_error")
].copy()

In [None]:
improved_vaccine_dev["intervention"] = "improved_vaccine"

In [None]:
improved_vaccine_dev.head()

In [None]:
improved_vaccine_dev.shape

In [None]:
improved_surveillance = large_frequencies.query("(horizon == 12) & (delay_type == 'ideal')").loc[
    :,
    ("future_timepoint", "clade_membership", "abs_frequency_error")
].copy()

In [None]:
improved_surveillance["intervention"] = "improved_surveillance"

In [None]:
improved_surveillance.head()

In [None]:
improved_vaccine_and_surveillance = large_frequencies.query("(horizon == 6) & (delay_type == 'ideal')").loc[
    :,
    ("future_timepoint", "clade_membership", "abs_frequency_error")
].copy()

In [None]:
improved_vaccine_and_surveillance["intervention"] = "improved_vaccine_and_surveillance"

In [None]:
improved_vaccine_and_surveillance.head()

In [None]:
interventions = pd.concat([
    status_quo,
    improved_vaccine_dev,
    improved_surveillance,
    improved_vaccine_and_surveillance,
])

In [None]:
interventions.head()

In [None]:
interventions_by_timepoint_clade = interventions.pivot_table(
    index=["future_timepoint", "clade_membership"],
    columns=["intervention"],
    values="abs_frequency_error",
).dropna()

In [None]:
interventions_by_timepoint_clade.head()

In [None]:
interventions_by_timepoint_clade["status_quo_vs_improved_vaccine"] = (
    interventions_by_timepoint_clade["status_quo"] - interventions_by_timepoint_clade["improved_vaccine"]
)

In [None]:
interventions_by_timepoint_clade["status_quo_vs_improved_surveillance"] = (
    interventions_by_timepoint_clade["status_quo"] - interventions_by_timepoint_clade["improved_surveillance"]
)

In [None]:
interventions_by_timepoint_clade["status_quo_vs_improved_vaccine_and_surveillance"] = (
    interventions_by_timepoint_clade["status_quo"] - interventions_by_timepoint_clade["improved_vaccine_and_surveillance"]
)

In [None]:
interventions_by_timepoint_clade.reset_index()

In [None]:
differences_in_error_by_intervention = interventions_by_timepoint_clade.reset_index().melt(
    id_vars=[
        "future_timepoint",
        "clade_membership",
    ],
    value_vars=[
        "status_quo_vs_improved_vaccine",
        "status_quo_vs_improved_surveillance",
        "status_quo_vs_improved_vaccine_and_surveillance",
    ],
    value_name="difference_in_abs_frequency_error",
)

In [None]:
differences_in_error_by_intervention.head()

In [None]:
differences_in_error_by_intervention["intervention_name"] = differences_in_error_by_intervention["intervention"].apply(
    lambda intervention: " ".join(intervention.replace("status_quo_vs_", "").split("_"))
)

In [None]:
differences_in_error_by_intervention.shape

In [None]:
intervention_order = [
    "improved vaccine",
    "improved surveillance",
    "improved vaccine and surveillance",
]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=150)

sns.boxplot(
    x="intervention_name",
    y="difference_in_abs_frequency_error",
    data=differences_in_error_by_intervention,
    order=intervention_order,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="intervention_name",
    y="difference_in_abs_frequency_error",
    data=differences_in_error_by_intervention,
    order=intervention_order,
    color="#000000",    
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

ax.set_xlabel("Intervention")
ax.set_ylabel("Difference in absolute clade frequency\nerror per clade and timepoint\n(status quo - intervention)")

sns.despine()
plt.tight_layout()

In [None]:
differences_in_error_by_intervention["intervention_name"].drop_duplicates()

In [None]:
differences_in_error_by_intervention.groupby("intervention_name", sort=False).agg({
    "difference_in_abs_frequency_error": ["mean", "median", "std"]
})