In [32]:
import functools
import os
from time import time

import numpy as np
import pandas as pd
from estimagic import minimize
from sid import get_msm_func, get_simulate_func
from sid.msm import get_diag_weighting_matrix

from src.calculate_moments import (
    smoothed_outcome_per_hundred_thousand_rki,
    smoothed_outcome_per_hundred_thousand_sim,
)
from src.config import BLD, SRC
from src.contact_models.get_contact_models import get_all_contact_models
from src.create_initial_states.create_initial_conditions import (  # noqa
    create_initial_conditions,
)
from src.manfred.minimize_manfred_estimagic import minimize_manfred_estimagic
from src.manfred.shared import hash_array
from src.plotting.plot_msm_performance import plot_msm_performance
from src.plotting.policy_gantt_chart import make_gantt_chart_of_policy_dict
from src.policies.full_policy_blocks import (
    get_german_reopening_phase,
    get_hard_lockdown,
    get_only_educ_closed,
    get_soft_lockdown,
)
from src.policies.policy_tools import combine_dictionaries

In [2]:
ESTIMATION_START = pd.Timestamp("2020-08-15")
ESTIMATION_END = pd.Timestamp("2020-12-05")
ESTIMATION_END = pd.Timestamp("2020-08-20")

INIT_START = ESTIMATION_START - pd.Timedelta(31, unit="D")
INIT_END = ESTIMATION_START - pd.Timedelta(1, unit="D")

In [3]:
initial_conditions = create_initial_conditions(
    start=INIT_START, end=INIT_END, seed=3484
)

In [4]:
contact_models = get_all_contact_models()

In [28]:
params = pd.read_pickle(BLD / "start_params.pkl")
initial_states = pd.read_parquet(BLD / "data" / "initial_states.parquet")

params.loc["infection_prob"]

hh_probs = ("infection_prob", "households", "households")
educ_models = [cm for cm in contact_models if "educ" in cm]
educ_probs = params.query(
    f"category == 'infection_prob' & subcategory in {educ_models}"
).index
work_models = [cm for cm in contact_models if "work" in cm]
work_probs = params.query(
    f"category == 'infection_prob' & subcategory in {work_models}"
).index
other_models = [cm for cm in contact_models if "other" in cm]
other_probs = params.query(
    f"category == 'infection_prob' & subcategory in {other_models}"
).index
school_models = [
    cm
    for cm in contact_models
    if "educ" in cm and "school" in cm and not "preschool" in cm
]
school_probs = params.query(
    f"category == 'infection_prob' & subcategory in {school_models}"
).index

other_educ_probs = [
    ("infection_prob", "educ_nursery_0", "educ_nursery_0"),
    ("infection_prob", "educ_preschool_0", "educ_preschool_0"),
]

params.loc[educ_probs, "value"] = 0.02
params.loc[school_probs, "value"] = 0.004
params.loc[other_probs, "value"] = 0.1
params.loc[work_probs, "value"] = 0.1
params.loc[hh_probs, "value"] = 0.2


