In [None]:
from matplotlib import pyplot
import os
from numpy import exp

from summer.utils import ref_times_to_dti
from autumn.core.plots.utils import REF_DATE
from autumn.core.project import get_project, load_timeseries
from autumn.settings.region import Region


In [None]:
project = get_project("sm_covid2", Region.FRANCE)

In [None]:
all_targets = project.plots
for target in all_targets:
    all_targets[target]['times'] = ref_times_to_dti(REF_DATE, all_targets[target]['times'])  

### Run a calibration to initialise the calibration object so we can then access likelihood methods

In [None]:
project.calibration.run(
    project,1,1,1, 
    derived_outputs_to_plot=[o for o in project.plots if (o != 'death_missed_school_ratio') and not("_diff_" in o)]
)
cal = project.calibration

In [None]:
params_dict = {}

In [None]:
params_dict["MLE"] = {
    "contact_rate":0.02995841309828505,
    "age_stratification.ifr.multiplier": 0.5230259999804419,
    "contact_rate": 0.02995841309828505,
    "infectious_seed_time": 11.30050215485564,
    "random_process.delta_values(1)": -1.0750580960048195,
    "random_process.delta_values(2)": 0.10748516201969016,
    "random_process.delta_values(3)": -0.02592371505135227,
    "random_process.delta_values(4)": 0.14610440885209375,
    "random_process.delta_values(5)": 0.5837441720378771,
    "random_process.delta_values(6)": 0.2156486490935401,
    "random_process.delta_values(7)": -0.5831829534201922,
    "random_process.delta_values(8)": 0.2863577626289051,
    "random_process.delta_values(9)":-0.10010896231338395,
    "random_process.delta_values(10)":  0.06635744365345397,
    "random_process.delta_values(11)":  -0.32018571814026453,
    "random_process.delta_values(12)": 0.6020370141695954,
    "random_process.delta_values(13)":  -0.06375841140441674,
    "random_process.delta_values(14)": -1.1297818666973602,
    "random_process.delta_values(15)": 0.7890782568598431,
    "random_process.delta_values(16)":  0.19870559032908952,
    "random_process.delta_values(17)":  -0.37801029096815,
    "random_process.delta_values(18)": -0.39898763083626965,
    "random_process.delta_values(19)": 0.41529007067837176,
    "random_process.delta_values(20)":  0.6853125885986571,
    "random_process.delta_values(21)":  -0.5415914511333617,

    "random_process.noise_sd": 0.5147739301564118,
    "voc_emergence.delta.new_voc_seed.time_from_gisaid_report": 20.986017544884035,
}

In [None]:
params_dict["manual"] = {
    "contact_rate":0.02995841309828505,
    "age_stratification.ifr.multiplier": 0.5230259999804419,
    "contact_rate": 0.02995841309828505,
    "infectious_seed_time": 11.30050215485564,
    "random_process.delta_values(1)": -1.0750580960048195,
    "random_process.delta_values(2)": 0.10748516201969016,
    "random_process.delta_values(3)": -0.02592371505135227,
    "random_process.delta_values(4)": 0.14610440885209375,
    "random_process.delta_values(5)": 0.5837441720378771,
    "random_process.delta_values(6)": 0.2156486490935401,
    "random_process.delta_values(7)": -0.4,  # -0.5831829534201922,
    "random_process.delta_values(8)": -.2, # 0.2863577626289051,
    "random_process.delta_values(9)": 0.2, # -0.10010896231338395,
    "random_process.delta_values(10)": 0.1, # 0.06635744365345397,
    "random_process.delta_values(11)": -0.2, # -0.32018571814026453,
    "random_process.delta_values(12)": .5, # 0.6020370141695954,
    "random_process.delta_values(13)": -.2,  # -0.06375841140441674,
    "random_process.delta_values(14)": -1.1297818666973602,
    "random_process.delta_values(15)": 0.7890782568598431,
    "random_process.delta_values(16)": .3 , # 0.19870559032908952,
    "random_process.delta_values(17)": -.5, # -0.37801029096815,
    "random_process.delta_values(18)": -0.39898763083626965,
    "random_process.delta_values(19)": .6, #0.41529007067837176,
    "random_process.delta_values(20)": .6, # 0.6853125885986571,
    "random_process.delta_values(21)": -.7, # -0.5415914511333617

    "random_process.noise_sd": 0.5147739301564118,
    "voc_emergence.delta.new_voc_seed.time_from_gisaid_report": 20.986017544884035,
}

In [None]:
params = {
    "MLE": project.param_set.baseline.update(params_dict["MLE"], calibration_format=True),
    "manual": project.param_set.baseline.update(params_dict["manual"], calibration_format=True)
}

models, derived_dfs = {}, {}
for run in ["MLE", "manual"]:
    models[run] = project.run_baseline_model(params[run])  
    derived_dfs[run] = model.get_derived_outputs_df()


In [None]:

outputs = ["infection_deaths", "cumulative_infection_deaths", "transformed_random_process"] #, "hospital_admissions", "hospital_occupancy", "icu_admissions", "icu_occupancy"]
for output in outputs:
    fig = pyplot.figure(figsize=(12, 8))
    pyplot.style.use("ggplot")
    axis = fig.add_subplot()

    for run in ["MLE", "manual"]:
        derived_dfs[run][output].plot(label=run)
    axis.set_title(output)
    
    if output in all_targets:
        # all_targets[output].plot.line(ax=axis, linewidth=0., markersize=10., marker="o")
        axis.scatter(all_targets[output]['times'], all_targets[output]['values'], color="k", s=5, alpha=0.5, zorder=10)         

    # add random process indicators
    if output == "transformed_random_process":
        y_text = 1.
    else:
        y_text = 0.
    start_time = params["MLE"]['random_process']['time']['start']
    step = params["MLE"]['random_process']['time']['step']
    for i in range(len(params["MLE"]['random_process']['delta_values'])):
        date = ref_times_to_dti(REF_DATE, [start_time + i * step])[0] 
        axis.text(x=date, y=1., s=str(i), ha="center", va="top")
    
    axis.legend()

log_likelihoods, log_priors, log_posteriors = {}, {}, {}
rp = cal.random_process
for run in ["MLE", "manual"]:
    rp.delta_values = params[run]["random_process"]["delta_values"]
    log_likelihoods[run] = cal.loglikelihood(params_dict[run])
    log_priors[run] = cal.logprior(params_dict[run]) + rp.evaluate_rp_loglikelihood()
    log_posteriors[run] = log_likelihoods[run] + log_priors[run]

print(f"log_priors: {log_priors}")
print(f"log_likelihoods: {log_likelihoods}")
print(f"log_posteriors: {log_posteriors}")
# accept proba
best = "manual" if  log_posteriors["manual"] >= log_posteriors["MLE"] else "MLE"
print(f"Best run is: {best}")
accept_proba = 100 * exp(-abs(log_posteriors['MLE'] - log_posteriors['manual']))
print(f"Proba of transition to worse run: {round(accept_proba)}%")