# Model run: Forecast 
Use this notebook to run and evaluate a parameter grid-search. 

This notebook can be run after 0, 1, 2, and 3b. We recommend also running 3a first, to check for and troubleshoot issues.

In [None]:
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import dotenv
import os 
import json 

In [None]:
os.chdir('../')

In [None]:
from pandemic.multirun_helpers import write_commands, generate_param_samples
# import summary stats run 

In [None]:
# Load variables and paths from .env
dotenv.load_dotenv('.env')

# Read environmental variables
input_dir = os.getenv('INPUT_PATH')
out_dir = os.getenv('OUTPUT_PATH')


### Using the summary stats from the grid search to fit distribution

In [None]:
with open('config.json') as json_file:
    config = json.load(json_file)

sim_name = config['sim_name']

run_name = f"{sim_name}_calibrate"
total_runs = config["run_count"] 

stats_dir = f"{out_dir}/summary_stats/{run_name}"

In [None]:
col_dict = {"start_max":"start","alpha_max":"alpha","beta_max":"beta",
    "lamda_max":"lamda","count_known_countries_time_window_fbeta_mean":"fbeta"}

agg_df = (
    pd.read_csv(f"{stats_dir}/summary_stats_bySample.csv")
    .rename(columns=col_dict)
    )

### Exploring possible quantile thresholds for fbeta

In [None]:
# Set a % threshold (0 - 100) - adjust based on the below plots

quant_threshold = 90

In [None]:
count_vals = []
min_fbeta = []

for val in range(70,100):
    subset = agg_df.loc[agg_df['fbeta']>=agg_df['fbeta'].quantile(val/100)]
    count_vals.append(len(subset.index))
    min_fbeta.append(subset['fbeta'].min())
    

In [None]:
sample_stats = pd.DataFrame(
    {"quantile":range(70,100), 
    "count":count_vals, 
    "min_fbeta":min_fbeta}
    ).set_index("quantile")
    

Visual: How many samples and what Fbeta scores are captured with each threshold?

In [None]:
fig, (ax1, ax2, ) = plt.subplots(1, 2, figsize=(10, 5))
sample_stats["count"].plot(ax = ax1)
ax1.vlines(quant_threshold, ymin=sample_stats["count"].min(), ymax=sample_stats["count"].max(), linestyle='dashed', color="firebrick")
ax1.set_title("Count")

sample_stats["min_fbeta"].plot(ax = ax2)
ax2.vlines(quant_threshold, ymin=sample_stats["min_fbeta"].min(), ymax=sample_stats["min_fbeta"].max(), linestyle='dashed', color="firebrick")
ax2.set_title("Fbeta")

plt.show()

Visual: What do the distributions of alpha and lamda look like with that threshold?

In [None]:
agg_df['top']=np.where(agg_df['fbeta']>=agg_df['fbeta'].quantile(quant_threshold/100),'top','low')

In [None]:
# Alpha by year

ax = sns.relplot(x="alpha",y="fbeta", col="start",hue="top",palette="rocket",data=agg_df,edgecolor="black",linewidth=0.5,s=100)
plt.show()

In [None]:
# Lamda by year

ax = sns.relplot(x="lamda",y="fbeta", col="start",hue="top",palette="rocket",data=agg_df,edgecolor="black",linewidth=0.5,s=100)
plt.show()

In [None]:
# Beta by year

ax = sns.relplot(x="beta",y="fbeta", col="start",hue="top",palette="rocket",data=agg_df,edgecolor="black",linewidth=0.5,s=100)
plt.show()

In [None]:
# Top parameter distribution plot

ax = sns.relplot(x="alpha", y="lamda", col="start", hue="fbeta", palette="mako_r", data=agg_df.loc[agg_df['top']=="top"])
plt.show()

### Generating the multivariate normal distribution and sampled parameters

In [None]:
# How many distinct parameter samples do you want to generate?
n_samples = 1000


In [None]:
# Fits a separate distribution per year 

samples_to_run = generate_param_samples(agg_df, n_samples)


In [None]:
# Save sampled parameters to .csv as a backup/for later use

samples_to_run.to_csv(f"{stats_dir}/sampled_param_sets.csv")


Visualize the parameter distributions that will be run.

In [None]:
# Plot to visually examine the parameter posterior distributions

ax = sns.jointplot(x="alpha", y="lamda", hue="start", data=samples_to_run, palette="deep", alpha = 0.4)
plt.show()


## Writing out sampled parameters to runs

In [None]:
commands_forecast = ""

for index, row in samples_to_run.iterrows():
    commands_forecast += write_commands(row, start_run = 0, end_run = 0, run_type = "forecast")



In [None]:
# # If you will run on HPC or later, write these to file 

# f1 = open(stats_dir + "/commands.txt", 'w')
# f1.write(commands_forecast)
# f1.close()


## Run forecast with sampled parameters

In [None]:
# Run model here
for command in commands_forecast.split('\n'):
    ! {command}

# Write to outdir/run_name + _forecast

In [None]:
# Generate summary stats

# Update summary stats script:

# If path ends in _forecast
# Agg on run rather than sample 

## Review model summary statistics

In [None]:
# Read the csv here

## Next step: Visualize forecast

Use notebook 4 to visualize the results of your forecast simulation. 