# Fit linear model to "distance to the future" for simulated populations

## Imports and configuration

In [None]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns

In [None]:
sns.set_style("ticks")

In [None]:
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)

az.rcParams["stats.hdi_prob"] = 0.89  # set credible interval for entire notebook
az.rcParams["stats.information_criterion"] = "waic"  # set information criterion to use in `compare`
az.rcParams["stats.ic_scale"] = "deviance"  # set information criterion scale

## Load and prepare data for analysis

In [None]:
distances = pd.read_csv("results/model_inputs_for_simulated_populations.tsv", sep="\t")

In [None]:
distances.head()

In [None]:
delay_to_delay_type = {
    0: "none",
    1: "ideal",
    3: "realistic",
}

In [None]:
distances["delay_type"] = distances["delay"].map(delay_to_delay_type)

In [None]:
distances.head()

Standardize submission delay and forecast horizon values for use in models.
Both predictors use units of "months".

In [None]:
horizon_mean = distances["horizon"].mean()

In [None]:
distances["horizon_c"] = (
    (distances["horizon"] - horizon_mean)
)

In [None]:
delay_mean = distances["delay"].mean()

In [None]:
distances["delay_c"] = (
    (distances["delay"] - delay_mean)
)

In [None]:
distances.head()

## Define model

In [None]:
with pm.Model() as model:
    # Priors for linear model
    a = pm.Normal("a", 0, 1.0)
    b_s = pm.Normal("b_s", 0, 0.1)
    b_h = pm.Normal("b_h", 0, 0.1)

    # Linear model
    mu = pm.math.exp(a + (b_s * distances["delay_c"].values) + (b_h * distances["horizon_c"].values))
    
    # Priors for likelihood
    sigma = pm.Exponential("sigma", 0.5)
    
    # Likelihood
    distance = pm.Gamma("distance", mu=mu, sigma=sigma, observed=distances["distance"].values)
    
    # Sample from the priors.
    prior_samples = pm.sample_prior_predictive(100)

### Prior predictive checks

In [None]:
prior = az.extract(prior_samples["prior"])

In [None]:
prior

In [None]:
factors = distances.loc[:, ["delay", "horizon", "delay_c", "horizon_c"]].drop_duplicates().reset_index(drop=True)

In [None]:
factors

Calculate the mean distance from the prior samples at each combination of delay and horizon values.

In [None]:
prior_mu = np.apply_along_axis(
    lambda row: np.exp(prior["a"] + (prior["b_s"] * row[0]) + (prior["b_h"] * row[1])),
    1,
    factors.loc[:, ["delay_c", "horizon_c"]].values,
)

In [None]:
prior_sigma = np.array([prior["sigma"].values] * factors.shape[0])

Sample distances from the gamma likelihood with the mean calculated above and the variance from the posterior (repeated once for each combination of delay/horizon factors).

In [None]:
prior_gamma_dist = pm.Gamma.dist(mu=prior_mu, sigma=prior_sigma)

In [None]:
prior_distance = pm.draw(prior_gamma_dist, random_seed=rng)

In [None]:
priors = pd.concat([
    pd.DataFrame({
        "delay": factors.loc[i, "delay"],
        "horizon": factors.loc[i, "horizon"],
        "distance": prior_distance[i, :],
    })
    for i in range(factors.shape[0])
])

Plot observed distances and those from the prior predictive samples, to determine whether the latter include realistic values.

In [None]:
distances.head()

In [None]:
priors["delay_type"] = priors["delay"].map(delay_to_delay_type)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), dpi=150)

sns.boxplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    hue_order=("none", "ideal", "realistic"),
    data=distances,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax1,
)
sns.stripplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    data=distances,
    hue_order=("none", "ideal", "realistic"),
    alpha=0.35,
    ax=ax1,
    dodge=True,
)

ax1.set_ylim(bottom=-0.25)

handles, labels = ax1.get_legend_handles_labels()

ax1.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="upper left",
    title="Type of delay",
    frameon=False,
)
ax1.set_xlabel("Forecast horizon (months)")
ax1.set_ylabel("Distance to the future (AAs)")

ax1.set_title("Observed")

sns.boxplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    hue_order=("none", "ideal", "realistic"),
    data=priors,
    fliersize=0.0,
    color="#CCCCCC",
    ax=ax2,
)
sns.stripplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    data=priors,
    hue_order=("none", "ideal", "realistic"),
    alpha=0.35,
    ax=ax2,
    dodge=True,
)

ax2.set_ylim(bottom=-0.25)

