# PoPS Global Model: Forecast 
Use this notebook to run the model with parameters sets sampled from a distribution generated from the previous model calibration step. These sampled parameter sets generate a forecast that propagates parameter uncertainty over multiple stochastic model runs.  

This notebook can be run after 0, 1, 2, and 3b. We recommend also running 3a first, to check for and troubleshoot issues.

## Set up workspace from env and configuration files 

First, import needed packages.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import dotenv
import os
import json

import warnings
warnings.filterwarnings(action='once')

Navigate to main repository.

In [None]:
os.chdir("../")

Import needed PoPS Global functions.

In [None]:
from pandemic.multirun_helpers import write_commands, generate_param_samples

Read in path variables from .env.

In [None]:
# Load variables and paths from .env
dotenv.load_dotenv(".env")

# Read environmental variables
input_dir = os.getenv("INPUT_PATH")
out_dir = os.getenv("OUTPUT_PATH")
sim_name = os.getenv("SIM_NAME")

Read in parameters from config.json

In [None]:
config_json_path = f"{out_dir}/config_{sim_name}.json"

with open(config_json_path) as json_file:
    config = json.load(json_file)

coi = config["coi"]
sim_years = config["sim_years"]
validation_method = config["validation_method"]

run_name = f"{sim_name}_calibrate"

## Use the summary stats from the grid search to generate a parameter distribution

Read summary statistics from file.

In [None]:
stats_dir = f"{out_dir}/summary_stats/{run_name}"

col_dict = {
    "start_max": "start",
    "alpha_max": "alpha",
    "beta_max": "beta",
    "lamda_max": "lamda",
}

agg_df = pd.read_csv(f"{stats_dir}/summary_stats_bySample.csv").rename(columns=col_dict)

In [None]:
# Create folder to save forecast figures

fig_dir = f"{stats_dir}/figs/forecast/"

if not os.path.exists(fig_dir):
    os.makedirs(fig_dir)

### Set a performance threshold
Select a threhold percentile value (on F-beta) to determine which samples to use to fit the distribution. 

The viusalizations below help explore the possible thresholds and their impact on the number of samples included and the 
corresponding cut-off value for F-beta, and on the distribution of parameters included. 

If Leave-One-Out cross validation is used, this will produce a plot for each Fbeta (overall Fbeta, and one per location in the validation data).

In [None]:
# Extract the fbeta columns
fbeta_cols = [fbeta_col for fbeta_col in agg_df.columns if "fbeta" in fbeta_col and "mean" in fbeta_col]

# Set up an empty dictionary of lists to store the results
min_fbetas = {}
min_fbetas["quantile"] = []
min_fbetas["count"] = []

for fbeta_col in fbeta_cols:
    min_fbetas[fbeta_col] = []

# Loop through the quantile thresholds
for val in range(70, 100):
    min_fbetas["quantile"] += [val]
    for fbeta_col in fbeta_cols:
        subset = agg_df.loc[agg_df[fbeta_col] >= agg_df[fbeta_col].quantile(val / 100)]   
        min_fbetas[fbeta_col] += [subset[fbeta_col].min()]
    # Count is consistent across metrics
    min_fbetas["count"] += [len(subset.index)]     

# Convert to dataframe

sample_stats = pd.DataFrame(
    min_fbetas
    ).set_index("quantile")

In [None]:
# Set a % threshold (0 - 100) - adjust based on the below plots

quant_threshold = 90

Visual: How many samples and what Fbeta scores are captured with each threshold?

In [None]:
width = len(fbeta_cols) + 1
fig, axs = plt.subplots(1, width, figsize=(4*width, 4))

sample_stats["count"].plot(ax=axs[0])
axs[0].vlines(
    quant_threshold,
    ymin=sample_stats["count"].min(),
    ymax=sample_stats["count"].max(),
    linestyle="dashed",
    color="firebrick",
)
axs[0].set_title(f"Count")

for i, fbeta_col in enumerate(fbeta_cols):

    sample_stats[fbeta_col].plot(ax=axs[i+1])
    axs[i+1].vlines(
        quant_threshold,
        ymin=sample_stats[fbeta_col].min(),
        ymax=sample_stats[fbeta_col].max(),
        linestyle="dashed",
        color="firebrick",
    )
    axs[i+1].set_title(f"Fbeta {' '.join(fbeta_col.split('_')[1:])}")

plt.savefig(f"{fig_dir}/sample_threshold.png", bbox_inches="tight")
plt.show()

Visual: What do the distributions of alpha and lamda look like with that threshold?

In [None]:
for fbeta_col in fbeta_cols:
    agg_df[f"top_{fbeta_col}"] = np.where(
        agg_df[fbeta_col] >= agg_df[fbeta_col].quantile(quant_threshold / 100), "top", "low"
    )

Visualize separation by parameter.

In [None]:
# Alpha by year
for fbeta_col in fbeta_cols:
    ax = sns.relplot(
        x="alpha",
        y=fbeta_col,
        col="start",
        hue=f"top_{fbeta_col}",
        palette="rocket",
        data=agg_df,
        edgecolor="black",
        linewidth=0.5,
        s=100,
    )
    plt.savefig(f"{fig_dir}/top_alpha_{fbeta_col}_start.png", bbox_inches="tight")
    plt.show()

