In [None]:
from matplotlib import pyplot
import os
from numpy import exp

from summer.utils import ref_times_to_dti
from autumn.core.plots.utils import REF_DATE
from autumn.core.project import get_project, load_timeseries
from autumn.settings.region import Region


In [None]:
project = get_project("sm_covid2", Region.FRANCE)

In [None]:
all_targets = project.plots
for target in all_targets:
    all_targets[target]['times'] = ref_times_to_dti(REF_DATE, all_targets[target]['times'])  

### Run a calibration to initialise the calibration object so we can then access likelihood methods

In [None]:
project.calibration.run(
    project,1,1,1, 
    derived_outputs_to_plot=[o for o in project.plots if (o != 'death_missed_school_ratio') and not("_diff_" in o)]
)
cal = project.calibration

In [None]:
params_dict = {}

In [None]:
params_dict["MLE"] = {
    "age_stratification.ifr.multiplier": 0.6251547198659229,
    "contact_rate": 0.03164074788281637,
    "infectious_seed_time": 19.04893117520397,
    "random_process.delta_values(1)": -0.18311593187612196,
    "random_process.delta_values(10)": 0.4789775730590282,
    "random_process.delta_values(11)": -0.31119184520705256,
    "random_process.delta_values(12)": -0.305381101610086,
    "random_process.delta_values(13)": 0.1470561631388123,
    "random_process.delta_values(14)": 0.13205418399338864,
    "random_process.delta_values(15)": 0.09366230292816935,
    "random_process.delta_values(16)": -0.2806266310059893,
    "random_process.delta_values(17)": 0.015928866798155195,
    "random_process.delta_values(18)": 0.08695028608645572,
    "random_process.delta_values(19)": 0.017006740282116795,
    "random_process.delta_values(2)": -0.13024041911952744,
    "random_process.delta_values(20)": -0.012000468861728564,
    "random_process.delta_values(21)": 0.1043431978884164,
    "random_process.delta_values(22)": 0.2812843842216628,
    "random_process.delta_values(23)": -0.3345567730471908,
    "random_process.delta_values(24)": 0.03645052677553595,
    "random_process.delta_values(25)": 0.07113806321627703,
    "random_process.delta_values(26)": -0.034573686994065156,
    "random_process.delta_values(27)": 0.1792438749670775,
    "random_process.delta_values(28)": -0.39841289667549495,
    "random_process.delta_values(29)": 0.22949830881622546,
    "random_process.delta_values(3)": -0.5668027993286127,
    "random_process.delta_values(30)": -0.11393810182829278,
    "random_process.delta_values(31)": -0.6309376285029871,
    "random_process.delta_values(32)": 0.3933673622046312,
    "random_process.delta_values(33)": 0.26901543234807423,
    "random_process.delta_values(34)": -0.019319087259375367,
    "random_process.delta_values(35)": -0.05520047347714474,
    "random_process.delta_values(36)": -0.17379319177880292,
    "random_process.delta_values(37)": -0.2029752912323639,
    "random_process.delta_values(38)": 0.32268039133362425,
    "random_process.delta_values(39)": -0.48266233497041133,
    "random_process.delta_values(4)": -0.5836969423870215,
    "random_process.delta_values(40)": -0.004616222494904898,
    "random_process.delta_values(41)": 0.9602355727696343,
    "random_process.delta_values(42)": -0.1889024490144937,
    "random_process.delta_values(43)": -0.027964054675037042,
    "random_process.delta_values(44)": -0.3891547841061338,
    "random_process.delta_values(45)": 0.44224519504957316,
    "random_process.delta_values(5)": 0.3545964410088338,
    "random_process.delta_values(6)": -0.16607901387788715,
    "random_process.delta_values(7)": 0.49994487800575405,
    "random_process.delta_values(8)": 0.3344108490198745,
    "random_process.delta_values(9)": -0.22013880066105562,
    "random_process.noise_sd": 0.62906649205768,
    "voc_emergence.delta.new_voc_seed.time_from_gisaid_report": 29.5352965161684
}

