In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from src.config import BLD, SRC
from src.create_initial_states.task_build_full_params import (
    _convert_index_to_int_where_possible,
)
from src.simulation.plotting import plot_incidences, style_plot

In [2]:
states = pd.read_parquet(BLD / "data" / "initial_states.parquet")
states["date"] = pd.Timestamp("2021-03-01")
states["infectious"] = np.random.choice([True, False], size=len(states), p=[0.99, 0.01])
states["symptomatic"] = np.random.choice(
    [True, False], size=len(states), p=[0.99, 0.01]
)
states["cd_infectious_true"] = np.random.choice(
    [-104, 0, -1, 3], size=len(states), p=[0.97, 0.01, 0.01, 0.01]
)
states["pending_test"] = np.random.choice(
    [True, False], size=len(states), p=[0.03, 0.97]
)
states["knows_immune"] = np.random.choice(
    [True, False], size=len(states), p=[0.07, 0.93]
)

contacts = pd.Series(1, index=states.index)

# Reduce Contacts on Condition

- 35.5% go to reduce_recurrent_model
- in reduce_recurrent_model it looks as if 85% are on the last line of the docstring. That's strange.

In [3]:
def reduce_contacts_on_condition(
    contacts, states, multiplier, condition, seed, is_recurrent
):
    """Reduce contacts for share of population for which condition is fulfilled.

    The subset of contacts for which contacts are reduced is specified by the condition
    and whoever has a positive number of contacts. Then, a share of individuals in the
    subset is sampled and the contacts are set to 0.

    Args:
        contacts (pandas.Series): The series with contacts.
        states (pandas.DataFrame): The states of one day passed by sid.
        multiplier (float): The share of people who maintain their contacts
            despite condition.
        condition (str): Condition which defines the subset of individuals who
            potentially reduce their contacts.
        seed (int)

    """
    np.random.seed(seed)
    if is_recurrent:
        reduced = reduce_recurrent_model(states, contacts, seed, multiplier)
    else:
        reduced = multiplier * contacts
    is_condition_true = states.eval(condition)
    reduced = reduced.where(is_condition_true, contacts)
    return reduced


def reduce_recurrent_model(states, contacts, seed, multiplier):
    """Reduce the number of recurrent contacts taking place by a multiplier.

    For recurrent contacts only whether the contacts Series is > 0 plays a role.
    Therefore, simply multiplying the number of contacts with it would not have
    an effect on the number of contacts taking place. Instead we make a random share of
    individuals scheduled to participate not participate.

    This function returns a Series of 0s and 1s.

    Args:
        multiplier (float or pd.Series): Must be smaller or equal to one. If a
            Series is supplied the index must be dates.

    """
    np.random.seed(seed)
    if isinstance(multiplier, pd.Series):
        date = get_date(states)
        multiplier = multiplier[date]

    contacts = contacts.to_numpy()
    resampled_contacts = np.random.choice(
        [1, 0], size=len(states), p=[multiplier, 1 - multiplier]
    )
    reduced = np.where(contacts > 0, resampled_contacts, contacts)
    return pd.Series(reduced, index=states.index)

In [4]:
%%timeit

reduce_contacts_on_condition(
    contacts,
    states,
    multiplier=0.5,
    condition="occupation == 'working'",
    seed=99,
    is_recurrent=False,
)

28.5 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
%%timeit

reduce_contacts_on_condition(
    contacts,
    states,
    multiplier=0.5,
    condition="occupation == 'working'",
    seed=99,
    is_recurrent=True,
)

46.5 ms ± 314 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
%%timeit

reduce_recurrent_model(states, contacts, 323, 0.5)

23.4 ms ± 428 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Find Educ Workers With Zero Students

- irgendwie sind die falschen Zeilen bei "find_educ_workers_with_zero_students" in der Textdatei gelandet, 
  die gar nicht zu der Funktion gehören 
  
  -> ich hab das Notebook neu gestartet 
  
- in _find_size_zero_classes gehen 80% auf die class_sizes = ...`.groupby`Zeile

In [7]:
def _find_educ_workers_with_zero_students(contacts, states, group_id_column):
    """Return educ_workers whose classes / groups don't have any children in them.

    Returns:
        has_no_class (pandas.Series): boolean Series with the
            same index as states. True for educ_workers whose classes / groups
            don't have any children in them.

    """
    size_0_classes = _find_size_zero_classes(contacts, states, group_id_column)
    has_no_class = states["educ_worker"] & states[group_id_column].isin(size_0_classes)
    return has_no_class


def _find_size_zero_classes(contacts, states, col):
    students_group_ids = states[col][~states["educ_worker"]]
    students_contacts = contacts[~states["educ_worker"]]
    # the .drop(-1) is needed because we use -1 instead of NaN to identify
    # individuals not participating in a recurrent contact model
    class_sizes = students_contacts.groupby(students_group_ids).sum().drop(-1)
    size_zero_classes = class_sizes[class_sizes == 0].index
    return size_zero_classes