ax2.get_legend().remove()
ax2.set_xlabel("Forecast horizon (months)")
ax2.set_ylabel("Distance to the future (AAs)")

ax2.set_title("Prior predictive")
                 
sns.despine()

## Fit the model to the data

In [None]:
with model:
    # Sample from the model.
    trace = pm.sample(
        draws=10000,
        tune=5000,
        chains=4,
        cores=4,
    )

In [None]:
az.plot_energy(trace)

In [None]:
az.plot_trace(trace, combined=True)
plt.tight_layout()

In [None]:
az.plot_posterior(trace)

In [None]:
az.plot_forest(trace, var_names=["b_s", "b_h"], combined=True)

In [None]:
az.summary(trace)

### Posterior checks

In [None]:
full_posterior = az.extract(trace["posterior"])

In [None]:
full_posterior

In [None]:
posterior = full_posterior.sel(draw=slice(None, None, 10))

In [None]:
posterior_mu = np.apply_along_axis(
    lambda row: np.exp(posterior["a"] + (posterior["b_s"] * row[0]) + (posterior["b_h"] * row[1])),
    1,
    factors.loc[:, ["delay_c", "horizon_c"]].values,
)

In [None]:
posterior_mu

In [None]:
posterior_mu.shape

In [None]:
posterior_sigma = np.array([posterior["sigma"].values] * factors.shape[0])

In [None]:
posterior_gamma_dist = pm.Gamma.dist(mu=posterior_mu, sigma=posterior_sigma)

In [None]:
posterior_distance = pm.draw(posterior_gamma_dist, random_seed=rng)

In [None]:
posteriors = pd.concat([
    pd.DataFrame({
        "delay": factors.loc[i, "delay"],
        "horizon": factors.loc[i, "horizon"],
        "mu": posterior_mu[i, :],
        "distance": posterior_distance[i, :],
    })
    for i in range(factors.shape[0])
])

In [None]:
posteriors["delay_type"] = posteriors["delay"].map(delay_to_delay_type)

In [None]:
posteriors

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), dpi=150)

sns.boxplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    hue_order=("none", "ideal", "realistic"),
    data=distances,
    color="#CCCCCC",
    fliersize=0.0,
    ax=ax1,
)
sns.stripplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    data=distances,
    hue_order=("none", "ideal", "realistic"),
    alpha=0.35,
    ax=ax1,
    dodge=True,
)

ax1.set_ylim(bottom=-0.25)

handles, labels = ax1.get_legend_handles_labels()

ax1.legend(
    handles=handles[3:],
    labels=labels[3:],
    loc="upper left",
    title="Type of delay",
    frameon=False,
)
ax1.set_xlabel("Forecast horizon (months)")
ax1.set_ylabel("Distance to the future (AAs)")

ax1.set_title("Observed")

sns.boxplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    hue_order=("none", "ideal", "realistic"),
    data=posteriors,
    fliersize=0.0,
    color="#CCCCCC",
    ax=ax2,
)
sns.stripplot(
    x="horizon",
    y="distance",
    hue="delay_type",
    data=posteriors,
    hue_order=("none", "ideal", "realistic"),
    alpha=0.35,
    ax=ax2,
    dodge=True,
)

ax2.set_ylim(bottom=-0.25)

ax2.get_legend().remove()
ax2.set_xlabel("Forecast horizon (months)")
ax2.set_ylabel("Distance to the future (AAs)")

ax2.set_title("Posterior predictive")

sns.despine()

## Plot distance to the present represented by intercept term

Calculate distance to the present without delays by setting delay=0 and horizon=0.

In [None]:
mu_to_present_no_delay = np.exp(posterior["a"] + posterior["b_s"] * (0 - delay_mean) + posterior["b_h"] * (0 - horizon_mean))

In [None]:
mu_to_present_no_delay

In [None]:
mu_to_present_no_delay.values

In [None]:
posterior["sigma"]

In [None]:
gamma_dist_to_present_no_delay = pm.Gamma.dist(mu=mu_to_present_no_delay.values, sigma=posterior["sigma"].values)
distance_to_present_no_delay = pm.draw(gamma_dist_to_present_no_delay, random_seed=rng)

In [None]:
distance_to_present_no_delay

In [None]:
distance_to_present_no_delay.shape

In [None]:
bins = np.arange(distance_to_present_no_delay.min(), distance_to_present_no_delay.max() + 0.1, 0.25)

In [None]:
mean_average_distance_to_present = distance_to_present_no_delay.mean()

