In [None]:
import pandas as pd
from plotly import graph_objects as go
import nevergrad as ng
from summer2 import AgeStratification, Overwrite

from estival.wrappers.nevergrad import optimize_model
import estival.priors as esp
import estival.targets as est
from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
from summer2 import CompartmentalModel
from summer2.parameters import Parameter, DerivedOutput
from summer2.functions.time import get_sigmoidal_interpolation_function

from tb_incubator.demographics import add_extra_crude_birth_flow
from tb_incubator.constants import set_project_base_path

pd.options.plotting.backend = "plotly"
project_paths = set_project_base_path("../incubator2024/tb_incubator")

In [None]:
# Can move this out of the notebook now
def get_age_groups_in_range(data_ages, lower_limit, upper_limit):
    return [i for i in data_ages if "+" not in i and lower_limit <= int(i.split("-")[0]) <= upper_limit]

In [None]:
# UN birth data
data_path = project_paths["DATA_PATH"]
birth_rates = pd.read_csv(data_path / "id_birth.csv", index_col=0)["value"]

In [None]:
# Load population and deaths
un_pop_data = pd.read_csv(data_path / "id_pop_deaths.csv")

In [None]:
# Process demographic data into multi-indexed dataframe
un_pops = pd.DataFrame({
    "population": un_pop_data["population"],
    "deaths": un_pop_data["deaths"]
})
un_pops.index = pd.MultiIndex.from_frame(un_pop_data[["year", "age"]])

target_pops = un_pops.groupby(level=[0]).sum()["population"]
target_deaths = un_pops.groupby(level=[0]).sum()["deaths"]
data_ages = set(un_pops.index.get_level_values(1))
years = set(un_pops.index.get_level_values(0))

In [None]:
# Get mapping from modelled age groups to data
agegroup_request = [[0, 4], [5, 14], [15, 34], [35, 49], [50, 74]]
agegroup_map = {low: get_age_groups_in_range(data_ages, low, up) for low, up in agegroup_request}
agegroup_map[agegroup_request[-1][0]].append("75+")

In [None]:
# Calculate death rates
mapped_rates = pd.DataFrame()
for year in years:
    for agegroup in agegroup_map:
        age_mask = [i in agegroup_map[agegroup] for i in un_pops.index.get_level_values(1)]
        age_year_data = un_pops.loc[age_mask].loc[year, :]
        total = age_year_data.sum()
        mapped_rates.loc[year, agegroup] = total["deaths"] / total["population"]
death_rates = mapped_rates.loc[birth_rates.index]

In [None]:
death_rates.plot()

In [None]:
# Arbitrary base model construction
model_comps = ["susceptible", "early latent", "late latent", "infectious", "recovered"]
model_times = [1850.0, 2024.0]
model = CompartmentalModel(
    times=model_times,
    compartments=model_comps,
    infectious_compartments=["infectious"],
)
init_pops = {"susceptible": Parameter("starting population"), "infectious": 0.0}
model.set_initial_population(init_pops)

In [None]:
# TB transitions, some meaningless TB-related flows
model.add_death_flow("TB death", Parameter("death rate"), "infectious")

In [None]:
# Demographic transitions
model.add_universal_death_flows("population_death", 1.0)
model.add_replacement_birth_flow("replacement_birth", "susceptible")
add_extra_crude_birth_flow(model, "extra_birth", Parameter("population growth rate"), "susceptible")

In [None]:
# Arbitrary epidemiological parameter
params = {"death rate": 0.1}

In [None]:
# Apply age stratification with age-specific death rate functions of time
age_strata = [i[0] for i in agegroup_request]
strat = AgeStratification("age", age_strata, model_comps)
death_adjs = {}
for age in age_strata:
    years = death_rates.index
    rates = death_rates[age]
    pop_death_func = get_sigmoidal_interpolation_function(years, rates)
    death_adjs[str(age)] = Overwrite(pop_death_func)
strat.set_flow_adjustments("population_death", death_adjs)
model.stratify_with(strat)

In [None]:
# Track population
model.request_output_for_compartments("total_population", model_comps);

In [None]:
# Prepare calibration model
priors = [
    esp.UniformPrior("population growth rate", (0.005, 0.03)),
    esp.UniformPrior("starting population", (1e6, 3e7)),
]
targets = [est.NegativeBinomialTarget("total_population", target_pops, dispersion_param=100.0)]
bcm = BayesianCompartmentalModel(model, params, priors, targets)

In [None]:
# Set up optimisation
budget = 1000
opt_class = ng.optimizers.NGOpt
orunner = optimize_model(bcm, opt_class=opt_class, budget=budget)
start_params = {"population growth rate": 0.01, "starting population": 1.5e6}
orunner = optimize_model(bcm, opt_class=opt_class, suggested=start_params, init_method="midpoint")

In [None]:
# Optimise
rec = orunner.minimize(budget)
map_params = rec.value[1]
print("Best candidate parameters:")
for i_param, param in enumerate(map_params):
    print(f"\t{param}: {round(map_params[param], 3)} (within bound {priors[i_param].bounds()}")

In [None]:
# Run with optimised parameters
model.run(parameters=params | map_params)

In [None]:
# Inspect outputs
outputs = model.get_derived_outputs_df()
fig = go.Figure()
fig.add_trace(go.Scatter(x=outputs.index, y=outputs["total_population"], name="modelled"))
fig.add_trace(go.Scatter(x=target_pops.index, y=target_pops, name="target", mode="markers", marker=dict(color="black", size=2.0)))
fig.update_layout(yaxis={"range": [0.0, 3e8]})