In [1]:
import os
import shutil

from functools import partial
from time import time

import dask.dataframe as dd
import pandas as pd
import numpy as np 
import seaborn as sns
import matplotlib.pyplot as plt
from sid import get_colors
from src.config import SRC, BLD

from sid import get_simulate_func

from src.create_initial_states.create_initial_conditions import (  # noqa
    create_initial_conditions,
)
 
from src.policies.combine_policies_over_periods import get_october_to_christmas_policies
from src.policies.combine_policies_over_periods import get_enacted_policies_of_2021

from src.simulation.main_specification import load_simulation_inputs
from src.simulation.main_specification import SIMULATION_DEPENDENCIES
from src.simulation.main_specification import SCENARIO_START

from sid import get_msm_func
from src.manfred.shared import hash_array
from estimagic.batch_evaluators import joblib_batch_evaluator
from sid.msm import get_diag_weighting_matrix

from src.calculate_moments import smoothed_outcome_per_hundred_thousand_rki
from src.calculate_moments import smoothed_outcome_per_hundred_thousand_sim

SAVED_COLUMNS = {
    "initial_states": ["age_group_rki"],
    "disease_states": ["newly_infected", "infectious", "ever_infected"],
    "time": ["date"],
    "other": [
        "new_known_case",
        "virus_strain",
        "n_has_infected",
        "pending_test",
    ],
}


plt.rcParams.update(
    {
        "axes.spines.right": False,
        "axes.spines.top": False,
        "legend.frameon": False,
    }
)

print(SCENARIO_START.date())

2021-04-01


# Load the fall and spring simulation inputs

In [2]:
# fall 

FALL_START_DATE = pd.Timestamp("2020-12-13")  ### pd.Timestamp("2020-10-15")
FALL_END_DATE = pd.Timestamp("2020-12-23")
FALL_INIT_START = FALL_START_DATE - pd.Timedelta(31, unit="D")
FALL_INIT_END = FALL_START_DATE - pd.Timedelta(1, unit="D")
print(FALL_INIT_START.date(), FALL_END_DATE.date())

virus_shares, FALL_KWARGS = load_simulation_inputs(
    SIMULATION_DEPENDENCIES, FALL_INIT_START, FALL_END_DATE, extend_ars_dfs=False
)
# we don't want to pass the "old" params
FALL_KWARGS.pop("params")

FALL_INITIAL_CONDITIONS = create_initial_conditions(
    start=FALL_INIT_START,
    end=FALL_INIT_END,
    seed=344490,
    reporting_delay=5,
    virus_shares=virus_shares,
)

FALL_POLICIES = get_october_to_christmas_policies(
    contact_models=FALL_KWARGS["contact_models"], educ_multiplier=0.8)

2020-11-12 2020-12-23


In [3]:
# spring

SPRING_START_DATE = pd.Timestamp("2021-02-05")
SPRING_END_DATE = SPRING_START_DATE + pd.Timedelta(weeks=1) ### 4
SPRING_INIT_START = SPRING_START_DATE - pd.Timedelta(31, unit="D")
print(SPRING_INIT_START.date(), SPRING_END_DATE.date())
SPRING_INIT_END = SPRING_START_DATE - pd.Timedelta(1, unit="D")

virus_shares, SPRING_KWARGS = load_simulation_inputs(
    SIMULATION_DEPENDENCIES, SPRING_INIT_START, SPRING_END_DATE, extend_ars_dfs=True
)
# we don't want to pass the "old" params
SPRING_KWARGS.pop("params")

SPRING_INITIAL_CONDITIONS = create_initial_conditions(
    start=SPRING_INIT_START,
    end=SPRING_INIT_END,
    seed=3930,
    reporting_delay=5,
    virus_shares=virus_shares,
)



SPRING_POLICIES = get_enacted_policies_of_2021(
    contact_models=SPRING_KWARGS["contact_models"],
    scenario_start=SCENARIO_START,
)


2021-01-05 2021-02-12


# Build the criterion

