In [17]:
import os
import shutil

from functools import partial
from time import time

import dask.dataframe as dd
import pandas as pd
import numpy as np 
import seaborn as sns
import matplotlib.pyplot as plt
from sid import get_colors
from src.config import SRC, BLD

from sid import get_simulate_func

from src.create_initial_states.create_initial_conditions import (  # noqa
    create_initial_conditions,
)
 
from src.policies.combine_policies_over_periods import get_october_to_christmas_policies
from src.policies.combine_policies_over_periods import get_enacted_policies_of_2021

from src.simulation.main_specification import load_simulation_inputs
from src.simulation.main_specification import SIMULATION_DEPENDENCIES
from src.simulation.main_specification import SCENARIO_START

from sid import get_msm_func
from src.manfred.shared import hash_array
from estimagic.batch_evaluators import joblib_batch_evaluator
from sid.msm import get_diag_weighting_matrix

from src.calculate_moments import smoothed_outcome_per_hundred_thousand_rki
from src.calculate_moments import smoothed_outcome_per_hundred_thousand_sim

from src.policies.policy_tools import combine_dictionaries

SAVED_COLUMNS = {
    "initial_states": ["age_group_rki"],
    "disease_states": ["newly_infected", "infectious", "ever_infected"],
    "time": ["date"],
    "other": [
        "new_known_case",
        "virus_strain",
        "n_has_infected",
        "pending_test",
    ],
}


plt.rcParams.update(
    {
        "axes.spines.right": False,
        "axes.spines.top": False,
        "legend.frameon": False,
    }
)

print(SCENARIO_START.date())

2021-04-06


# Load the fall and spring simulation inputs

In [21]:
# spring

spring_start_date = pd.Timestamp("2021-02-05")
spring_end_date = spring_start_date + pd.Timedelta(weeks=1) ### 4
spring_init_start = spring_start_date - pd.Timedelta(31, unit="D")
print(spring_init_start.date(), spring_end_date.date())
spring_init_end = spring_start_date - pd.Timedelta(1, unit="D")

virus_shares, spring_kwargs = load_simulation_inputs(
    SIMULATION_DEPENDENCIES, spring_init_start, spring_end_date, extend_ars_dfs=False
)
# we don't want to parse the "old" params
spring_kwargs.pop("params")

full_spring_simulate_inputs = {
    "duration": {"start": spring_start_date, "end": spring_end_date},
    "initial_conditions": create_initial_conditions(
        start=spring_init_start,
        end=spring_init_end,
        seed=3930,
        reporting_delay=5,
        virus_shares=virus_shares,
    ),
    "contact_policies": get_enacted_policies_of_2021(
        contact_models=spring_kwargs["contact_models"],
        scenario_start=SCENARIO_START,
    ),
    "saved_columns": SAVED_COLUMNS,
    **spring_kwargs
}

2021-01-05 2021-02-12


In [22]:
# fall 

fall_start_date = pd.Timestamp("2020-12-13") ### pd.Timestamp("2020-10-15")
fall_init_start = fall_start_date - pd.Timedelta(31, unit="D")
fall_end_date = pd.Timestamp("2020-12-23")
print(fall_init_start.date(), fall_end_date.date())
fall_init_end = fall_start_date - pd.Timedelta(1, unit="D")

virus_shares, fall_kwargs = load_simulation_inputs(
    SIMULATION_DEPENDENCIES, fall_init_start, fall_end_date, extend_ars_dfs=False
)
# we don't want to parse the "old" params
fall_kwargs.pop("params")


full_fall_simulate_inputs = {
    "duration": {"start": fall_start_date, "end": fall_end_date},
    "initial_conditions": create_initial_conditions(
        start=fall_init_start,
        end=fall_init_end,
        seed=344490,
        reporting_delay=5,
        virus_shares=virus_shares,
    ),
    "contact_policies": get_october_to_christmas_policies(
        contact_models=fall_kwargs["contact_models"], educ_multiplier=0.8
    ),
    "saved_columns": SAVED_COLUMNS,
    **fall_kwargs
}

