In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

%matplotlib inline

In [None]:
sns.set_style("ticks")

## Compare distances

Compare the model distances between the estimated and observed future populations when the estimated populations depend on different data delay conditions. We expect that the "no delay" scenario will be closer to the true future on average followed by the "ideal" and then the "realistic" conditions.

In [None]:
distances = pd.read_csv(
    snakemake.input.distances,
    sep="\t",
    parse_dates=[
        "initial_timepoint",
        "future_timepoint",
    ]
)

In [None]:
distances

In [None]:
delay_types = set(distances["delay_type"].drop_duplicates().values)

In [None]:
delay_types

In [None]:
realistic_delay_type = list(delay_types - {"none", "ideal"})[0]

In [None]:
realistic_delay_type

In [None]:
summary_distances_by_delay_and_horizon = distances.groupby(["horizon", "delay_type"], sort=False).agg({
    "distance": ["mean", "median", "std"],
}).round(2)

In [None]:
summary_distances_by_delay_and_horizon[("distance", "mean")]

In [None]:
summary_distances_by_delay_and_horizon.columns

In [None]:
summary_distances_by_delay_and_horizon.columns = [
    column[1]
    for column in summary_distances_by_delay_and_horizon.columns
]

In [None]:
summary_distances_by_delay_and_horizon

In [None]:
summary_distances_by_delay_and_horizon["mean_std"] = summary_distances_by_delay_and_horizon.apply(
    lambda row: f"{row['mean']:.2f} +/- {row['std']:.2f}",
    axis=1,
)

In [None]:
summary_distances_by_delay_and_horizon

In [None]:
summary_distances_by_delay_and_horizon = summary_distances_by_delay_and_horizon.pivot_table(
    values=["mean_std"],
    index=["horizon"],
    columns=["delay_type"],
    aggfunc=lambda value: value,
    sort=False,
)

In [None]:
summary_distances_by_delay_and_horizon

In [None]:
summary_distances_by_delay_and_horizon.columns = [
    column[1]
    for column in summary_distances_by_delay_and_horizon.columns
]

In [None]:
summary_distances_by_delay_and_horizon = summary_distances_by_delay_and_horizon.reset_index()

In [None]:
summary_distances_by_delay_and_horizon

In [None]:
if realistic_delay_type == "observed":
    table_template_header = r"""
\begin{tabular*}{0.7\textwidth}{rrrr}
\toprule
          & \multicolumn{3}{c}{Distance to future (mean +/- std dev AAs)} \\
  Horizon & No delay & Ideal delay & Observed delay \\
\midrule
"""
    table_template_row = r"{horizon} & {none} & {ideal} & {observed} \\"
else:
    table_template_header = r"""
\begin{tabular*}{0.7\textwidth}{rrrr}
\toprule
          & \multicolumn{3}{c}{Distance to future (mean +/- std dev AAs)} \\
  Horizon & No delay & Ideal delay & Realistic delay \\
\midrule
"""
    table_template_row = r"{horizon} & {none} & {ideal} & {realistic} \\"

table_template_footer = r"""
\bottomrule
\end{tabular*}
"""

In [None]:
with open(snakemake.output.distances_summary_table, "w", encoding="utf-8") as oh:
    oh.write(table_template_header + "\n")
    
    for record in summary_distances_by_delay_and_horizon.to_dict(orient="records"):
        oh.write(table_template_row.format(**record) + "\n")
    
    oh.write(table_template_footer + "\n")

In [None]:
delay_order = ("none", "ideal", realistic_delay_type)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300)

sns.violinplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    hue_order=delay_order,
    data=distances,
    palette=["#FFFFFF"] * len(delay_order),
    inner="quartile",
    cut=0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    data=distances,
    hue_order=delay_order,
    alpha=0.25,
    ax=ax,
    dodge=True,
)

ax.set_ylim(bottom=0)

handles, labels = ax.get_legend_handles_labels()
for handle in handles:
    handle.set_alpha(1)

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="upper left",
    title="Type of delay",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel("Distance to the future (AAs)")

sns.despine()

plt.tight_layout()

plt.savefig(snakemake.output.distances_figure)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300)

sns.violinplot(
    x="horizon",
    y="optimal_distance",
    hue="delay_type",
    hue_order=delay_order,
    data=distances,
    palette=["#FFFFFF"] * len(delay_order),
    inner="quartile",
    cut=0,
    ax=ax,
)
sns.stripplot(
    x="horizon",
    y="optimal_distance",
    hue="delay_type",
    data=distances,
    hue_order=delay_order,
    alpha=0.25,
    ax=ax,
    dodge=True,
)