In [None]:
params_dict["manual"] = {
    "age_stratification.ifr.multiplier": 0.6251547198659229,
    "contact_rate": 0.03164074788281637,
    "infectious_seed_time": 19.04893117520397,
    "random_process.delta_values(1)": -0.18311593187612196,
    "random_process.delta_values(10)": 0.4789775730590282,
    "random_process.delta_values(11)": -0.31119184520705256,
    "random_process.delta_values(12)": -0.305381101610086,
    "random_process.delta_values(13)": 0.1470561631388123,
    "random_process.delta_values(14)": 0.13205418399338864,
    "random_process.delta_values(15)":  0.0, # 0.09366230292816935,
    "random_process.delta_values(16)": -0.2806266310059893,
    "random_process.delta_values(17)": 0.015928866798155195,
    "random_process.delta_values(18)": 0.08695028608645572,
    "random_process.delta_values(19)": 0.017006740282116795,
    "random_process.delta_values(2)": -0.13024041911952744,
    "random_process.delta_values(20)": -0.012000468861728564,
    "random_process.delta_values(21)": 0.1043431978884164,
    "random_process.delta_values(22)": 0.2812843842216628,
    "random_process.delta_values(23)": -0.3345567730471908,
    "random_process.delta_values(24)": 0.03645052677553595,
    "random_process.delta_values(25)": 0.07113806321627703,
    "random_process.delta_values(26)": -0.034573686994065156,
    "random_process.delta_values(27)": 0.1792438749670775,
    "random_process.delta_values(28)": -0.39841289667549495,
    "random_process.delta_values(29)": 0.22949830881622546,
    "random_process.delta_values(3)": -0.5668027993286127,
    "random_process.delta_values(30)": -0.11393810182829278,
    "random_process.delta_values(31)": -0.6309376285029871,
    "random_process.delta_values(32)": 0.3933673622046312,
    "random_process.delta_values(33)": 0.26901543234807423,
    "random_process.delta_values(34)": -0.019319087259375367,
    "random_process.delta_values(35)": -0.05520047347714474,
    "random_process.delta_values(36)": -0.17379319177880292,
    "random_process.delta_values(37)": -0.2029752912323639,
    "random_process.delta_values(38)": 0.32268039133362425,
    "random_process.delta_values(39)": -0.48266233497041133,
    "random_process.delta_values(4)": -0.5836969423870215,
    "random_process.delta_values(40)": -0.004616222494904898,
    "random_process.delta_values(41)": 0.9602355727696343,
    "random_process.delta_values(42)": -0.1889024490144937,
    "random_process.delta_values(43)": -0.027964054675037042,
    "random_process.delta_values(44)": -0.3891547841061338,
    "random_process.delta_values(45)": 0.44224519504957316,
    "random_process.delta_values(5)": 0.3545964410088338,
    "random_process.delta_values(6)": -0.16607901387788715,
    "random_process.delta_values(7)": 0.49994487800575405,
    "random_process.delta_values(8)": 0.3344108490198745,
    "random_process.delta_values(9)": -0.22013880066105562,
    "random_process.noise_sd": 0.62906649205768,
    "voc_emergence.delta.new_voc_seed.time_from_gisaid_report": 29.5352965161684
}

In [None]:
params = {
    "MLE": project.param_set.baseline.update(params_dict["MLE"], calibration_format=True),
    "manual": project.param_set.baseline.update(params_dict["manual"], calibration_format=True)
}

models, derived_dfs = {}, {}
for run in ["MLE", "manual"]:
    models[run] = project.run_baseline_model(params[run])  
    derived_dfs[run] = models[run].get_derived_outputs_df()


In [None]:
dispersion_params = [10, 60]
for i in [0, 1]:    
    cal.targets[i].dispersion_param = dispersion_params[i]

outputs = ["infection_deaths", "cumulative_infection_deaths", "transformed_random_process"] #, "hospital_admissions", "hospital_occupancy", "icu_admissions", "icu_occupancy"]
for output in outputs:
    fig = pyplot.figure(figsize=(12, 8))
    pyplot.style.use("ggplot")
    axis = fig.add_subplot()

    for run in ["MLE", "manual"]:
        derived_dfs[run][output].plot(label=run)
    axis.set_title(output)
    
    if output in all_targets:
        # all_targets[output].plot.line(ax=axis, linewidth=0., markersize=10., marker="o")
        axis.scatter(all_targets[output]['times'], all_targets[output]['values'], color="k", s=5, alpha=0.5, zorder=10)         

    # add random process indicators
    if output == "transformed_random_process":
        y_text = 1.
    else:
        y_text = 0.
    start_time = params["MLE"]['random_process']['time']['start']
    step = params["MLE"]['random_process']['time']['step']
    for i in range(len(params["MLE"]['random_process']['delta_values'])):
        date = ref_times_to_dti(REF_DATE, [start_time + i * step])[0] 
        axis.text(x=date, y=1., s=str(i), ha="center", va="top")
    
    axis.legend()

log_likelihoods, log_priors, log_posteriors = {}, {}, {}
rp = cal.random_process
for run in ["MLE", "manual"]:
    rp.delta_values = params[run]["random_process"]["delta_values"]
    log_likelihoods[run] = cal.loglikelihood(params_dict[run])
    log_priors[run] = cal.logprior(params_dict[run]) + rp.evaluate_rp_loglikelihood()
    log_posteriors[run] = log_likelihoods[run] + log_priors[run]

print(f"log_priors: {log_priors}")
print(f"log_likelihoods: {log_likelihoods}")
print(f"log_posteriors: {log_posteriors}")
# accept proba
best = "manual" if  log_posteriors["manual"] >= log_posteriors["MLE"] else "MLE"
print(f"Best run is: {best}")
accept_proba = 100 * exp(-abs(log_posteriors['MLE'] - log_posteriors['manual']))
print(f"Proba of transition to worse run: {round(accept_proba)}%")

# Check negative binomial distribution

In [None]:
from scipy import stats
from matplotlib import pyplot as plt
n = 100  # this is the dispersion parameter

mu = 400  # the approximate model output values

p = mu / (mu + n)
x = [round(mu/50)*i for i in range(100)]
y = [stats.nbinom.pmf(round(t), n, 1.0 - p) for t in x]

In [None]:
plt.plot(x,y)