# Notebook to compare Penn CHIME fits against all parameter fits

## Imports

In [None]:
from os import environ
from datetime import date

from pandas import DataFrame

from penn_chime.parameters import Parameters, Disposition
from penn_chime.models import (
    SimSirModel,
    sim_sir,
    calculate_admits,
    calculate_dispositions,
)

from models import sir_step, sihr_step, one_minus_logistic_fcn, FitFcn

## Set up Penn CHIME model

In [None]:
p = Parameters(
    current_hospitalized=69,
    date_first_hospitalized=date(2020, 3, 7),
    doubling_time=4.0,
    hospitalized=Disposition.create(days=7, rate=0.025),
    icu=Disposition.create(days=9, rate=0.0075),
    infectious_days=14,
    market_share=0.15,
    n_days=100,
    population=3600000,
    recovered=0,
    relative_contact_rate=0.3,
    ventilated=Disposition.create(days=10, rate=0.005),
)

p.doubling_time = None
simsir = SimSirModel(p)


## Tests

### Check that model agrees with Penn CHIME if no policies are in place

Calculate S, I, H, R for no policies

In [None]:
n_days = simsir.raw_df.day.max() - simsir.raw_df.day.min() + 1

policies = [(simsir.beta, n_days)]
raw_df = DataFrame(
    sim_sir(
        simsir.susceptible,
        simsir.infected,
        p.recovered,
        simsir.gamma,
        -simsir.i_day,
        policies,
    )
)

calculate_dispositions(raw_df, simsir.rates, market_share=1.0)

day0 = raw_df.iloc[0].fillna(0)

raw_df.head()

Compute values using new fit function

In [None]:
pars = {
    "beta_i": simsir.beta,
    "gamma_i": simsir.gamma,
    "initial_susceptible": day0.susceptible,
    "initial_infected": day0.infected,
    "initial_hospitalized": day0.hospitalized,
    "initial_recovered": day0.recovered,
    "hospitalization_rate": simsir.rates["hospitalized"],
}
x = {
    "n_iter": raw_df.shape[0],
}


f = FitFcn(sir_step, columns=["susceptible", "infected", "hospitalized", "recovered"])
y = f(x, pars)

Check that difference is consistent with zero

In [None]:
mean = (y - raw_df[f.columns]).mean()
sdev = (y - raw_df[f.columns]).std()
assert (mean.abs() < sdev).all()
mean