ax.set_ylim(bottom=0)

handles, labels = ax.get_legend_handles_labels()
for handle in handles:
    handle.set_alpha(1)

ax.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="upper left",
    title="Type of delay",
    frameon=False,
)
ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel("Optimal distance to the future (AAs)")

sns.despine()

plt.tight_layout()

plt.savefig(snakemake.output.optimal_distances_figure)

## Plot effects of interventions on distances to the future

In [None]:
distances

### Compare effects of interventions on distance to the future for fitness metric

In [None]:
status_quo = distances.query(
    f"(horizon == 12) & (delay_type == '{realistic_delay_type}')"
).loc[
    :,
    ("future_timepoint", "distance")
].copy()

In [None]:
status_quo["intervention"] = "status_quo"

In [None]:
improved_vaccine_dev = distances.query(
    f"(horizon == 6) & (delay_type == '{realistic_delay_type}')"
).loc[
    :,
    ("future_timepoint", "distance")
].copy()

In [None]:
improved_vaccine_dev["intervention"] = "improved_vaccine"

In [None]:
improved_surveillance = distances.query(
    f"(horizon == 12) & (delay_type == 'ideal')"
).loc[
    :,
    ("future_timepoint", "distance")
].copy()

In [None]:
improved_surveillance["intervention"] = "improved_surveillance"

In [None]:
improved_vaccine_and_surveillance = distances.query(
    f"(horizon == 6) & (delay_type == 'ideal')"
).loc[
    :,
    ("future_timepoint", "distance")
].copy()

In [None]:
improved_vaccine_and_surveillance["intervention"] = "improved_vaccine_and_surveillance"

In [None]:
interventions = pd.concat([
    status_quo,
    improved_vaccine_dev,
    improved_surveillance,
    improved_vaccine_and_surveillance,
])

In [None]:
interventions.head()

In [None]:
interventions_by_timepoint = interventions.pivot_table(
    index=["future_timepoint"],
    columns=["intervention"],
    values="distance",
).dropna()

In [None]:
interventions_by_timepoint

In [None]:
interventions_by_timepoint["status_quo_vs_improved_vaccine"] = (
    interventions_by_timepoint["status_quo"] - interventions_by_timepoint["improved_vaccine"]
)

interventions_by_timepoint["status_quo_vs_improved_surveillance"] = (
    interventions_by_timepoint["status_quo"] - interventions_by_timepoint["improved_surveillance"]
)

interventions_by_timepoint["status_quo_vs_improved_vaccine_and_surveillance"] = (
    interventions_by_timepoint["status_quo"] - interventions_by_timepoint["improved_vaccine_and_surveillance"]
)

In [None]:
differences_by_intervention = interventions_by_timepoint.reset_index().melt(
    id_vars=[
        "future_timepoint",
    ],
    value_vars=[
        "status_quo_vs_improved_vaccine",
        "status_quo_vs_improved_surveillance",
        "status_quo_vs_improved_vaccine_and_surveillance",
    ],
    value_name="difference_in_distance",
)

In [None]:
differences_by_intervention["intervention_name"] = differences_by_intervention["intervention"].apply(
    lambda intervention: " ".join(intervention.replace("status_quo_vs_", "").split("_"))
)

In [None]:
differences_by_intervention

In [None]:
intervention_order = [
    "improved vaccine",
    "improved surveillance",
    "improved vaccine and surveillance",
]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=300)

