In [None]:
import pandas as pd
from plotly import graph_objects as go
import nevergrad as ng
from summer2 import AgeStratification, Overwrite

from estival.wrappers.nevergrad import optimize_model
import estival.priors as esp
import estival.targets as est
from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
from summer2 import CompartmentalModel
from summer2.parameters import Parameter, Function, DerivedOutput
from summer2.functions.time import get_sigmoidal_interpolation_function

from tb_incubator.demographics import add_extra_crude_birth_flow
from tb_incubator.constants import set_project_base_path
from tb_incubator.input import get_birth_rate, get_pop_death_data, get_death_rates

pd.options.plotting.backend = "plotly"
project_paths = set_project_base_path("../tb_incubator/")

In [None]:
# Load birth data
birth_rates = get_birth_rate()

In [None]:
# Load age-stratified, population and death data
pop_death = get_pop_death_data()
target_pops = pop_death.groupby(level=[0]).sum()["population"]

In [None]:
# Load death rates
death_rates = get_death_rates()

In [None]:
# Arbitrary base model construction
model_comps = ["susceptible", "early latent", "late latent", "infectious", "recovered"]
model_times = [1850.0, 2024.0]
model = CompartmentalModel(
    times=model_times,
    compartments=model_comps,
    infectious_compartments=["infectious"],
)
model.set_initial_population({})

In [None]:
# TB transitions, some meaningless TB-related flows - can add to these to test
model.add_death_flow("TB death", Parameter("death rate"), "infectious")

In [None]:
# Demographic transitions
model.add_universal_death_flows("population_death", 1.0)  # Placeholder to overwrite later
model.add_replacement_birth_flow("replacement_birth", "susceptible")

In [None]:
# Population by year and get the duration of the run-in period
total_pop_by_year = pop_death.groupby("year")["population"].sum()
pop_start_year = total_pop_by_year.index[0]
start_period = pop_start_year - model_times[0]

In [None]:
pop_entry

In [None]:
# Calculate population entry rates and convert to function
pop_entry = total_pop_by_year.diff().dropna()  # Note this will only work if data are annual
pop_entry[pop_start_year] = total_pop_by_year[pop_start_year] / start_period
pop_entry = pop_entry.sort_index()
entry_rate = get_sigmoidal_interpolation_function(pop_entry.index, pop_entry)

In [None]:
pop_entry = total_pop_by_year.diff().dropna()  # Note this will only work if data are annual
pop_entry[pop_start_year] = total_pop_by_year[pop_start_year] / start_period

In [None]:
# Apply age stratification with age-specific death rate functions of time
agegroup_request = [[0, 4], [5, 14], [15, 34], [35, 49], [50, 100]]
age_strata = [i[0] for i in agegroup_request]
strat = AgeStratification("age", age_strata, model_comps)
death_adjs = {}
for age in age_strata:
    years = death_rates.index
    rates = death_rates[age]
    pop_death_func = get_sigmoidal_interpolation_function(years, rates)
    death_adjs[str(age)] = Overwrite(pop_death_func)
strat.set_flow_adjustments("population_death", death_adjs)
model.stratify_with(strat)

In [None]:
# Add births as additional entry rate (split imports in case the susceptible compartments are further stratified later)
model.add_importation_flow("births", entry_rate, dest="susceptible", split_imports=True, dest_strata={"age": "0"})

In [None]:
# Track populations
age_pop_outputs = [model.request_output_for_compartments(s, model_comps, strata={"age": str(s)}) for s in age_strata]

In [None]:
# Run and inspect results
model.run({"death rate": 0.01})
fig = model.get_derived_outputs_df().plot.area()
fig.add_trace(go.Scatter(x=target_pops.index, y=target_pops, name="target", mode="markers", marker=dict(color="black", size=2.0)))