In [None]:
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import pandas as pd
import seaborn as sns

In [None]:
color_by_delay_type = {
    "none": "C0",
    "ideal": "C1",
    "realistic": "C2",
}

In [None]:
interventions = pd.read_csv(
    "../manuscript/tables/h3n2_effects_of_realistic_interventions.tsv",
    sep="\t",
    parse_dates=["future_timepoint"],
)

In [None]:
interventions.head()

In [None]:
sorted_surveillance_interventions = interventions.query(
    "intervention_name == 'improved surveillance'"
).sort_values(
    "difference_in_total_absolute_forecast_error"
).reset_index()

In [None]:
sorted_surveillance_interventions.head()

In [None]:
sorted_surveillance_interventions.tail()

In [None]:
worst_future_timepoint = sorted_surveillance_interventions.at[0, "future_timepoint"]

In [None]:
worst_future_timepoint

In [None]:
worst_initial_timepoint = worst_future_timepoint - pd.DateOffset(years=1)

In [None]:
worst_initial_timepoint

In [None]:
best_future_timepoint = sorted_surveillance_interventions.at[
    sorted_surveillance_interventions.shape[0] - 1,
    "future_timepoint"
]

In [None]:
best_future_timepoint

In [None]:
best_initial_timepoint = best_future_timepoint - pd.DateOffset(years=1)

In [None]:
best_initial_timepoint

In [None]:
df = pd.read_csv(
    "../results/clade_frequencies_for_h3n2.tsv",
    sep="\t",
    parse_dates=[
        "timepoint",
        "future_timepoint",
    ],
)

In [None]:
df.head()

In [None]:
df.loc[
    (
        (df["timepoint"] == worst_initial_timepoint) &
        (df["future_timepoint"] == worst_future_timepoint) &
        (df["delay_type"] == "ideal")
    ),
].sort_values(
    "absolute_forecast_error",
    ascending=False,
).head()

In [None]:
worst_clade = df.loc[
    (
        (df["timepoint"] == worst_initial_timepoint) &
        (df["future_timepoint"] == worst_future_timepoint) &
        (df["delay_type"] == "ideal")
    ),
].sort_values(
    "absolute_forecast_error",
    ascending=False,
).reset_index().at[0, "clade_membership"]

In [None]:
worst_clade

In [None]:
worst_frequencies = df.loc[
    (
        (df["delay_type"] == "none") &
        (df["delta_month"] == 12) &
        (df["clade_membership"] == worst_clade)
    ),
    [
        "timepoint",
        "frequency",
    ],
]

In [None]:
worst_forecasts = df.loc[
    (
        (df["clade_membership"] == worst_clade) &
        (df["future_timepoint"] == worst_future_timepoint) &
        (df["timepoint"] == worst_initial_timepoint)
    ),
    [
        "timepoint",
        "future_timepoint",
        "frequency",
        "projected_frequency",
        "delay_type",
    ],
]

In [None]:
worst_forecasts.shape

In [None]:
worst_forecasts.values[:5]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 4), dpi=120)

ax.plot(
    worst_frequencies["timepoint"],
    worst_frequencies["frequency"],
    "o-",
    color="#999999",
)

for (initial_timepoint, future_timepoint, initial_frequency, future_frequency, delay_type) in worst_forecasts.values:
    ax.add_line(
        Line2D(
            [initial_timepoint, future_timepoint],
            [initial_frequency, future_frequency],
            color=color_by_delay_type[delay_type],
        )
    )

ax.axvline(
    x=worst_future_timepoint,
    color="#999999",
    linestyle="--",
    zorder=-10,
)

ax.set_xlabel("Date")
ax.set_ylabel("Frequency")

sns.despine()

In [None]:
df.loc[
    (
        (df["timepoint"] == best_initial_timepoint) &
        (df["future_timepoint"] == best_future_timepoint) &
        (df["delay_type"] == "ideal") &
        (df["frequency"] >= 0.2) &
        (df["frequency"] < 0.9)
    ),
].sort_values(
    "absolute_forecast_error",
    ascending=True,
).head()

In [None]:
best_clade = df.loc[
    (
        (df["timepoint"] == best_initial_timepoint) &
        (df["future_timepoint"] == best_future_timepoint) &
        (df["delay_type"] == "ideal") &
        (df["frequency"] >= 0.2) &
        (df["frequency"] < 0.9)
    ),
].sort_values(
    "absolute_forecast_error",
    ascending=True,
).reset_index().at[0, "clade_membership"]

In [None]:
best_clade

In [None]:
best_frequencies = df.loc[
    (
        (df["delay_type"] == "none") &
        (df["delta_month"] == 12) &
        (df["clade_membership"] == best_clade)
    ),
    [
        "timepoint",
        "frequency",
    ],
]

In [None]:
best_forecasts = df.loc[
    (
        (df["clade_membership"] == best_clade) &
        (df["future_timepoint"] == best_future_timepoint) &
        (df["timepoint"] == best_initial_timepoint)
    ),
    [
        "timepoint",
        "future_timepoint",
        "frequency",
        "projected_frequency",
        "delay_type",
    ],
]

In [None]:
best_forecasts

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 4), dpi=120)

ax.plot(
    best_frequencies["timepoint"],
    best_frequencies["frequency"],
    "o-",
    color="#999999",
)

for (initial_timepoint, future_timepoint, initial_frequency, future_frequency, delay_type) in best_forecasts.values:
    ax.add_line(
        Line2D(
            [initial_timepoint, future_timepoint],
            [initial_frequency, future_frequency],
            color=color_by_delay_type[delay_type],
        )
    )

ax.axvline(
    x=best_future_timepoint,
    color="#999999",
    linestyle="--",
    zorder=-10,
)

ax.set_xlabel("Date")
ax.set_ylabel("Frequency")

sns.despine()