2020-11-12 2020-12-23


# Build the criterion

In [23]:
def _simulate_wrapper(params, simulate):
    return simulate(params)["time_series"]


def build_and_evaluate_msm_func(
    params, seed, prefix, simulate_kwargs
):
    params_hash = hash_array(params["value"].to_numpy())
    path = SRC / "exploration" / f"{prefix}_{params_hash}_{os.getpid()}"

    simulate = get_simulate_func(
        **simulate_kwargs,
        params=params,
        path=path / "fall_part",
        seed=seed,
    )

    calc_moments = {
        "infections_by_age_group": functools.partial(
            smoothed_outcome_per_hundred_thousand_sim,
            outcome="new_known_case",
            groupby="age_group_rki",
        ),
    }

    # empirical moments
    rki_data = pd.read_pickle(
        BLD / "data" / "processed_time_series" / "rki.pkl")
    rki_data = rki_data.loc[start_date:end_date]
    age_group_info = pd.read_pickle(
        BLD / "data" / "population_structure" / "age_groups_rki.pkl"
    )
    empirical_moments = {
        "infections_by_age_group": smoothed_outcome_per_hundred_thousand_rki(
            df=rki_data,
            outcome="newly_infected",
            groupby="age_group_rki",
            window=7,
            min_periods=1,
            group_sizes=age_group_info["n"],
        )
    }
    
    # weighting matrix
    age_weights = age_group_info["weight"].to_dict()

    temp = empirical_moments["infections_by_age_group"].to_frame().copy(deep=True)
    temp["age_group"] = temp.index.get_level_values(1)
    temp["weights"] = temp["age_group"].replace(age_weights)

    weights = {"infections_by_age_group": temp["weights"]}

    weight_mat = get_diag_weighting_matrix(
        empirical_moments=empirical_moments,
        weights=weights,
    )

    msm_func = get_msm_func(
        simulate=partial(_simulate_wrapper, simulate=simulate),
        calc_moments=calc_moments,
        empirical_moments=empirical_moments,
        replace_nans=lambda x: x * 1,
        weighting_matrix=weight_mat,     
    )

    res = msm(params)
    shutil.rmtree(path)
    return res
    
    
# params, seed, prefix, simulate_kwargs, policies, start_date, 
# end_date, initial_conditions

    
def build_and_evaluate_combined_msm_func(
    params, 
    seed, 
    prefix, 
    full_fall_simulate_inputs, 
    full_spring_simulate_inputs,
    saved_columns,
):    
    fall_msm_res = build_and_evaluate_msm_func(
        params=params, 
        seed=seed, 
        prefix=prefix + "_fall", 
        simulate_kwargs=full_fall_simulate_inputs,
    )
    spring_msm_res = build_and_evaluate_msm_func(
        params=params, 
        seed=seed + 100_000, 
        prefix=prefix + "_spring", 
        simulate_kwargs=full_spring_simulate_inputs,
    )

    # get weights of parts
    length_fall = _get_period_length(full_fall_simulate_inputs)
    length_spring = _get_period_length(full_spring_simulate_inputs)
    combined_length = length_fall + length_spring
    fall_weight = length_fall / combined_length
    spring_weight = length_spring / combined_length
    
    # combine results
    weighted_value = \
        fall_weight * fall_msm_res["value"] \
        + spring_weight * spring_msm_res["value"]
    
    emp_mom_list = [fall_msm_res["empirical_moments"], spring_msm_res["empirical_moments"]]
    sim_mom_list = [fall_msm_res["simulated_moments"], spring_msm_res["simulated_moments"]]
   
    res = {
        "value": weighted_value,
        "empirical_moments": emp_mom_list,
        "simulated_moments": sim_mom_list,
    }
    return res
   
    
def _get_period_length(inputs):
    return inputs["end_date"] - inputs["start_date"]
    
    
