In [1]:
import itertools
import warnings
from time import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from src.config import BLD, SRC
from src.estimation.gridsearch import (
    get_mask_around_diagonal,
    run_1d_gridsearch,
    run_2d_gridsearch,
)
from src.estimation.msm_criterion import (
    get_index_bundles,
    get_parallelizable_msm_criterion,
)
from src.manfred.shared import hash_array
from src.plotting.msm_plots import plot_estimation_moment, plot_infection_channels
from src.simulation.load_params import load_params
from src.simulation.load_simulation_inputs import load_simulation_inputs

# -----------------------------------------------------------------------------------

DEBUG = False
prefix = "delta_untl_june_2_3_infectious"

# -----------------------------------------------------------------------------------


FALL_SIM_START = pd.Timestamp("2021-04-11")  # spring end is 2021-04-10
FALL_SIM_END = pd.Timestamp("2021-07-01")

SPRING_SIM_START = pd.Timestamp("2021-07-02")  # summer end. not used
SPRING_SIM_END = pd.Timestamp("2021-08-26")  # not used

if DEBUG:
    FALL_SIM_END = FALL_SIM_START + pd.Timedelta(days=3)
    SPRING_SIM_END = SPRING_SIM_START + pd.Timedelta(days=3)

warnings.filterwarnings(
    "ignore", message="indexing past lexsort depth may impact performance."
)

# %load_ext snakeviz

# Params

In [2]:
params = load_params("baseline")

# set delta infectiousness
delta_factor = 2.3  # 1.6 originally
params.loc[("virus_strain", "delta", "factor"), "value"] = 1.67 * delta_factor
params.loc["infection_prob"] = params.loc[["infection_prob"]] * delta_factor / 1.6


delta_params = [("events", "delta_cases_per_100_000", "2021-06-01")]
rapid_reduc_params = [
    ("rapid_test_demand", "low_incidence_factor", "other_demand"),
    ("rapid_test_demand", "low_incidence_factor", "worker_demand"),
]

free_params = delta_params + rapid_reduc_params

# Set some free parameters-----------------------------------------------------------

# params.loc[delta_params] = 0.1

# -----------------------------------------------------------------------------------

params.loc[free_params]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,value
category,subcategory,name,Unnamed: 3_level_1
events,delta_cases_per_100_000,2021-06-01,0.1
rapid_test_demand,low_incidence_factor,other_demand,0.25
rapid_test_demand,low_incidence_factor,worker_demand,0.25


# Specify Grid

In [3]:
# -----------------------------------------------------------------------------------
dimensions = "1d"

n_gridpoints = 10 if not DEBUG else 1
loc1 = delta_params
gridspec1 = (0.02, 0.1, n_gridpoints)
print(np.linspace(*gridspec1).round(2))

OUT_PATH = BLD / "figures" / "calibration" / f"{prefix}_002_to_010"
OUT_PATH.mkdir(parents=True)

loc2 = None
gridspec2 = ("low", "high", n_gridpoints)

n_seeds = 20 if not DEBUG else 1
n_cores = 40 if not DEBUG else 1
mask = None

# -----------------------------------------------------------------------------------

[0.02 0.03 0.04 0.05 0.06 0.06 0.07 0.08 0.09 0.1 ]


In [4]:
initial_states_path = (
    str(BLD / "simulations" / "last_states") + "/verify_spring_baseline_{seed}.pkl"
)

# Build the Fitness Function

In [5]:
pmsm = get_parallelizable_msm_criterion(
    prefix=prefix,
    fall_start_date=FALL_SIM_START,
    fall_end_date=FALL_SIM_END,
    spring_start_date=SPRING_SIM_START,
    spring_end_date=SPRING_SIM_END,
    mode="fall",
    debug=DEBUG,
)

# Run estimation

In [None]:
start_time = pd.Timestamp.now()
print(start_time)

if dimensions == "2d":
    results, grid, best_index, fitness_plot = run_2d_gridsearch(
        func=pmsm,
        params=params,
        loc1=loc1,
        gridspec1=gridspec1,
        loc2=loc2,
        gridspec2=gridspec2,
        n_seeds=n_seeds,
        n_cores=n_cores,
        mask=mask,
        initial_states_path=initial_states_path,
    )
else:
    results, grid, best_index, fitness_plot = run_1d_gridsearch(
        func=pmsm,
        params=params,
        loc=loc1,
        gridspec=gridspec1,
        n_seeds=n_seeds,
        n_cores=n_cores,
        initial_states_path=initial_states_path,
    )

end_time = pd.Timestamp.now()

2021-09-02 12:33:20.174670


In [None]:
print(end_time - start_time)

In [None]:
pd.to_pickle(results, OUT_PATH / "results.pkl")

# Plot Criterion Values

In [None]:
fitness_plot.savefig(OUT_PATH / "x_to_criterion.pdf")
fitness_plot

# Plot Delta Shares

In [None]:
for g, res in zip(grid, results):
    fig = plot_estimation_moment(res, "aggregated_delta_share")
    fig.savefig(OUT_PATH / f"delta_share_{g:.2f}.pdf")

In [None]:
fig = plot_estimation_moment(results[best_index], "aggregated_delta_share")

# Plot Case Numbers

In [None]:
ix = best_index
fig = plot_estimation_moment(results[ix], "aggregated_infections_not_log")
fig.savefig(OUT_PATH / f"cases_best_{grid[ix]:.2f}.pdf")
fig

In [None]:
fig = plot_estimation_moment(results[ix], "infections_by_age_group")
fig.savefig(OUT_PATH / f"cases_best_{grid[ix]:.2f}_by_age_gorup.pdf")
fig