In [None]:
from importlib import reload

from tbh import runner_tools as rt 
from tbh import model as tbm 
reload(tbm)


params, priors, tv_params = rt.get_parameters_and_priors()
model_config = rt.DEFAULT_MODEL_CONFIG
model = tbm.get_tb_model(model_config, tv_params)

In [None]:
ow_params = {
    "raw_transmission_rate": 35.
}
model.run(params | ow_params)

In [None]:
YEAR = 2020

In [None]:
c_df = model.get_outputs_df()
import pandas as pd
import matplotlib.pyplot as plt

# Define the age groups you're interested in
age_groups = [int(a) for a in model_config['age_groups']]

# Create a dictionary to store total population per age group over time
age_totals = {}

for age in age_groups:
    # Select columns ending with "Xage_<age>"
    age_cols = [col for col in c_df.columns if col.endswith(f"Xage_{age}")]
    
    # Sum across all compartments for this age group at each time point
    age_totals[age] = c_df[age_cols].sum(axis=1)

# Convert to DataFrame
age_totals_df = pd.DataFrame(age_totals)

# Normalize to population fractions (i.e., divide by total population at each time)
age_fractions = age_totals_df.div(age_totals_df.sum(axis=1), axis=0)

# # Plot
# plt.figure(figsize=(10, 6))
# for age in age_groups:
#     plt.plot(age_fractions.index, age_fractions[age], label=f"Age {age}+")

# plt.title("Population Fractions by Age Group Over Time")
# plt.xlabel("Time")
# plt.ylabel("Fraction of Total Population")
# plt.legend()
# plt.grid(True)
# plt.tight_layout()
# plt.show()


In [None]:
from tbh.paths import DATA_FOLDER

age_bins = [int(a) for a in model_config['age_groups']]

pop_data = pd.read_csv(DATA_FOLDER / "un_population.csv")
mort_data = pd.read_csv(DATA_FOLDER / "un_mortality.csv")

# Filter by country and start year
pop_data = pop_data[(pop_data["ISO3_code"] == model_config["iso3"]) & 
                    (pop_data["Time"] >= model_config["start_time"])]
mort_data = mort_data[(mort_data["ISO3_code"] == model_config["iso3"]) & 
                        (mort_data["Time"] >= model_config["start_time"])]

# Define bin edges and labels
bin_edges = age_bins + [200]  # use 200 as an upper cap beyond realistic ages
bin_labels = age_bins  # label each bin by its lower bound

pop_data["age_group"] = pd.cut(pop_data["AgeGrpStart"], bins=bin_edges, labels=bin_labels, right=False)
mort_data["age_group"] = pd.cut(mort_data["AgeGrpStart"], bins=bin_edges, labels=bin_labels, right=False)

# Drop rows outside specified bins (age_group == NaN)
pop_data = pop_data.dropna(subset=["age_group"])
mort_data = mort_data.dropna(subset=["age_group"])

# Convert category labels back to integers
pop_data["age_group"] = pop_data["age_group"].astype(int)
mort_data["age_group"] = mort_data["age_group"].astype(int)

# Aggregate by year and age group
pop_summary = pop_data.groupby(["Time", "age_group"])["PopTotal"].sum().reset_index()

In [None]:
d = pop_summary[pop_summary["Time"]==YEAR]
d.index = d["age_group"]
d = d[["PopTotal"]]
d["data_frac"] = d["PopTotal"] / d["PopTotal"].sum()

modelled_fracs = age_fractions.loc[YEAR]
modelled_fracs.index.name = "age_group"

m = pd.merge(modelled_fracs, d, left_index=True, right_index=True)
m.rename(columns={m.columns[0]: 'model_frac'}, inplace=True)
m = 100 * m[['model_frac', 'data_frac']]
m

In [None]:
m.plot(kind='bar', width=0.8)