In [1]:
import warnings
from typing import List

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from autumn.core.inputs.database import get_input_db
from autumn.core.inputs.demography.queries import get_population_by_agegroup
from autumn.core.project import ParameterSet, get_project
from autumn.core.utils.display import pretty_print

In [2]:
warnings.filterwarnings("ignore")
input_db = get_input_db()
def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)

pop = input_db.query(table_name="population", conditions={"iso3": "KIR"})

def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)


pop["age_bins"] = find_agebins(
    pop["start_age"], [0, 5, 15, 35, 50, 100], ["0-4", "5-14", "15-34", "35-49", "50+"]
)
pop = pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()
pop['sum'] = pop['sum'].astype('int')
pop['type'] = "Empriric"

## Running an empty model

In [3]:
project = get_project("tb_dynamics", "kiribati")

In [4]:
base_params = project.param_set.baseline

In [5]:
base_params

Params{'time': {'start': 1950.0, 'end': 2020.0, 'step': 0.1}, 'iso3': 'KIR', 'crude_birth_rate': 1.0, 'start_population_size': 33048.0, 'infectious_seed': 1.0, 'age_breakpoints': [0, 5, 15, 35, 50], 'crude_death_rate': 0.001, 'description': None}

In [6]:
model_0 = project.run_baseline_model(base_params)

['0', '5', '15', '35', '50']
[0, 5, 15, 35, 50]


In [7]:
df_0 = model_0.get_outputs_df()
derived_df_0 = model_0.get_derived_outputs_df()

In [8]:
total_pop = input_db.query(table_name="population", conditions={"iso3": base_params['iso3']})
total_pop["age_bins"] = find_agebins(total_pop["start_age"], [0, 100], ["Total"])
total_pop = total_pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()

In [9]:
df_0

Unnamed: 0,susceptibleXage_0,susceptibleXage_5,susceptibleXage_15,susceptibleXage_35,susceptibleXage_50,early_latentXage_0,early_latentXage_5,early_latentXage_15,early_latentXage_35,early_latentXage_50,...,on_treatmentXage_0,on_treatmentXage_5,on_treatmentXage_15,on_treatmentXage_35,on_treatmentXage_50,recoveredXage_0,recoveredXage_5,recoveredXage_15,recoveredXage_35,recoveredXage_50
1950.0,5509.833273,7943.759622,12179.631445,3543.892762,3869.882898,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1950.1,5518.543787,7969.982994,12188.611414,3576.519806,3821.978754,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1950.2,5527.162559,7996.103150,12197.800470,3608.933832,3775.173480,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1950.3,5535.696083,8022.119353,12207.196387,3641.137563,3729.445327,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1950.4,5544.150501,8048.031097,12216.796923,3673.133687,3684.773082,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2019.6,13646.493356,24614.947215,38403.258841,22168.144655,10106.903163,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2019.7,13663.533381,24639.387911,38450.066696,22202.306463,10130.375599,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2019.8,13680.622012,24663.922642,38496.876455,22236.459095,10153.787483,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2019.9,13697.758440,24688.551426,38543.689046,22270.602629,10177.139493,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [10]:
fig2_1 = px.line(
    derived_df_0,
    x=derived_df_0.index,
    y="total_population",
)
fig2_2 = px.scatter(total_pop, x="year", y="sum")
fig2_2.update_traces(marker=dict(color="red"))
fig2_3 = go.Figure(
    data=fig2_1.data + fig2_2.data,
)
fig2_3.update_layout(
    title="Modelled vs Actual Population", title_x=0.5, xaxis_title="Year", yaxis_title="Population"
)
fig2_3.show()

## Modeled age group

In [None]:
modeled_pop = df_0.filter(regex="^susceptible", axis=1).reset_index()
modeled_pop.rename(columns={"index": "year", "variable": "age_bins", "value": "sum"}, inplace=True)
modeled_pop = modeled_pop.melt(id_vars="year")
modeled_pop["variable"] = modeled_pop["variable"].replace(
    [
        "susceptibleXage_0",
        "susceptibleXage_5",
        "susceptibleXage_15",
        "susceptibleXage_35",
        "susceptibleXage_50",
    ],
    ["0-4", "5-14", "15-34", "35-49", "50+"],
)
modeled_pop.rename(columns={"variable": "age_bins", "value": "sum"}, inplace=True)
ryear = list(range(int(base_params["time"]["start"]), int(base_params["time"]["end"] + 1), 5))
modeled_pop = modeled_pop[modeled_pop["year"].isin(ryear)]
modeled_pop["sum"] = modeled_pop["sum"].astype("int")
modeled_pop["year"] = modeled_pop["year"].astype("int")
modeled_pop["type"] = "Modeled"
modeled_pop.reset_index(drop=True, inplace=True)


In [None]:
total_pop = pd.concat([pop,modeled_pop])

In [None]:
fig4 = go.Figure()
fig4.update_layout(
    template = "plotly",
    xaxis =  dict(title_text="Year"),
    yaxis = dict(title_text="Population"),
    barmode = "stack",
)

for group in total_pop['age_bins'].unique():
    plot_df = total_pop[total_pop['age_bins'] == group]
    fig4.add_trace(
        go.Bar(x = [plot_df['year'],plot_df['type']], y = plot_df['sum'], name = group)
    )

fig4

In [None]:
fig5 =  px.bar(
    total_pop,
    x='year',
    y='sum',
    color='age_bins',
    facet_col="type",
    labels={'year':'Year', 'sum':'Population', 'age_bins':'Age group'}
)
fig5.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig5.show()