In [None]:
mean_average_distance_to_present

In [None]:
median_average_distance_to_present = np.median(distance_to_present_no_delay)

In [None]:
median_average_distance_to_present

In [None]:
lower_hpdi_average_distance_to_present, upper_hpdi_average_distance_to_present = az.hdi(distance_to_present_no_delay, 0.89)

In [None]:
lower_hpdi_average_distance_to_present

In [None]:
upper_hpdi_average_distance_to_present

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=150)

ax.hist(
    distance_to_present_no_delay,
    bins=bins,
    color="#CCCCCC",
)

ax.axvline(x=median_average_distance_to_present, color="#000000")
ax.axvline(x=lower_hpdi_average_distance_to_present, color="#000000", linestyle="--")
ax.axvline(x=upper_hpdi_average_distance_to_present, color="#000000", linestyle="--")

ax.set_xlim(left=0)

ax.set_xlabel("Average distance to the present without delays (AAs)")
ax.set_ylabel("Number of posterior samples")

sns.despine()

## Plot effect of delays on distance to present

Create a sequence of average submission delays with standardization.

In [None]:
delay_seq = np.arange(0, 4)

In [None]:
delay_seq_c = (delay_seq - delay_mean)

In [None]:
delay_seq_c

In [None]:
mu_median_for_delays = []
mu_lower_hpdi = []
mu_upper_hpdi = []
distance_lower_hpdi = []
distance_upper_hpdi = []

for delay_c in delay_seq_c:
    mu = np.exp(posterior["a"] + posterior["b_s"] * delay_c + posterior["b_h"] * (0 - horizon_mean)).values
    mu_median_for_delays.append(np.median(mu))
    
    mu_lower_hpdi_for_delay, mu_upper_hpdi_for_delay = az.hdi(mu, 0.89)
    mu_lower_hpdi.append(mu_lower_hpdi_for_delay)
    mu_upper_hpdi.append(mu_upper_hpdi_for_delay)
    
    gamma_dist_for_delay = pm.Gamma.dist(mu=mu, sigma=posterior["sigma"].values)
    distance_for_delay = pm.draw(gamma_dist_for_delay, random_seed=rng)
    distance_lower_hpdi_for_delay, distance_upper_hpdi_for_delay = az.hdi(distance_for_delay, 0.89)
    distance_lower_hpdi.append(distance_lower_hpdi_for_delay)
    distance_upper_hpdi.append(distance_upper_hpdi_for_delay)

In [None]:
mu_median_for_delays

In [None]:
mu_lower_hpdi

In [None]:
mu_upper_hpdi

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=200)

ax.plot(
    delay_seq,
    mu_median_for_delays,
    "o-",
    color="#000000",
)

ax.fill_between(
    delay_seq,
    y1=mu_lower_hpdi,
    y2=mu_upper_hpdi,
    color="#999999",
    alpha=0.75,
    zorder=-10
)

ax.fill_between(
    delay_seq,
    y1=distance_lower_hpdi,
    y2=distance_upper_hpdi,
    color="#CCCCCC",
    alpha=0.75,
    zorder=-20
)

ax.set_xticks(delay_seq)

ax.set_ylim(bottom=0)

ax.set_xlabel("Submission delay (months)")
ax.set_ylabel("Distance to the present (AAs)")

sns.despine()

## Plot effect of horizons and delays

In [None]:
horizon_seq = np.arange(3, 13, 3)

In [None]:
horizon_seq

In [None]:
delay_seq = np.array([0, 1, 3])

In [None]:
posterior_simulated_distances = []
records = []

for horizon in horizon_seq:
    horizon_c = (horizon - horizon_mean)
    
    for delay in delay_seq:
        delay_c = (delay - delay_mean)
        
        mu = np.exp(posterior["a"] + posterior["b_s"] * delay_c + posterior["b_h"] * horizon_c).values
        mu_median = np.median(mu)

        mu_lower_hpdi, mu_upper_hpdi = az.hdi(mu, 0.89)

        gamma_dist = pm.Gamma.dist(mu=mu, sigma=posterior["sigma"].values)
        distance = pm.draw(gamma_dist, random_seed=rng)
        distance_lower_hpdi, distance_upper_hpdi = az.hdi(distance, 0.89)
        
        posterior_simulated_distances.append(
            pd.DataFrame({
                "horizon": horizon,
                "delay": delay,
                "distance": distance,
            })
        )

        records.append({
            "horizon": horizon,
            "delay": delay,
            "mu_median": mu_median,
            "mu_lower_hpdi": mu_lower_hpdi,
            "mu_upper_hpdi": mu_upper_hpdi,
            "distance_lower_hpdi": distance_lower_hpdi,
            "distance_upper_hpdi": distance_upper_hpdi,
        })