sns.violinplot(
    x="intervention_name",
    y="difference_in_distance",
    data=differences_by_intervention,
    order=intervention_order,
    color="#FFFFFF",
    cut=0,
    inner="quartile",
    ax=ax,
)
sns.stripplot(
    x="intervention_name",
    y="difference_in_distance",
    data=differences_by_intervention,
    order=intervention_order,
    color="#000000",    
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

ax.set_xlabel("Intervention")
ax.set_ylabel("Difference in distance to future per timepoint\n(status quo - intervention)")

#ax.set_ylim(
#    bottom=differences_in_error_by_intervention["difference_in_total_absolute_forecast_error"].min() - 0.2,
#)

ax.text(
    0.5,
    0.95,
    "more accurate",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.text(
    0.5,
    0.05,
    "less accurate",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

sns.despine()
plt.tight_layout()

plt.savefig(snakemake.output.effects_of_realistic_interventions)

In [None]:
differences_by_intervention.sort_values("difference_in_distance").head(10)

### Compare effects of interventions on optimal distance to the future

In [None]:
status_quo = distances.query(
    f"(horizon == 12) & (delay_type == '{realistic_delay_type}')"
).loc[
    :,
    ("future_timepoint", "optimal_distance")
].copy()

In [None]:
status_quo["intervention"] = "status_quo"

In [None]:
improved_vaccine_dev = distances.query(
    f"(horizon == 6) & (delay_type == '{realistic_delay_type}')"
).loc[
    :,
    ("future_timepoint", "optimal_distance")
].copy()

In [None]:
improved_vaccine_dev["intervention"] = "improved_vaccine"

In [None]:
improved_surveillance = distances.query(
    f"(horizon == 12) & (delay_type == 'ideal')"
).loc[
    :,
    ("future_timepoint", "optimal_distance")
].copy()

In [None]:
improved_surveillance["intervention"] = "improved_surveillance"

In [None]:
improved_vaccine_and_surveillance = distances.query(
    f"(horizon == 6) & (delay_type == 'ideal')"
).loc[
    :,
    ("future_timepoint", "optimal_distance")
].copy()

In [None]:
improved_vaccine_and_surveillance["intervention"] = "improved_vaccine_and_surveillance"

In [None]:
interventions = pd.concat([
    status_quo,
    improved_vaccine_dev,
    improved_surveillance,
    improved_vaccine_and_surveillance,
])

In [None]:
interventions.head()

In [None]:
interventions_by_timepoint = interventions.pivot_table(
    index=["future_timepoint"],
    columns=["intervention"],
    values="optimal_distance",
).dropna()

In [None]:
interventions_by_timepoint

In [None]:
interventions_by_timepoint["status_quo_vs_improved_vaccine"] = (
    interventions_by_timepoint["status_quo"] - interventions_by_timepoint["improved_vaccine"]
)

interventions_by_timepoint["status_quo_vs_improved_surveillance"] = (
    interventions_by_timepoint["status_quo"] - interventions_by_timepoint["improved_surveillance"]
)

interventions_by_timepoint["status_quo_vs_improved_vaccine_and_surveillance"] = (
    interventions_by_timepoint["status_quo"] - interventions_by_timepoint["improved_vaccine_and_surveillance"]
)

In [None]:
differences_by_intervention = interventions_by_timepoint.reset_index().melt(
    id_vars=[
        "future_timepoint",
    ],
    value_vars=[
        "status_quo_vs_improved_vaccine",
        "status_quo_vs_improved_surveillance",
        "status_quo_vs_improved_vaccine_and_surveillance",
    ],
    value_name="difference_in_optimal_distance",
)

In [None]:
differences_by_intervention["intervention_name"] = differences_by_intervention["intervention"].apply(
    lambda intervention: " ".join(intervention.replace("status_quo_vs_", "").split("_"))
)

In [None]:
differences_by_intervention

In [None]:
intervention_order = [
    "improved vaccine",
    "improved surveillance",
    "improved vaccine and surveillance",
]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=300)

sns.violinplot(
    x="intervention_name",
    y="difference_in_optimal_distance",
    data=differences_by_intervention,
    order=intervention_order,
    color="#FFFFFF",
    cut=0,
    inner="quartile",
    ax=ax,
)
sns.stripplot(
    x="intervention_name",
    y="difference_in_optimal_distance",
    data=differences_by_intervention,
    order=intervention_order,
    color="#000000",    
    alpha=0.35,
    ax=ax,
    dodge=True,
)

ax.axhline(
    y=0,
    color="#000000",
    zorder=-10,
    linewidth=1,
)

ax.set_xlabel("Intervention")
ax.set_ylabel("Difference in optimal distance to future per timepoint\n(status quo - intervention)")

ax.set_ylim(
    bottom=differences_by_intervention["difference_in_optimal_distance"].min() - 0.5,
)

ax.text(
    0.5,
    0.95,
    "more accurate",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

ax.text(
    0.5,
    0.05,
    "less accurate",
    horizontalalignment='center',
    verticalalignment='center',
    transform=ax.transAxes,
)

sns.despine()
plt.tight_layout()

plt.savefig(snakemake.output.optimal_effects_of_realistic_interventions)

In [None]:
differences_by_intervention.groupby("intervention_name")["difference_in_optimal_distance"].mean()