In [133]:
import pandas as pd
from plotly import graph_objects as go
import nevergrad as ng
from summer2 import AgeStratification, Overwrite

from estival.wrappers.nevergrad import optimize_model
import estival.priors as esp
import estival.targets as est
from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
from summer2 import CompartmentalModel
from summer2.parameters import Parameter, Function, DerivedOutput
from summer2.functions.time import get_sigmoidal_interpolation_function

from tb_incubator.demographics import add_extra_crude_birth_flow
from tb_incubator.constants import set_project_base_path
from tb_incubator.input import get_birth_rate, get_pop_death_data, get_death_rates

pd.options.plotting.backend = "plotly"
project_paths = set_project_base_path("../tb_incubator/")

In [134]:
# load birth data
birth_rates = get_birth_rate()


Columns (2,3,4,7) have mixed types. Specify dtype option on import or set low_memory=False.



In [135]:
# load age-stratified, population & deaths data
pop_death = get_pop_death_data()

target_pops = pop_death.groupby(level=[0]).sum()["population"]


Columns (3,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111) have mixed types. Specify dtype option on import or set low_memory=False.


Columns (3,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111) have mixed types. Specify dtype option on import or set low_memory=False.



In [136]:
# load death rates
death_rates = get_death_rates()


Columns (3,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111) have mixed types. Specify dtype option on import or set low_memory=False.


Columns (3,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111) have mixed types. Specify dtype option on import or set low_memory=False.



In [137]:
death_rates.plot()

In [138]:
# Arbitrary base model construction
model_comps = ["susceptible", "early latent", "late latent", "infectious", "recovered"]
model_times = [1850.0, 2024.0]
model = CompartmentalModel(
    times=model_times,
    compartments=model_comps,
    infectious_compartments=["infectious"],
)
init_pops = {"susceptible": 1e6, "infectious": 0.0}
model.set_initial_population(init_pops)

In [139]:
# TB transitions, some meaningless TB-related flows
#model.add_death_flow("TB death", Parameter("death rate"), "infectious")

In [140]:
# Demographic transitions
model.add_universal_death_flows("population_death", 1.0)
model.add_replacement_birth_flow("replacement_birth", "susceptible")


In [141]:
def get_entry_rate_from_pops(start_time, end_time, start_pop, end_pop): 
    pop_change = end_pop - start_pop
    duration = end_time - start_time
    rate = pop_change/duration
    return rate

In [142]:
total_population_per_year = pop_death.groupby('year')['population'].sum()
start_pop = total_population_per_year.get(1950)
end_pop = total_population_per_year.get(1951)


num_entry = get_entry_rate_from_pops(1950, 1951, start_pop, end_pop)

In [144]:
# Arbitrary epidemiological parameter
params = {"starting population" : 1e7}

In [145]:
# Apply age stratification with age-specific death rate functions of time
agegroup_request = [[0, 4], [5, 14], [15, 34], [35, 49], [50, 100]]
age_strata = [i[0] for i in agegroup_request]
strat = AgeStratification("age", age_strata, model_comps)
death_adjs = {}
for age in age_strata:
    years = death_rates.index
    rates = death_rates[age]
    pop_death_func = get_sigmoidal_interpolation_function(years, rates)
    death_adjs[str(age)] = Overwrite(pop_death_func)
strat.set_flow_adjustments("population_death", death_adjs)
model.stratify_with(strat)

In [146]:
model.add_importation_flow("births", num_imported=num_entry, dest="susceptible", split_imports=True, dest_strata={"age": "5"})

In [147]:
# Track population
model.request_output_for_compartments("total_population", model_comps)
for s in age_strata:
    model.request_output_for_compartments(f"{s}_population", model_comps, strata={"age": str(s)})

In [148]:
model.run(parameters=params)
age_pops = model.get_derived_outputs_df()
age_pops.plot.area()

In [149]:
# Prepare calibration model
priors = [
    esp.UniformPrior("population growth rate", (0.005, 0.03)),
    esp.UniformPrior("starting population", (1e6, 3e7)),
]
targets = [est.NegativeBinomialTarget("total_population", target_pops, dispersion_param=100.0)]
bcm = BayesianCompartmentalModel(model, params, priors, targets)

In [150]:
# Set up optimisation
budget = 1000
opt_class = ng.optimizers.NGOpt
orunner = optimize_model(bcm, opt_class=opt_class, budget=budget)
start_params = {"population growth rate": 0.01, "starting population": 1.5e6}
orunner = optimize_model(bcm, opt_class=opt_class, suggested=start_params, init_method="midpoint")

In [151]:
# Optimise
rec = orunner.minimize(budget)
map_params = rec.value[1]
print("Best candidate parameters:")
for i_param, param in enumerate(map_params):
    print(f"\t{param}: {round(map_params[param], 3)} (within bound {priors[i_param].bounds()}")

Best candidate parameters:
	population growth rate: 0.009 (within bound (0.005, 0.03)
	starting population: 1529635.294 (within bound (1000000.0, 30000000.0)


In [152]:
# Run with optimised parameters
model.run(parameters=params | map_params)

In [153]:
# Inspect outputs
outputs = model.get_derived_outputs_df()
fig = go.Figure()
fig.add_trace(go.Scatter(x=outputs.index, y=outputs["total_population"], name="modelled"))
fig.add_trace(go.Scatter(x=target_pops.index, y=target_pops, name="target", mode="markers", marker=dict(color="black", size=2.0)))
fig.update_layout(yaxis={"range": [0.0, 3e8]})