In [None]:
from copy import copy 
from matplotlib import pyplot as plt
import datetime

import summer2
from summer.utils import ref_times_to_dti

from autumn.core.project import get_project, load_timeseries
from autumn.core.project.params import get_with_nested_key

# Targets represent data we are trying to fit to
from estival import targets as est
# We specify parameters using (Bayesian) priors
from estival import priors as esp
# Finally we combine these with our summer2 model in a BayesianCompartmentalModel (BCM)
from estival.model import BayesianCompartmentalModel

# Some preliminary code

### Load project, model and targets

In [None]:
project = get_project("sm_covid2", "france")
death_target = project.calibration.targets[0]

default_params = project.param_set.baseline
m = project.build_model(default_params.to_dict()) 

In [None]:
death_target.data =  death_target.data

In [None]:
targets = [
    # est.NegativeBinomialTarget(name="infection_deaths", data=death_target.data, dispersion_param=100.)
    est.NormalTarget(name="infection_deaths", data=death_target.data, stdev=100.)
]

### Format priors to use in estival

In [None]:
def get_estival_uniform_priors(autumn_priors):
    estival_priors = []
    for prior_dict in autumn_priors:
        assert prior_dict["distribution"] == "uniform", "Only uniform priors are currently supported"
        estival_priors.append(
            esp.UniformPrior(prior_dict["param_name"], prior_dict["distri_params"]),
        ) 
    return estival_priors
    

In [None]:
priors = get_estival_uniform_priors(project.calibration.all_priors)

### Create BayesianCompartmentalModel object

In [None]:
bcm = BayesianCompartmentalModel(m, default_params.to_dict(), priors, targets)

### Helper function to plot model fit to data

In [None]:
def plot_fit(bcm, params):

    REF_DATE = datetime.date(2019,12,31)
    datetime_target = copy(bcm.targets["infection_deaths"].data)
    datetime_target.index = ref_times_to_dti(REF_DATE, datetime_target.index)

    ax = bcm.run(params).derived_outputs["infection_deaths"].plot()
    datetime_target.plot(style='.')
    ll = bcm.loglikelihood(**params)

    text = f"ll={round(ll, 4)}"
    plt.text(0.8, 0.9, text, transform=ax.transAxes)

# Compare model fits

In [None]:
example_params = {
    'contact_rate': 0.04104869232093793,
    'infectious_seed_time': 30.447925441292778,
    'age_stratification.ifr.multiplier': 0.8145117432172231,
    'voc_emergence.delta.new_voc_seed.time_from_gisaid_report': 18.52679245171894,
    'voc_emergence.omicron.new_voc_seed.time_from_gisaid_report': 0.0,
    'infection_deaths_dispersion_param': 62.54504840986004,
    'random_process.noise_sd': 0.5835215881701278,
    'random_process.delta_values(0)': -1.306496185258189,
    'random_process.delta_values(1)': 0.11250765346332026,
    'random_process.delta_values(2)': 0.5009919549420943,
    'random_process.delta_values(3)': -0.06375110616249913,
    'random_process.delta_values(4)': -0.24028730546991617,
    'random_process.delta_values(5)': 0.2606697551298094,
    'random_process.delta_values(6)': -0.13194034829544132,
    'random_process.delta_values(7)': -0.3559750720412327,
    'random_process.delta_values(8)': 0.04754597441674813,
    'random_process.delta_values(9)': 0.12138039221640762,
    'random_process.delta_values(10)': 0.6411075116621299,
    'random_process.delta_values(11)': 10.0,
    'random_process.delta_values(12)': 20.0,
    'random_process.delta_values(13)': 0.0,
    'random_process.delta_values(14)': 0.0,
    'random_process.delta_values(15)': 0.0,
    'random_process.delta_values(16)': 0.0
}

In [None]:
plot_fit(bcm, example_params)

In [None]:
alternate_params = copy(example_params)
alternate_params.update(
    {   # Is the model sensitive to changes in the following parameters?
        
        # "contact_rate": 0.0,  # YES
        "random_process.delta_values(1)": -10.,  # NO
        # 'age_stratification.ifr.multiplier': 0., # YES
        # 'voc_emergence.delta.new_voc_seed.time_from_gisaid_report': 0., # YES

    }
)

In [None]:
plot_fit(bcm, alternate_params)