In [4]:
def parallelizable_msm_func(params, seed, prefix):
    params_hash = hash_array(params["value"].to_numpy())
    path = SRC / "exploration" / f"{prefix}_{params_hash}_{os.getpid()}"

    calc_moments = {
        "infections_by_age_group": partial(
            smoothed_outcome_per_hundred_thousand_sim,
            outcome="new_known_case",
            groupby="age_group_rki",
        ),
    }
    
    rki_cases = pd.read_pickle(BLD / "data" / "processed_time_series" / "rki.pkl")
    age_group_info = pd.read_pickle(
        BLD / "data" / "population_structure" / "age_groups_rki.pkl"
    )

    empirical_moments = {
        "infections_by_age_group": smoothed_outcome_per_hundred_thousand_rki(
            df=rki_cases,
            outcome="newly_infected",
            groupby="age_group_rki",
            window=7,
            min_periods=1,
            group_sizes=age_group_info["n"],
        )
    }
    
    age_weights = age_group_info["weight"].to_dict()

    temp = empirical_moments["infections_by_age_group"].to_frame().copy(deep=True)
    temp["age_group"] = temp.index.get_level_values(1)
    temp["weights"] = temp["age_group"].replace(age_weights)

    weights = {"infections_by_age_group": temp["weights"]}

    weight_mat = get_diag_weighting_matrix(
        empirical_moments=empirical_moments,
        weights=weights,
    )

    msm = get_msm_func(
        simulate=partial(
            run_fall_and_spring_and_combine_their_time_series, 
            seed=seed,
            path=path,
        ),
        calc_moments=calc_moments,
        empirical_moments=empirical_moments,
        replace_nans=lambda x: x * 1,
        weighting_matrix=weight_mat,       
    )
    
    res = msm(params)
    shutil.rmtree(path)
    return res

    
    
def run_fall_and_spring_and_combine_their_time_series(params, seed, path):    
    # build the simulate functions
    fall_simulate = get_simulate_func(
        **FALL_KWARGS,
        params=params,
        contact_policies=FALL_POLICIES,
        duration={"start": FALL_START_DATE, "end": FALL_END_DATE},
        initial_conditions=FALL_INITIAL_CONDITIONS,
        path=path / "fall_part",
        seed=seed,
        saved_columns=SAVED_COLUMNS,
    )
    spring_simulate = get_simulate_func(
        **SPRING_KWARGS,
        params=params,
        contact_policies=SPRING_POLICIES,
        duration={"start": SPRING_START_DATE, "end": SPRING_END_DATE},
        initial_conditions=SPRING_INITIAL_CONDITIONS,
        path=path / "spring_part",
        seed=seed + 100_000,
        saved_columns=SAVED_COLUMNS,
    )
    
    fall_part = fall_simulate(params)
    fall_ts = fall_part["time_series"]
    fall_ts = fall_ts.set_index("date")
    
    spring_part = spring_simulate(params)
    spring_ts = spring_part["time_series"]
    spring_ts = spring_ts.set_index("date")
    
    combined_ts = dd.concat([fall_ts, spring_ts], axis=0)
    return combined_ts
    
pmsm = partial(
    parallelizable_msm_func,
    prefix="gridsearch",
)

# Params

In [5]:
params = pd.read_pickle(BLD / "params.pkl")
params.loc[("virus_strain", "base_strain", "factor")] = 1.0
params.loc[("virus_strain", "b117", "factor")] = 1.67

In [None]:
before = time()

combined_ts = run_fall_and_spring_and_combine_their_time_series(
    params=params, seed=5471, path=BLD / "test")

print(round(time() - before) / 60)

age_group_rki      0-4     5-14     15-34     35-59     60-79   80-100
demand         31098.0  68684.0  168161.0  281470.0  171230.0  35656.0
target demand      0.0      0.0       0.0       0.0       0.0      0.0
difference         inf      inf       inf       inf       inf      inf



























age_group_rki   0-4   5-14  15-34   35-59  60-79  80-100
demand         7.00  22.00  81.00  111.00  40.00    28.0
target demand  4.00  15.00  71.00   99.00  45.00    31.0
difference     0.75   0.47   0.14    0.12  -0.11    -0.1
age_group_rki   0-4   5-14   15-34   35-59  60-79  80-100
demand         5.00  28.00  118.00  158.00   48.0   37.00
target demand  3.00  12.00   58.00   81.00   37.0   26.00
difference     0.67   1.33    1.03    0.95    0.3    0.42
age_group_rki  0-4   5-14   15-34   35-59  60-79  80-100
demand         8.0  29.00  132.00  148.00  80.00    36.0
target demand  5.0  20.00   92.00  128.00  58.00    40.0
difference     0.6   0.45    0.43    0.16   0.38    -0.1
a

