In [None]:
import pandas as pd
from plotly import graph_objects as go
import nevergrad as ng
from summer2 import AgeStratification, Overwrite

from estival.wrappers.nevergrad import optimize_model
import estival.priors as esp
import estival.targets as est
from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
from summer2 import CompartmentalModel
from summer2.parameters import Parameter, DerivedOutput
from summer2.functions.time import get_sigmoidal_interpolation_function

from tb_incubator.demographics import add_extra_crude_birth_flow
from tb_incubator.constants import set_project_base_path
from tb_incubator.input import get_birth_rate, get_pop_death_data, get_death_rates

pd.options.plotting.backend = "plotly"
project_paths = set_project_base_path("../tb_incubator/")

In [None]:
# load birth data
birth_rates = get_birth_rate()

In [None]:
# load age-stratified, population & deaths data
pop_death = get_pop_death_data()

target_pops = pop_death.groupby(level=[0]).sum()["population"]

In [None]:
# load death rates
death_rates = get_death_rates()

In [None]:
death_rates.plot()

In [None]:
# Arbitrary base model construction
model_comps = ["susceptible", "early latent", "late latent", "infectious", "recovered"]
model_times = [1850.0, 2024.0]
model = CompartmentalModel(
    times=model_times,
    compartments=model_comps,
    infectious_compartments=["infectious"],
)
init_pops = {"susceptible": Parameter("starting population"), "infectious": 0.0}
model.set_initial_population(init_pops)

In [None]:
# TB transitions, some meaningless TB-related flows
#model.add_death_flow("TB death", Parameter("death rate"), "infectious")

In [None]:
# Demographic transitions
model.add_universal_death_flows("population_death", 1.0)
model.add_replacement_birth_flow("replacement_birth", "susceptible")
add_extra_crude_birth_flow(model, "extra_birth", Parameter("population growth rate"), "susceptible")

In [None]:
# Arbitrary epidemiological parameter
params = {"population growth rate": 0.1,
          "starting population" : 1e7}

In [None]:
# Apply age stratification with age-specific death rate functions of time
agegroup_request = [[0, 4], [5, 14], [15, 34], [35, 49], [50, 100]]
age_strata = [i[0] for i in agegroup_request]
strat = AgeStratification("age", age_strata, model_comps)
death_adjs = {}
for age in age_strata:
    years = death_rates.index
    rates = death_rates[age]
    pop_death_func = get_sigmoidal_interpolation_function(years, rates)
    death_adjs[str(age)] = Overwrite(pop_death_func)
strat.set_flow_adjustments("population_death", death_adjs)
model.stratify_with(strat)

In [None]:
# Track population
model.request_output_for_compartments("total_population", model_comps)
for s in age_strata:
    model.request_output_for_compartments(f"{s}_population", model_comps, strata={"age": str(s)})

In [None]:
model.run(parameters=params)
age_pops = model.get_derived_outputs_df()
age_pops.plot.area()

In [None]:
# Prepare calibration model
priors = [
    esp.UniformPrior("population growth rate", (0.005, 0.03)),
    esp.UniformPrior("starting population", (1e6, 3e7)),
]
targets = [est.NegativeBinomialTarget("total_population", target_pops, dispersion_param=100.0)]
bcm = BayesianCompartmentalModel(model, params, priors, targets)

In [None]:
# Set up optimisation
budget = 1000
opt_class = ng.optimizers.NGOpt
orunner = optimize_model(bcm, opt_class=opt_class, budget=budget)
start_params = {"population growth rate": 0.01, "starting population": 1.5e6}
orunner = optimize_model(bcm, opt_class=opt_class, suggested=start_params, init_method="midpoint")

In [None]:
# Optimise
rec = orunner.minimize(budget)
map_params = rec.value[1]
print("Best candidate parameters:")
for i_param, param in enumerate(map_params):
    print(f"\t{param}: {round(map_params[param], 3)} (within bound {priors[i_param].bounds()}")

In [None]:
# Run with optimised parameters
model.run(parameters=params | map_params)

In [None]:
# Inspect outputs
outputs = model.get_derived_outputs_df()
fig = go.Figure()
fig.add_trace(go.Scatter(x=outputs.index, y=outputs["total_population"], name="modelled"))
fig.add_trace(go.Scatter(x=target_pops.index, y=target_pops, name="target", mode="markers", marker=dict(color="black", size=2.0)))
fig.update_layout(yaxis={"range": [0.0, 3e8]})