In [None]:
df = pd.DataFrame(records)

In [None]:
df

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=200)

# Distance to the future without delay
ax.plot(
    df.query("delay == 0")["horizon"],
    df.query("delay == 0")["mu_median"],
    "o-",
    color="C0",
    label="no delay",
)

ax.fill_between(
    df.query("delay == 0")["horizon"],
    y1=df.query("delay == 0")["mu_lower_hpdi"],
    y2=df.query("delay == 0")["mu_upper_hpdi"],
    color="C0",
    alpha=0.5,
    zorder=-10
)

ax.fill_between(
    df.query("delay == 0")["horizon"],
    y1=df.query("delay == 0")["distance_lower_hpdi"],
    y2=df.query("delay == 0")["distance_upper_hpdi"],
    color="C0",
    alpha=0.25,
    zorder=-20
)

# Distance to the future with delay
ax.plot(
    df.query("delay == 1")["horizon"],
    df.query("delay == 1")["mu_median"],
    "o-",
    color="C1",
    label="ideal delay",
)

ax.fill_between(
    df.query("delay == 1")["horizon"],
    y1=df.query("delay == 1")["mu_lower_hpdi"],
    y2=df.query("delay == 1")["mu_upper_hpdi"],
    color="C1",
    alpha=0.5,
    zorder=-10
)

ax.fill_between(
    df.query("delay == 1")["horizon"],
    y1=df.query("delay == 1")["distance_lower_hpdi"],
    y2=df.query("delay == 1")["distance_upper_hpdi"],
    color="C1",
    alpha=0.25,
    zorder=-20
)

# Distance to the future with delay
ax.plot(
    df.query("delay == 3")["horizon"],
    df.query("delay == 3")["mu_median"],
    "o-",
    color="C2",
    label="realistic delay",
)

ax.fill_between(
    df.query("delay == 3")["horizon"],
    y1=df.query("delay == 3")["mu_lower_hpdi"],
    y2=df.query("delay == 3")["mu_upper_hpdi"],
    color="C2",
    alpha=0.5,
    zorder=-10
)

ax.fill_between(
    df.query("delay == 3")["horizon"],
    y1=df.query("delay == 3")["distance_lower_hpdi"],
    y2=df.query("delay == 3")["distance_upper_hpdi"],
    color="C2",
    alpha=0.25,
    zorder=-20
)

ax.set_xticks(horizon_seq)

ax.set_ylim(bottom=0)

ax.set_xlabel("Forecast horizon (months)")
ax.set_ylabel("Distance to the future (AAs)")

ax.legend(
    loc="upper left",
    frameon=False,
)

sns.despine()

## Plot improvements under different realistic scenarios

We consider three possible realistic scenarios for future practice of influenza genomic surveillance and vaccine development:

1. A change in vaccine development reduces the required forecast horizon from 12 months to 6 months (e.g., through adoption of mRNA-based vaccines).
2. A change in genomic surveillance capacity and policy reduces the average submission delay of genomes to GISAID from 3 months to 1 month.
3. Both changes to vaccine development and genomic surveillance occur at once.

In [None]:
posterior_simulated_distance_df = pd.concat(posterior_simulated_distances)

In [None]:
posterior_simulated_distance_df.head()

Get difference in distances for scenario 1 where delay=3 and horizon=12 or 6.

In [None]:
scenario_1_improvement = (
    posterior_simulated_distance_df.query("(delay == 3) & (horizon == 12)")["distance"] - 
    posterior_simulated_distance_df.query("(delay == 3) & (horizon == 6)")["distance"]
)

In [None]:
scenario_1_improvement_median = scenario_1_improvement.median()

In [None]:
scenario_1_improvement_median

In [None]:
scenario_1_improvement_lower_hdi, scenario_1_improvement_upper_hdi = az.hdi(scenario_1_improvement.values, 0.89)

In [None]:
scenario_1_improvement_lower_hdi

In [None]:
scenario_1_improvement_upper_hdi

Get difference in distances for scenario 2 where horizon=12 and delay=3 or 1.

In [None]:
scenario_2_improvement = (
    posterior_simulated_distance_df.query("(delay == 3) & (horizon == 12)")["distance"] - 
    posterior_simulated_distance_df.query("(delay == 1) & (horizon == 12)")["distance"]
)