age_group_rki   0-4   5-14   15-34   35-59  60-79  80-100
demand         5.00  26.00  135.00  189.00  77.00    41.0
target demand  4.00  16.00   70.00  103.00  50.00    41.0
difference     0.25   0.62    0.93    0.83   0.54     0.0
age_group_rki   0-4   5-14   15-34   35-59  60-79  80-100
demand         12.0  37.00  136.00  201.00  75.00   43.00
target demand   3.0  13.00   56.00   81.00  39.00   32.00
difference      3.0   1.85    1.43    1.48   0.92    0.34
age_group_rki   0-4  5-14   15-34   35-59  60-79  80-100
demand         12.0  40.0  144.00  170.00   68.0   63.00
target demand   5.0  20.0   87.00  128.00   62.0   51.00
difference      1.4   1.0    0.66    0.33    0.1    0.24
age_group_rki  0-4   5-14   15-34   35-59  60-79  80-100
demand         7.0  43.00  122.00  164.00  73.00   67.00
target demand  5.0  21.00   91.00  134.00  64.00   53.00
difference     0.4   1.05    0.34    0.22   0.14    0.26
age_group_rki    0-4   5-14   15-34   35-59  60-79  80-100
demand         11.00 

age_group_rki    0-4   5-14   15-34   35-59   60-79  80-100
demand         14.00  50.00  216.00  291.00  127.00   95.00
target demand   6.00  26.00  116.00  173.00   85.00   76.00
difference      1.33   0.92    0.86    0.68    0.49    0.25
age_group_rki    0-4   5-14   15-34   35-59   60-79  80-100
demand         17.00  50.00  227.00  319.00  120.00   89.00
target demand   7.00  26.00  120.00  179.00   88.00   79.00
difference      1.43   0.92    0.89    0.78    0.36    0.13
age_group_rki    0-4   5-14   15-34   35-59   60-79  80-100
demand         12.00  52.00  203.00  304.00  129.00   99.00
target demand   7.00  29.00  132.00  196.00   96.00   86.00
difference      0.71   0.79    0.54    0.55    0.34    0.15
age_group_rki      0-4     5-14     15-34     35-59     60-79   80-100
demand         31354.0  68939.0  167518.0  281142.0  171513.0  35833.0
target demand      0.0      0.0       0.0       0.0       0.0      0.0
difference         inf      inf       inf       inf       inf      

age_group_rki   0-4   5-14   15-34   35-59   60-79  80-100
demand         12.0  28.00  159.00  212.00  101.00   70.00
target demand   3.0   9.00   65.00   97.00   58.00   60.00
difference      3.0   2.11    1.45    1.19    0.74    0.17
age_group_rki  0-4   5-14   15-34  35-59  60-79  80-100
demand         8.0  20.00  134.00  204.0  78.00   72.00
target demand  2.0   7.00   57.00   85.0  51.00   53.00
difference     3.0   1.86    1.35    1.4   0.53    0.36
age_group_rki  0-4   5-14  15-34   35-59  60-79  80-100
demand         7.0  18.00  142.0  198.00  97.00   70.00
target demand  2.0   7.00   49.0   78.00  48.00   51.00
difference     2.5   1.57    1.9    1.54   1.02    0.37
age_group_rki   0-4   5-14   15-34   35-59  60-79  80-100
demand         14.0  17.00  153.00  225.00  96.00   89.00
target demand   2.0   6.00   43.00   68.00  42.00   44.00
difference      6.0   1.83    2.56    2.31   1.29    1.02
age_group_rki   0-4   5-14   15-34   35-59   60-79  80-100
demand         11.0  19.0

age_group_rki  0-4   5-14   15-34   35-59  60-79  80-100
demand         8.0  19.00  102.00  152.00  66.00   50.00
target demand  2.0   4.00   29.00   44.00  28.00   26.00
difference     3.0   3.75    2.52    2.45   1.36    0.92
age_group_rki  0-4  5-14   15-34   35-59  60-79  80-100
demand         8.0  31.0  106.00  147.00  66.00   55.00
target demand  1.0   2.0   18.00   28.00  17.00   17.00
difference     7.0  14.5    4.89    4.25   2.88    2.24
age_group_rki   0-4   5-14  15-34  35-59  60-79  80-100
demand         12.0  15.00  92.00  116.0  64.00   37.00
target demand   2.0   4.00  26.00   40.0  25.00   24.00
difference      5.0   2.75   2.54    1.9   1.56    0.54
Resume the simulation...
age_group_rki    0-4   5-14  15-34   35-59   60-79  80-100
demand         15.00  56.00  209.0  327.00  119.00   83.00
target demand   7.00  29.00  131.0  196.00   96.00   86.00
difference      1.14   0.93    0.6    0.67    0.24   -0.03
age_group_rki   0-4   5-14  15-34   35-59   60-79  80-100
deman

In [None]:
n_new_cases = combined_ts["new_known_case"].groupby('date').sum().compute()

In [None]:
sns.lineplot(x=n_new_cases.index, y=n_new_cases)

# Example Usage