In [None]:
# Lamda by year
for fbeta_col in fbeta_cols:
    ax = sns.relplot(
        x="lamda",
        y=fbeta_col,
        col="start",
        hue=f"top_{fbeta_col}",
        palette="rocket",
        data=agg_df,
        edgecolor="black",
        linewidth=0.5,
        s=100,
    )

    plt.savefig(f"{fig_dir}/top_lambda_{fbeta_col}_start.png", bbox_inches="tight")
    plt.show()

In [None]:
# Beta by year
for fbeta_col in fbeta_cols:
    ax = sns.relplot(
        x="beta",
        y=fbeta_col,
        col="start",
        hue=f"top_{fbeta_col}",
        palette="rocket",
        data=agg_df,
        edgecolor="black",
        linewidth=0.5,
        s=100,
    )

    plt.savefig(f"{fig_dir}/top_beta_{fbeta_col}_start.png", bbox_inches="tight")
    plt.show()

Visualize the overall parameter distributions of sampled sets. 

- If validation method is "loo", the sampled parameters will be fit from parameter sets above the quantile threshold for each omitted validation location's Fbeta. Parameter sets that appear above this threshold for multiple locations will be repeated in the set.
- If validation method is none, the sampled parameters will be fit from the overall sample Fbeta. 

In [None]:
# Create dataset of the top samples

if validation_method == "loo":
    # Eliminate the sample fbeta column
    fbeta_cols = [fbeta_col for fbeta_col in fbeta_cols if "no" in fbeta_col]

top_samples = pd.DataFrame()

for fbeta_col in fbeta_cols:
    top_samples = pd.concat(
        [
            top_samples,
            (
                agg_df.loc[agg_df[f"top_{fbeta_col}"] == "top", 
                ["start", "alpha", "beta", "lamda", fbeta_col]]
                .rename(columns={fbeta_col: "fbeta"})
            )
        ]
    )

top_samples = top_samples.reset_index(drop=True)

In [None]:
# Top parameter distribution plot
ax = sns.relplot(
    x="alpha",
    y="lamda",
    col="start",
    hue="fbeta",
    palette="mako_r",
    data=top_samples,
)

plt.savefig(f"{fig_dir}/top_param_distributions.png", bbox_inches="tight")
plt.show()

### Generate a multivariate normal distribution and sampled parameters
Using the samples above your threshold, randomly sample a set of new parameter sets from their distribution. 

In [None]:
# How many distinct parameter samples do you want to generate?
n_samples = 10

In [None]:
# Fits a separate distribution per year

samples_to_run = generate_param_samples(top_samples, n_samples)

In [None]:
# Save sampled parameters to .csv as a backup/for later use

samples_to_run.to_csv(f"{stats_dir}/sampled_param_sets.csv", index=False)

Visualize the parameter distributions that will be run.

In [None]:
# Plot to visually examine the parameter posterior distributions

ax = sns.jointplot(
    x="alpha", y="lamda", hue="start", data=samples_to_run, palette="deep", alpha=0.6
)
plt.savefig(f"{fig_dir}/posterior_param_dist.png", bbox_inches="tight")
plt.show()

## Run the model forecast

First write out the commands with the new sampled parameter sets. One run will be conducted with each parameter sample.

In [None]:
commands_forecast = ""

for index, row in samples_to_run.iterrows():
    commands_forecast += write_commands(
        row, start_run=0, end_run=0, run_type="forecast"
    )

In [None]:
# # If you will run on HPC or later, write these to file

f1 = open(stats_dir + "/commands.txt", 'w')
f1.write(commands_forecast)
f1.close()

Run the cell below to execute all model runs. These must complete before you can calculate 
the summary statistics. Remember that this may take some time (approximately 2 - 5 minutes 
per run per core, depending on your computer and number of time-steps in your simulation), 
so prepare accordingly!


In [None]:
# Run model from script

for command in commands_forecast.split('\n'):
    ! {command}


These runs will write out to "outputs/{run_name}_forecast/". 

Calculate summary statistics on completed runs. This is also run in parallel, so time 
will vary depending on how many cores you use.

In [None]:
# Generate summary stats
# Note: The summary stats  may generate a "warning" from the pandas library. This should not cause any errors.

! python pandemic/get_stats.py forecast

## Review model summary statistics

You can summarize the model runs now with a single set of summary statistics. 

In [None]:
run_name = f"{sim_name}_forecast"
stats_dir = f"{out_dir}/summary_stats/{run_name}"

agg_df = pd.read_csv(f"{stats_dir}/summary_stats_bySample.csv").rename(columns=col_dict)

agg_df

In [None]:
print(
    f"Final forecast summary results: \n\n"
    f"F-beta = {round(agg_df.loc[0,'fbeta_mean'],4)}"
)

for year in sim_years:
    print(
        f"Probability of intro. to {coi} by {year}: "
        f"{round(agg_df.loc[0, [col for col in agg_df.columns if f'prob_by_{year}' in col]].values[0],4)}"
        )

## Next step: Visualize forecast

Use notebook 4 to visualize the full results of your forecast simulation. 