In [None]:
scenario_2_improvement_median = scenario_2_improvement.median()

In [None]:
scenario_2_improvement_median

In [None]:
scenario_2_improvement_lower_hdi, scenario_2_improvement_upper_hdi = az.hdi(scenario_2_improvement.values, 0.89)

In [None]:
scenario_2_improvement_lower_hdi

In [None]:
scenario_2_improvement_upper_hdi

Finally, get difference in distances for scenario 3 where horizon/delay is either 12/3 or 6/1.

In [None]:
scenario_3_improvement = (
    posterior_simulated_distance_df.query("(delay == 3) & (horizon == 12)")["distance"] - 
    posterior_simulated_distance_df.query("(delay == 1) & (horizon == 6)")["distance"]
)

In [None]:
scenario_3_improvement_median = scenario_3_improvement.median()

In [None]:
scenario_3_improvement_median

In [None]:
scenario_1_improvement_median + scenario_2_improvement_median

In [None]:
scenario_3_improvement_lower_hdi, scenario_3_improvement_upper_hdi = az.hdi(scenario_3_improvement.values, 0.89)

In [None]:
scenario_3_improvement_lower_hdi

In [None]:
scenario_3_improvement_upper_hdi

In [None]:
bins = np.arange(
    min(scenario_1_improvement.min(), scenario_2_improvement.min(), scenario_3_improvement.min()),
    max(scenario_1_improvement.max(), scenario_2_improvement.max(), scenario_3_improvement.max()) + 0.1,
    0.25,
)

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(8, 5), dpi=150, sharex=True, sharey=True)
all_axes = axes.flatten()

# Scenario 1
ax1 = all_axes[0]
ax1.hist(
    scenario_1_improvement,
    bins=bins,
    color="#999999",
)
ax1.axvline(x=0, color="#000000", zorder=-10, linewidth=1)
ax1.axvline(x=scenario_1_improvement_median, color="red")
ax1.axvline(x=scenario_1_improvement_lower_hdi, color="red", linestyle="--")
ax1.axvline(x=scenario_1_improvement_upper_hdi, color="red", linestyle="--")

ax1.text(
    0.92,
    0.5,
    f"{scenario_1_improvement_median:.2f} ({scenario_1_improvement_lower_hdi:.2f}, {scenario_1_improvement_upper_hdi:.2f}) AAs",
    horizontalalignment='right',
    verticalalignment='center',
    transform=ax1.transAxes,
)

ax1.set_title("Improved vaccine development (6-month horizon, 3-month delay)")

# Scenario 2
ax2 = all_axes[1]
ax2.hist(
    scenario_2_improvement,
    bins=bins,
    color="#999999",
)
ax2.axvline(x=0, color="#000000", zorder=-10, linewidth=1)
ax2.axvline(x=scenario_2_improvement_median, color="red")
ax2.axvline(x=scenario_2_improvement_lower_hdi, color="red", linestyle="--")
ax2.axvline(x=scenario_3_improvement_upper_hdi, color="red", linestyle="--")

ax2.text(
    0.92,
    0.5,
    f"{scenario_2_improvement_median:.2f} ({scenario_2_improvement_lower_hdi:.2f}, {scenario_2_improvement_upper_hdi:.2f}) AAs",
    horizontalalignment='right',
    verticalalignment='center',
    transform=ax2.transAxes,
)

ax2.set_ylabel("Number of posterior samples")
ax2.set_title("Improved genomic surveillance (12-month horizon, 1-month delay)")

# Scenario 3
ax3 = all_axes[2]
ax3.hist(
    scenario_3_improvement,
    bins=bins,
    color="#999999",
)
ax3.axvline(x=0, color="#000000", zorder=-10, linewidth=1)
ax3.axvline(x=scenario_3_improvement_median, color="red")
ax3.axvline(x=scenario_3_improvement_lower_hdi, color="red", linestyle="--")
ax3.axvline(x=scenario_3_improvement_upper_hdi, color="red", linestyle="--")

ax3.text(
    0.92,
    0.5,
    f"{scenario_3_improvement_median:.2f} ({scenario_3_improvement_lower_hdi:.2f}, {scenario_3_improvement_upper_hdi:.2f}) AAs",
    horizontalalignment='right',
    verticalalignment='center',
    transform=ax3.transAxes,
)

ax3.set_xlabel("Reduction in distance to the future (AAs)")
ax3.set_title("Improved vaccine and surveillance (6-month horizon, 1-month delay)")

sns.despine()
plt.tight_layout()