In [8]:
school_contacts = states["occupation"].isin(["school", "school_teacher"])

In [9]:
%%timeit

_find_educ_workers_with_zero_students(
    contacts=school_contacts, states=states, group_id_column="school_group_id_1"
)

44.3 ms ± 499 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%%timeit

_find_size_zero_classes(
    contacts=school_contacts, states=states, col="school_group_id_2"
)

43 ms ± 568 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Reduce Work Model

- 1/3 went into asserting that the Länder in the states and thresholds fit together 

    $\Rightarrow$ I moved that to check_initial_states and only check that there are no NaN after `.map`
    and replaced it with a check that there are no NaN after the .map operation. That takes 8% of the time
    and could be removed.
    
    
$\Rightarrow$ `.where` statements take ~40% of the time. the `.map` statement 30%.

In [11]:
from sid.time import get_date


def reduce_work_model(states, contacts, seed, multiplier, is_recurrent):  # noqa: U100
    """Reduce contacts for the working population.

    Args:
        multiplier (float, pandas.Series, pandas.DataFrame):
            share of workers that have work contacts.
            If it is a Series or DataFrame, the index must be dates.
            If it is a DataFrame the columns must be the values of
            the "state" column in the states.
        is_recurrent (bool): True if the contact model is recurernt

    """
    if isinstance(multiplier, (pd.Series, pd.DataFrame)):
        date = get_date(states)
        multiplier = multiplier.loc[date]

    msg = f"Work multiplier not in [0, 1] on {get_date(states)}"
    if isinstance(multiplier, (float, int)):
        assert 0 <= multiplier <= 1, msg
    else:
        assert (multiplier >= 0).all(), msg
        assert (multiplier <= 1).all(), msg

    threshold = 1 - multiplier
    if isinstance(threshold, pd.Series):
        threshold = states["state"].map(threshold.get)
        # this assert could be skipped because we check in
        # task_check_initial_states that the federal state names overlap.
        assert threshold.notnull().all()

    above_threshold = states["work_contact_priority"] > threshold

    if not is_recurrent:
        reduced_contacts = contacts.where(above_threshold, 0)
    if is_recurrent:
        reduced_contacts = contacts.where(above_threshold, False)
    return reduced_contacts

In [12]:
%%timeit

reduce_work_model(
    states=states, contacts=contacts, seed=111, multiplier=0.4, is_recurrent=True
)

37.3 ms ± 283 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Demand Test

- 70% of time goes to `_scale_demand_up_or_down`

- of `_scale_demand_up_or_down` >90% go to `_decrease_test_demand`

- of `_decrease_test_demand`: 93.7% go to `states[demanded].query(f"age_group_rki == '{group}'").index`

- of `_increase_test_demand`: >90% go to `infected_untested = states.index[states.eval(selection_string) & ~demanded]`

In [13]:
import warnings


def _decrease_test_demand(demanded, states, n_to_remove, group):
    """Decrease the number of tests demanded in an age group by a certain number.

    This is called when the endogenously demanded tests (symptomatics + educ workers)
    already exceed the designated number of positive tests in an age group.

    """
    demanded = demanded.copy(deep=True)

    demanding_test_in_age_group = (
        states[demanded].query(f"age_group_rki == '{group}'").index
    )
    drawn = np.random.choice(
        a=demanding_test_in_age_group, size=n_to_remove, replace=False
    )
    demanded.loc[drawn] = False
    return demanded


def _increase_test_demand(demanded, states, n_undemanded_tests, group):
    """Randomly increase the number of tests demanded in an age group.
    This is the case where we have additional positive tests to distribute.

    """
    demanded = demanded.copy(deep=True)

    right_age_group = f"(age_group_rki == '{group}')"
    currently_infected = "(infectious | symptomatic | (cd_infectious_true >= 0))"
    untested = "(~pending_test & ~knows_immune)"
    selection_string = right_age_group + " & " + currently_infected + " & " + untested

    infected_untested = states.index[states.eval(selection_string) & ~demanded]

    if len(infected_untested) >= n_undemanded_tests:
        drawn = np.random.choice(infected_untested, n_undemanded_tests, replace=False)
    else:
        date = get_date(states)
        warnings.warn(
            f"\n\nThe implied share_known_cases for age group {group} is >1 "
            f"on date {date.date()} ({date.day_name()}).\n\n"
        )
        drawn = infected_untested
    demanded.loc[drawn] = True
    return demanded

In [14]:
demanded = pd.Series(True, index=states.index)

In [15]:
%%timeit

_decrease_test_demand(demanded, states, n_to_remove=30, group="80-100")

239 ms ± 2.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%%timeit

_increase_test_demand(demanded, states, n_undemanded_tests=20, group="80-100")


The implied share_known_cases for age group 80-100 is >1 on date 2021-03-01 (Monday).




11.2 ms ± 329 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
