In [None]:
import pandas as pd

from estival.wrappers.nevergrad import optimize_model

from tb_incubator.plotting import plot_model_vs_actual, display_plot
from tb_incubator.input import load_targets, load_param_info
from tb_incubator.calibrate import get_bcm

from multiprocessing import cpu_count
pd.options.plotting.backend = "plotly"

In [None]:
params = load_param_info()["value"]
all_targets = load_targets()
file_prefix = "opt3"

bcm = get_bcm(params, improved_detection=True, xpert_sensitivity=True, covid_effects=True)

In [None]:
def calibrate_model_with_optimisation(bcm):
    """
    This function performs a model calibration using optimisation. 
    Calibration is performed using the estival package, which implements a wrapper for optimisation methods provided by the nevergrad package. 

    Args:
        bcm: the calibration model object (type BayesianCompartmentalModel) 
    """

    from nevergrad.optimization.differentialevolution import TwoPointsDE
    orunner = optimize_model(bcm, opt_class=TwoPointsDE, num_workers=cpu_count(), budget=4000)
    rec = orunner.minimize(4000)
    optimised_params = rec.value[1]    
   
    return optimised_params

In [None]:
optimised_params = calibrate_model_with_optimisation(bcm)
optimised_params

In [None]:
# run the modle with the optimised parameter set
res = bcm.run(optimised_params)
outs = res.derived_outputs

In [None]:
plot = plot_model_vs_actual(
    outs, all_targets['notif'], "notification", "Notification", "Modelled vs Data", "Actual data"
)

plot.update_xaxes(range=[2000, 2024])
plot
#display_plot(plot, f"{file_prefix}_notif_opt", "svg")

In [None]:
plot = plot_model_vs_actual(
    outs, all_targets['prevalence'], "prevalence", "Prevalence", "Modelled vs Data", "Actual data"
)

plot.update_xaxes(range=[2003, 2024])
plot
#display_plot(plot, f"{file_prefix}_prev_opt", "svg")

In [None]:
plot = plot_model_vs_actual(
    outs, all_targets['incidence'], "incidence", "incidence", "Modelled vs Data", "Actual data"
)
plot.update_xaxes(range=[2000, 2024])
plot
#display_plot(plot, f"{file_prefix}_inc_opt", "svg")

In [None]:
fig = outs["percentage_latent"].plot()
fig.update_xaxes(range=[2000, 2024])

fig
#display_plot(fig, f"{file_prefix}_ltbi_opt", "svg")