In [None]:
arguments = []
for seed in [3, 4]:
    seed = int(seed * 1e5)
    arguments.append({"params": params, "seed": seed})

In [None]:
before = time()

results = joblib_batch_evaluator(
    func=pmsm,
    arguments=arguments,
    n_cores=3, ### 
    unpack_symbol="**",
)

# 7.5 min on my mc
print(round(time() - before) / 60)

In [None]:
# rearrange to be able to plot
iteration = {}
iteration["empirical_moments"] = results[0]["empirical_moments"]
iteration["simulated_moments"] = {}
for key in ["infections_by_age_group"]:
    mom_list = [res["simulated_moments"][key] for res in results]
    iteration["simulated_moments"][key] = mom_list

In [None]:
def plot_msm_performance(iteration, key="infections_by_age_group"):
    """Plot the moment performance contrasting empirical and simulated moments.
    
    Args:
        iteration (dict): estimagic optimization iteration
        
    Returns:
        fig, axes
    
    """
    colors = get_colors(palette="categorical", number=2)

    fig, axes = plt.subplots(nrows=7, figsize=(8, 16), sharex=False)

    emp_mom_sr = iteration["empirical_moments"][key]

    emp_mom = _convert_to_dataframe_with_age_groups_as_columns(
        emp_mom_sr
    )
    sim_moms = iteration["simulated_moments"][key]
    mean_sim_mom = _convert_to_dataframe_with_age_groups_as_columns(
        pd.concat(sim_moms, axis=1).mean(axis=1)
    )
    emp_mom = emp_mom.loc[mean_sim_mom.index]

    age_groups = ["0-4", "5-14", "15-34", "35-59", "60-79", "80-100"]

    for age_group, ax in zip(age_groups, axes[1:]):
        for mom in sim_moms:
            mom_df = _convert_to_dataframe_with_age_groups_as_columns(mom)
            sns.lineplot(
                x=mom_df.index,
                y=mom_df[age_group],
                color=colors[0],
                alpha=0.4,
                linewidth=0.8,
                ax=ax,
            )
        sns.lineplot(
            x=mean_sim_mom.index,
            y=mean_sim_mom[age_group],
            label="simulated",
            color=colors[0],
            ax=ax,
            linewidth=2.5,
        )

        sns.lineplot(
            x=emp_mom.index,
            y=emp_mom[age_group],
            label="empirical",
            color=colors[1],
            ax=ax,
            linewidth=2.5,
        )
        ax.set_title(f"Goodness of Fit: {age_group}")
        ax.set_ylabel("Infections per 100 000")
        
    # add overall fitness plot
    ax = axes[0]
    
    age_group_info = pd.read_pickle(
        BLD / "data" / "population_structure" / "age_groups_rki.pkl"
    )
    age_weights = age_group_info["weight"]

    aggregated_emp_mom = (emp_mom * age_weights).sum(axis=1)

    mean_agg_sim_mom = (
        _convert_to_dataframe_with_age_groups_as_columns(
            pd.concat(sim_moms, axis=1).mean(axis=1)
        )
        * age_weights
    ).sum(axis=1)

    agg_sim_moms = [
        (_convert_to_dataframe_with_age_groups_as_columns(mom) * age_weights).sum(
            axis=1
        )
        for mom in sim_moms
    ]

    for mom in agg_sim_moms:

        sns.lineplot(
            x=mom.index,
            y=mom,
            color=colors[0],
            alpha=0.4,
            linewidth=0.8,
            ax=ax,
        )

    sns.lineplot(
        x=aggregated_emp_mom.index,
        y=aggregated_emp_mom,
        label="empirical",
        color=colors[1],
        ax=ax,
        linewidth=2.5,
    )

    sns.lineplot(
        x=mean_agg_sim_mom.index,
        y=mean_agg_sim_mom,
        label="simulated",
        color=colors[0],
        ax=ax,
        linewidth=2.5,
    )
    ax.set_title("Overall Goodness of Fit")
    ax.set_ylabel("Infections per 100 000")


    for ax in axes:
        ax.xaxis.set_major_locator(plt.MaxNLocator(8))

    fig.tight_layout()
    return fig, axes


def _convert_to_dataframe_with_age_groups_as_columns(sr):
    sr = sr.copy()
    sr.name = "value"
    df = sr.to_frame()
    df["date"] = list(map(lambda x: x.split("'", 2)[1], df.index))
    df["date"] = pd.to_datetime(df["date"])
    df["group"] = list(map(lambda x: x.rsplit(",", 1)[1].strip("') "), df.index))
    df.set_index(["date", "group"], inplace=True)
    return df["value"].unstack()

In [None]:
fig, ax = plot_msm_performance(iteration)