params.loc["infection_prob"]

  result = self._run_cell(


Unnamed: 0_level_0,Unnamed: 1_level_0,value
subcategory,name,Unnamed: 2_level_1
households,households,0.2
educ_school_0,educ_school_0,0.004
educ_school_1,educ_school_1,0.004
educ_school_2,educ_school_2,0.004
educ_preschool_0,educ_preschool_0,0.02
educ_nursery_0,educ_nursery_0,0.02
work_non_recurrent,work_non_recurrent,0.1
work_recurrent_daily,work_recurrent_daily,0.1
work_recurrent_weekly_0,work_recurrent_weekly_0,0.1
work_recurrent_weekly_1,work_recurrent_weekly_1,0.1


In [6]:
def get_estimation_policies(contact_models):
    reopening_end_multipliers = {"educ": 0.8, "work": 0.6, "other": 0.7}
    to_combine = [
        get_german_reopening_phase(
            contact_models=contact_models,
            block_info={
                "start_date": "2020-04-23",
                "end_date": "2020-09-30",
                "prefix": "reopening",
            },
            start_multipliers={"educ": 0.5, "work": 0.2, "other": 0.3},
            end_multipliers=reopening_end_multipliers,
            educ_switching_date="2020-08-01",
        ),
        get_soft_lockdown(
            contact_models=contact_models,
            block_info={
                "start_date": "2020-10-01",
                "end_date": "2020-10-20",
                "prefix": "after_reopening",
            },
            multipliers=reopening_end_multipliers,
        ),
        get_soft_lockdown(
            contact_models=contact_models,
            block_info={
                "start_date": "2020-10-21",
                "end_date": "2020-11-01",
                "prefix": "anticipate_lockdown_light",
            },
            multipliers={"educ": 0.8, "work": 0.6, "other": 0.55},
        ),
        get_soft_lockdown(
            contact_models=contact_models,
            block_info={
                "start_date": "2020-11-02",
                "end_date": "2020-12-20",
                "prefix": "lockdown_light",
            },
            multipliers={"educ": 0.7, "work": 0.4, "other": 0.4},
        ),
    ]

    return combine_dictionaries(to_combine)

In [7]:
def parallelizable_msm_func(params, initial_states, initial_conditions, prefix):

    params_hash = hash_array(params["value"].to_numpy())
    path = SRC / "exploration" / f"{prefix}_{params_hash}_{os.getpid()}"

    contact_models = get_all_contact_models()

    estimation_policies = get_estimation_policies(contact_models)

    simulate = get_simulate_func(
        params=params,
        initial_states=initial_states,
        contact_models=contact_models,
        contact_policies=estimation_policies,
        duration={"start": ESTIMATION_START, "end": ESTIMATION_END},
        initial_conditions=initial_conditions,
        path=path,
        saved_columns={
            "initial_states": ["age_group_rki"],
            "disease_states": ["newly_infected"],
            "time": ["date"],
            "other": ["new_known_case"],
        },
    )

    calc_moments = {
        "infections_by_age_group": functools.partial(
            smoothed_outcome_per_hundred_thousand_sim,
            outcome="newly_infected",
            groupby="age_group_rki",
        ),
    }

    data = pd.read_pickle(BLD / "data" / "processed_time_series" / "rki.pkl")
    data = data.loc[ESTIMATION_START:ESTIMATION_END]
    age_group_info = pd.read_pickle(
        BLD / "data" / "population_structure" / "age_groups_rki.pkl"
    )

    empirical_moments = {
        "infections_by_age_group": smoothed_outcome_per_hundred_thousand_rki(
            df=data,
            outcome="newly_infected",
            groupby="age_group_rki",
            window=7,
            min_periods=1,
            group_sizes=age_group_info["n"],
        )
        * 4
    }

    age_weights = age_group_info["weight"].to_dict()

    temp = empirical_moments["infections_by_age_group"].to_frame().copy(deep=True)
    temp["age_group"] = temp.index.get_level_values(1)
    temp["weights"] = temp["age_group"].replace(age_weights)

    weights = {"infections_by_age_group": temp["weights"]}

    weight_mat = get_diag_weighting_matrix(
        empirical_moments=empirical_moments,
        weights=weights,
    )

    def simulate_wrapper(params, simulate):
        return simulate(params)["time_series"]

    msm = get_msm_func(
        simulate=functools.partial(simulate_wrapper, simulate=simulate),
        calc_moments=calc_moments,
        empirical_moments=empirical_moments,
        replace_nans=lambda x: x * 1,
        weighting_matrix=weight_mat,
    )

    return msm(params)


pmsm = functools.partial(
    parallelizable_msm_func,
    initial_states=initial_states,
    initial_conditions=initial_conditions,
    prefix="gridsearch",
)

In [9]:
msm_res = pmsm(params)

Start the simulation...
2020-08-20: 100%|██████████| 6/6 [01:20<00:00, 13.34s/it]


In [15]:
msm_res["root_contributions"].to_numpy()

array([-0.18855007,  1.40796903,  0.14586316,  0.0749582 , -0.19440492,
       -0.04399734, -0.20213811,  1.52336832,  0.28286716,  0.17561829,
       -0.19440492, -0.04399734, -0.18824108,  1.56187129,  0.32768411,
        0.20262782, -0.19440492, -0.04399734, -0.18824108,  1.56187129,
        0.32768411,  0.20262782, -0.19440492, -0.04399734, -0.17459124,
        1.67037231,  0.47347946,  0.27813292, -0.19440492, -0.04399734,
       -0.15411648,  1.91545837,  0.71924021,  0.46089234, -0.19440492,
       -0.04399734])

In [None]:
plot_msm_performance(msm_res)

In [29]:
constraints = [
    {"query": "category != 'infection_prob'", "type": "fixed"},
    {"loc": other_educ_probs, "type": "equality"},
    {"loc": other_probs, "type": "equality"},
    {"loc": school_probs, "type": "equality"},
    {"loc": work_probs, "type": "equality"},
]

In [34]:
algo_options = {
    "step_sizes": [0.1, 0.05, 0.02],
    "max_step_sizes": [0.3, 0.2, 0.2],
    "linesearch_n_points": 12,
    "gradient_weight": 0.5,
    "noise_n_evaluations_per_x": [50, 90, 120],
    "convergence_relative_params_tolerance": 0.001,
    "direction_window": 3,
    "batch_evaluator_options": {"n_cores": 12},
}

params.loc["infection_prob", "lower_bound"] = 0.002
params.loc["infection_prob", "upper_bound"] = 0.15
params.loc[educ_probs, "upper_bound"] = 0.05
params.loc[hh_probs, "upper_bound"] = 0.25

res = minimize(
    criterion=pmsm,
    params=params,
    algorithm=minimize_manfred_estimagic,
    algo_options=algo_options,
    logging="first_manfred_attempt.db",
    constraints=contstaints,
)

  result = self._run_cell(


Unnamed: 0_level_0,Unnamed: 1_level_0,value,lower_bound,upper_bound
subcategory,name,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
households,households,0.2,0.002,0.25
educ_school_0,educ_school_0,0.004,0.002,0.05
educ_school_1,educ_school_1,0.004,0.002,0.05
educ_school_2,educ_school_2,0.004,0.002,0.05
educ_preschool_0,educ_preschool_0,0.02,0.002,0.05
educ_nursery_0,educ_nursery_0,0.02,0.002,0.05
work_non_recurrent,work_non_recurrent,0.1,0.002,0.15
work_recurrent_daily,work_recurrent_daily,0.1,0.002,0.15
work_recurrent_weekly_0,work_recurrent_weekly_0,0.1,0.002,0.15
work_recurrent_weekly_1,work_recurrent_weekly_1,0.1,0.002,0.15
