# Plot clade frequency errors by delay type and forecast horizon for natural H3N2 populations 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
sns.set_style("ticks")

## Load clade frequencies

In [None]:
frequencies = pd.read_csv(
    "../results/clade_frequencies_for_h3n2.tsv",
    sep="\t",
    parse_dates=["timepoint", "future_timepoint"],
).rename(
    columns={"delta_month": "horizon"},
)

In [None]:
frequencies["horizon"] = frequencies["horizon"].astype(int)

In [None]:
frequencies.head()

In [None]:
frequencies["frequency_error"] = frequencies["observed_frequency"] - frequencies["projected_frequency"]

In [None]:
frequencies["abs_frequency_error"] = np.abs(frequencies["frequency_error"])

In [None]:
frequencies.shape

## Plot clade frequency errors by delay type and forecast horizon

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=150)

sns.boxplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    hue_order=("none", "ideal", "observed"),
    data=frequencies,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    data=frequencies,
    hue_order=("none", "ideal", "observed"),
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="upper left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel("Clade frequency error\n(without delay - with delay)")

sns.despine()
plt.tight_layout()

Plot clade frequency errors for larger clades only.

In [None]:
large_frequency_threshold = 0.1

In [None]:
large_frequencies = frequencies.query(f"frequency >= {large_frequency_threshold}")

In [None]:
large_frequencies.shape

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=150)

sns.boxplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    hue_order=("none", "ideal", "observed"),
    data=large_frequencies,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="frequency_error",
    hue="delay_type",
    data=large_frequencies,
    hue_order=("none", "ideal", "observed"),
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="lower left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel(r"""Clade frequency error
($\geq${large_frequency_threshold}% without delay - with delay)""".format(large_frequency_threshold=int(large_frequency_threshold * 100)))

sns.despine()
plt.tight_layout()

In [None]:
sns.lmplot(
    data=large_frequencies,
    x="frequency",
    y="frequency_error",
    hue="delay_type",
    col="horizon",
    col_wrap=2,
    height=4,
)

In [None]:
large_frequencies

In [None]:
large_mae_frequencies = large_frequencies.groupby(["horizon", "delay_type", "timepoint"])["abs_frequency_error"].mean().reset_index()

In [None]:
large_mae_frequencies

In [None]:
large_mae_frequencies_by_delays = large_mae_frequencies.pivot(
    index=["horizon", "timepoint"],
    values=["abs_frequency_error"],
    columns=["delay_type"],
).fillna(0)

In [None]:
large_mae_frequencies_by_delays.head()

In [None]:
large_mae_frequencies_by_delays.columns = ["ideal", "none", "observed"]

In [None]:
large_mae_frequencies_by_delays = large_mae_frequencies_by_delays.reset_index()

In [None]:
large_mae_frequencies_by_delays

In [None]:
large_mae_frequencies_by_delays["ideal_mae_difference"] = large_mae_frequencies_by_delays["none"] - large_mae_frequencies_by_delays["ideal"]

In [None]:
large_mae_frequencies_by_delays["observed_mae_difference"] = large_mae_frequencies_by_delays["none"] - large_mae_frequencies_by_delays["observed"]

In [None]:
large_mae_frequency_differences = large_mae_frequencies_by_delays.melt(
    id_vars=["horizon", "timepoint"],
    value_vars=["ideal_mae_difference", "observed_mae_difference"],
    var_name="delay_type",
    value_name="mae_difference",
)

In [None]:
large_mae_frequency_differences["delay_type"] = large_mae_frequency_differences["delay_type"].apply(lambda delay: delay.split("_")[0])

In [None]:
large_mae_frequency_differences

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=150)

sns.boxplot(
    x="horizon",
    y="mae_difference",
    hue="delay_type",
    hue_order=("ideal", "observed"),
    data=large_mae_frequency_differences,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="mae_difference",
    hue="delay_type",
    data=large_mae_frequency_differences,
    hue_order=("ideal", "observed"),
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

handles, labels = ax.get_legend_handles_labels()

ax.legend(
    handles=handles[2:],
    labels=labels[2:],
    loc="lower left",
    title="Delay type",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel("Difference in MAE without and with delay")

sns.despine()
plt.tight_layout()