pmsm = partial(
    build_and_evaluate_combined_msm_func,
    prefix="gridsearch",
)

# Params

In [24]:
params = pd.read_pickle(BLD / "params.pkl")
params.loc[("virus_strain", "base_strain", "factor")] = 1.0
params.loc[("virus_strain", "b117", "factor")] = 1.67

In [None]:
before = time()

combined_ts = build_and_evaluate_combined_msm_func(
    params=params, seed=5471, prefix="test", 
    full_fall_simulate_inputs=full_fall_simulate_inputs,
    full_spring_simulate_inputs=full_spring_simulate_inputs,
    saved_columns=SAVED_COLUMNS
)

print(round(time() - before) / 60)


Too much endogenous test demand on 2020-11-16 (Monday). This is an indication that the share of symptomatic infections is too high orthat too many symptomatic people demand a test:

age_group_rki  0-4   5-14   15-34   35-59  60-79  80-100
demand         6.0  31.00   91.00  101.00  28.00   17.00
target demand  6.0  23.00  105.00  146.00  66.00   46.00
difference     0.0   0.35   -0.13   -0.31  -0.58   -0.63



Too much endogenous test demand on 2020-11-17 (Tuesday). This is an indication that the share of symptomatic infections is too high orthat too many symptomatic people demand a test:

age_group_rki   0-4   5-14   15-34   35-59  60-79  80-100
demand         9.00  29.00  124.00  144.00  23.00   25.00
target demand  4.00  15.00   71.00   99.00  45.00   31.00
difference     1.25   0.93    0.75    0.45  -0.49   -0.19



Too much endogenous test demand on 2020-11-18 (Wednesday). This is an indication that the share of symptomatic infections is too high orthat too many symptomatic people


Too much endogenous test demand on 2020-12-02 (Wednesday). This is an indication that the share of symptomatic infections is too high orthat too many symptomatic people demand a test:

age_group_rki    0-4   5-14   15-34   35-59  60-79  80-100
demand         11.00  47.00  173.00  222.00  73.00   74.00
target demand   3.00  15.00   68.00   99.00  50.00   45.00
difference      2.67   2.13    1.54    1.24   0.46    0.64



Too much endogenous test demand on 2020-12-03 (Thursday). This is an indication that the share of symptomatic infections is too high orthat too many symptomatic people demand a test:

age_group_rki   0-4   5-14   15-34   35-59   60-79  80-100
demand         10.0  37.00  157.00  232.00  106.00   58.00
target demand   5.0  22.00  102.00  148.00   76.00   67.00
difference      1.0   0.68    0.54    0.57    0.39   -0.13



Too much endogenous test demand on 2020-12-04 (Friday). This is an indication that the share of symptomatic infections is too high orthat too many sympt

In [None]:
n_new_cases = combined_ts["new_known_case"].groupby('date').sum().compute()

In [None]:
sns.lineplot(x=n_new_cases.index, y=n_new_cases)

In [16]:
fall_kwargs.keys()

dict_keys(['testing_demand_models', 'testing_allocation_models', 'testing_processing_models', 'initial_states', 'contact_models', 'susceptibility_factor_model', 'virus_strains'])

# Example Usage

In [None]:
arguments = []
for seed in [3, 4]:
    seed = int(seed * 1e5)
    arguments.append({"params": params, "seed": seed})

In [None]:
before = time()

results = joblib_batch_evaluator(
    func=pmsm,
    arguments=arguments,
    n_cores=3, ### 
    unpack_symbol="**",
)

# 7.5 min on my mc
print(round(time() - before) / 60)

In [None]:
# rearrange to be able to plot
iteration = {}
iteration["empirical_moments"] = results[0]["empirical_moments"]
iteration["simulated_moments"] = {}
for key in ["infections_by_age_group"]:
    mom_list = [res["simulated_moments"][key] for res in results]
    iteration["simulated_moments"][key] = mom_list

