In [1]:
import warnings
from typing import List

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from autumn.core.inputs.database import get_input_db
from autumn.core.inputs.demography.queries import get_population_by_agegroup
from autumn.core.project import ParameterSet, get_project
from autumn.core.utils.display import pretty_print

In [2]:
warnings.filterwarnings("ignore")
input_db = get_input_db()
def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)

pop = input_db.query(table_name="population", conditions={"iso3": "KIR"})

def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)


pop["age_bins"] = find_agebins(
    pop["start_age"], [0, 5, 15, 35, 50, 100], ["0-4", "5-14", "15-34", "35-49", "50+"]
)
pop = pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()
pop['sum'] = pop['sum'].astype('int')
pop['type'] = "Data"

## Running an empty model

In [5]:
project = get_project("tb_dynamics", "kiribati")

In [6]:
base_params = project.param_set.baseline

In [7]:
pretty_print(base_params)

Params
{ 'age_breakpoints': [0, 5, 15, 35, 50],
  'age_mixing': {'age_adjust': True, 'source_iso3': 'FJI'},
  'age_specific_latency': { 'early_activation': { 0: 0.0066,
                                                  5: 0.0027,
                                                  15: 0.00027},
                            'late_activation': { 0: 1.9e-11,
                                                 5: 6.4e-06,
                                                 15: 3.3e-06},
                            'stabilisation': {0: 0.012, 5: 0.012, 15: 0.0054}},
  'contact_rate': 0.00178,
  'crude_birth_rate': 1.0,
  'crude_death_rate': 0.0008,
  'cumulative_start_time': 1950.0,
  'description': None,
  'infect_death_rate': 0.2,
  'infectious_seed': 300.0,
  'iso3': 'KIR',
  'progression_multiplier': 1.0,
  'rr_infection_latent': 0.21,
  'rr_infection_recovered': 1.0,
  'self_recovery_rate': 0.2,
  'start_population_size': 33048.0,
  'time': {'end': 2020.0, 'start': 1950.0, 'step': 0.1}}


In [8]:
model_0 = project.run_baseline_model(base_params)

In [9]:
df_0 = model_0.get_outputs_df()
derived_df_0 = model_0.get_derived_outputs_df()

In [10]:
derived_df_0

Unnamed: 0,total_population,percentage_latent,prevalence_infectious,incidence_early,incidence_late,cumulative_deaths,cumulative_diseased,incidence
1950.0,33048.000000,0.000000,907.770516,0.000000,0.000000,6.000000,0.000000,0.000000
1950.1,33070.736637,0.331150,878.262699,30.220854,0.010682,11.904479,3.023154,91.414766
1950.2,33094.505804,0.638915,863.321236,76.791056,0.048014,17.666078,10.707061,232.180744
1950.3,33119.219860,0.933839,856.377123,102.344875,0.111748,23.359451,20.952723,309.356994
1950.4,33144.822410,1.221634,853.948904,117.087093,0.192984,29.026104,32.680731,353.841318
...,...,...,...,...,...,...,...,...
2019.6,84533.702060,64.362522,1495.410175,449.077675,73.020438,12954.012847,27630.207087,617.621257
2019.7,84617.169435,64.367008,1494.756695,449.091200,73.092783,12979.302311,27682.425485,617.113508
2019.8,84700.630762,64.371407,1494.088109,449.770405,73.165033,13004.605540,27734.719029,617.392614
2019.9,84784.088944,64.375604,1493.503231,450.295091,73.237182,13029.923092,27787.072256,617.488822


In [None]:
total_pop = input_db.query(table_name="population", conditions={"iso3": base_params['iso3']})
total_pop["age_bins"] = find_agebins(total_pop["start_age"], [0, 100], ["Total"])
total_pop = total_pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()

In [None]:
fig2_1 = px.line(
    derived_df_0,
    x=derived_df_0.index,
    y="total_population",
)
fig2_2 = px.scatter(total_pop, x="year", y="sum")
fig2_2.update_traces(marker=dict(color="red"))
fig2_3 = go.Figure(
    data=fig2_1.data + fig2_2.data,
)
fig2_3.update_layout(
    title="Modelled vs Actual Population", title_x=0.5, xaxis_title="Year", yaxis_title="Population"
)
fig2_3.show()

## Modeled age group

In [None]:
modeled_pop = df_0.filter(regex="^susceptible", axis=1).reset_index()
modeled_pop.rename(columns={"index": "year", "variable": "age_bins", "value": "sum"}, inplace=True)
modeled_pop = modeled_pop.melt(id_vars="year")
modeled_pop["variable"] = modeled_pop["variable"].replace(
    [
        "susceptibleXage_0",
        "susceptibleXage_5",
        "susceptibleXage_15",
        "susceptibleXage_35",
        "susceptibleXage_50",
    ],
    ["0-4", "5-14", "15-34", "35-49", "50+"],
)
modeled_pop.rename(columns={"variable": "age_bins", "value": "sum"}, inplace=True)
ryear = list(range(int(base_params["time"]["start"]), int(base_params["time"]["end"] + 1), 5))
modeled_pop = modeled_pop[modeled_pop["year"].isin(ryear)]
modeled_pop["sum"] = modeled_pop["sum"].astype("int")
modeled_pop["year"] = modeled_pop["year"].astype("int")
modeled_pop["type"] = "Modeled"
modeled_pop.reset_index(drop=True, inplace=True)


In [None]:
total_pop = pd.concat([pop,modeled_pop])

In [None]:
fig4 = go.Figure()
fig4.update_layout(
    template = "plotly",
    xaxis =  dict(title_text="Year"),
    yaxis = dict(title_text="Population"),
    barmode = "stack",
)

for group in total_pop['age_bins'].unique():
    plot_df = total_pop[total_pop['age_bins'] == group]
    fig4.add_trace(
        go.Bar(x = [plot_df['year'],plot_df['type']], y = plot_df['sum'], name = group)
    )

fig4

In [None]:
fig5 =  px.bar(
    total_pop,
    x='year',
    y='sum',
    color='age_bins',
    facet_col="type",
    labels={'year':'Year', 'sum':'Population', 'age_bins':'Age group'}
)
fig5.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig5.show()