In [None]:
import itertools
import warnings
from time import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from src.config import BLD, SRC
from src.estimation.gridsearch import (
    get_mask_around_diagonal,
    run_1d_gridsearch,
    run_2d_gridsearch,
)
from src.estimation.msm_criterion import (
    get_index_bundles,
    get_parallelizable_msm_criterion,
)
from src.manfred.shared import hash_array
from src.plotting.msm_plots import plot_estimation_moment, plot_infection_channels
from src.simulation.load_params import load_params
from src.simulation.load_simulation_inputs import load_simulation_inputs

DEBUG = False
FALL_SIM_START = pd.Timestamp("2020-09-15")
FALL_SIM_END = pd.Timestamp("2020-12-22")
SPRING_SIM_START = pd.Timestamp("2021-02-10")
SPRING_SIM_END = pd.Timestamp("2021-05-14")
if DEBUG:
    FALL_SIM_END = FALL_SIM_START + pd.Timedelta(days=5)
    SPRING_SIM_END = SPRING_SIM_START + pd.Timedelta(days=5)

warnings.filterwarnings(
    "ignore", message="indexing past lexsort depth may impact performance."
)
%load_ext snakeviz

# Load the simulation inputs

In [None]:
params = load_params("baseline")

pmsm = get_parallelizable_msm_criterion(
    prefix="gridsearch",
    fall_start_date=FALL_SIM_START,
    fall_end_date=FALL_SIM_END,
    spring_start_date=SPRING_SIM_START,
    spring_end_date=SPRING_SIM_END,
    mode="combined",
    debug=DEBUG,
)

# Modify Params

In [None]:
index_bundles = get_index_bundles(params)
hh_probs = index_bundles["hh"]
school_probs = index_bundles["school"]
young_educ_probs = index_bundles["young_educ"]
work_probs = index_bundles["work"]
other_probs = index_bundles["other"]

vacation_probs = params.query("category == 'additional_other_vacation_contact'").index
free_probs = [bundle[0] for bundle in list(index_bundles.values()) + [vacation_probs]]

In [None]:
params.loc[other_probs, "value"] = 0.15875
params.loc[young_educ_probs, "value"] = 0.005
params.loc[school_probs, "value"] = 0.01
params.loc[work_probs, "value"] = 0.1475
params.loc[hh_probs, "value"] = 0.095
params.loc[vacation_probs, "value"] = 0.5
params.loc[free_probs]

In [None]:
params

# Run estimation

In [None]:
dimensions = "1d"

n_gridpoints = 1
loc1 = other_probs
gridspec1 = (0.15875, 0.15875, n_gridpoints)
# only used if 2d
loc2 = work_probs
gridspec2 = (0.14, 0.16, n_gridpoints)

n_seeds = 20
n_cores = 20

mask = None

In [None]:
if dimensions == "2d":
    results, grid, best_index, fig = run_2d_gridsearch(
        func=pmsm,
        params=params,
        loc1=loc1,
        gridspec1=gridspec1,
        loc2=loc2,
        gridspec2=gridspec2,
        n_seeds=n_seeds,
        n_cores=n_cores,
        mask=mask,
    )
else:
    results, grid, best_index, fig = run_1d_gridsearch(
        func=pmsm,
        params=params,
        loc=loc1,
        gridspec=gridspec1,
        n_seeds=n_seeds,
        n_cores=n_cores,
    )

In [None]:
pd.to_pickle(results, "results.pkl")

In [None]:
fig

In [None]:
ix = best_index
plot_estimation_moment(results[ix], "aggregated_infections_not_log")

In [None]:
plot_estimation_moment(results[ix], "aggregated_infections_not_log")

In [None]:
plot_infection_channels(results[ix], aggregate=True)

In [None]:
grid[best_index]

In [None]:
best_index

In [None]:
np.mean([res["value"] for res in results[best_index]])
# 193