In [None]:
def plot_msm_performance(iteration, key="infections_by_age_group"):
    """Plot the moment performance contrasting empirical and simulated moments.
    
    Args:
        iteration (dict): estimagic optimization iteration
        
    Returns:
        fig, axes
    
    """
    colors = get_colors(palette="categorical", number=2)

    fig, axes = plt.subplots(nrows=7, figsize=(8, 16), sharex=False)

    emp_mom_sr = iteration["empirical_moments"][key]

    emp_mom = _convert_to_dataframe_with_age_groups_as_columns(
        emp_mom_sr
    )
    sim_moms = iteration["simulated_moments"][key]
    mean_sim_mom = _convert_to_dataframe_with_age_groups_as_columns(
        pd.concat(sim_moms, axis=1).mean(axis=1)
    )
    emp_mom = emp_mom.loc[mean_sim_mom.index]

    age_groups = ["0-4", "5-14", "15-34", "35-59", "60-79", "80-100"]

    for age_group, ax in zip(age_groups, axes[1:]):
        for mom in sim_moms:
            mom_df = _convert_to_dataframe_with_age_groups_as_columns(mom)
            sns.lineplot(
                x=mom_df.index,
                y=mom_df[age_group],
                color=colors[0],
                alpha=0.4,
                linewidth=0.8,
                ax=ax,
            )
        sns.lineplot(
            x=mean_sim_mom.index,
            y=mean_sim_mom[age_group],
            label="simulated",
            color=colors[0],
            ax=ax,
            linewidth=2.5,
        )

        sns.lineplot(
            x=emp_mom.index,
            y=emp_mom[age_group],
            label="empirical",
            color=colors[1],
            ax=ax,
            linewidth=2.5,
        )
        ax.set_title(f"Goodness of Fit: {age_group}")
        ax.set_ylabel("Infections per 100 000")
        
    # add overall fitness plot
    ax = axes[0]
    
    age_group_info = pd.read_pickle(
        BLD / "data" / "population_structure" / "age_groups_rki.pkl"
    )
    age_weights = age_group_info["weight"]

    aggregated_emp_mom = (emp_mom * age_weights).sum(axis=1)

    mean_agg_sim_mom = (
        _convert_to_dataframe_with_age_groups_as_columns(
            pd.concat(sim_moms, axis=1).mean(axis=1)
        )
        * age_weights
    ).sum(axis=1)

    agg_sim_moms = [
        (_convert_to_dataframe_with_age_groups_as_columns(mom) * age_weights).sum(
            axis=1
        )
        for mom in sim_moms
    ]

    for mom in agg_sim_moms:

        sns.lineplot(
            x=mom.index,
            y=mom,
            color=colors[0],
            alpha=0.4,
            linewidth=0.8,
            ax=ax,
        )

    sns.lineplot(
        x=aggregated_emp_mom.index,
        y=aggregated_emp_mom,
        label="empirical",
        color=colors[1],
        ax=ax,
        linewidth=2.5,
    )

    sns.lineplot(
        x=mean_agg_sim_mom.index,
        y=mean_agg_sim_mom,
        label="simulated",
        color=colors[0],
        ax=ax,
        linewidth=2.5,
    )
    ax.set_title("Overall Goodness of Fit")
    ax.set_ylabel("Infections per 100 000")


    for ax in axes:
        ax.xaxis.set_major_locator(plt.MaxNLocator(8))

    fig.tight_layout()
    return fig, axes


def _convert_to_dataframe_with_age_groups_as_columns(sr):
    sr = sr.copy()
    sr.name = "value"
    df = sr.to_frame()
    df["date"] = list(map(lambda x: x.split("'", 2)[1], df.index))
    df["date"] = pd.to_datetime(df["date"])
    df["group"] = list(map(lambda x: x.rsplit(",", 1)[1].strip("') "), df.index))
    df.set_index(["date", "group"], inplace=True)
    return df["value"].unstack()

In [None]:
fig, ax = plot_msm_performance(iteration)