In [1]:
import warnings
from typing import List

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from autumn.core.inputs.database import get_input_db
from autumn.core.inputs.demography.queries import get_population_by_agegroup
from autumn.core.project import ParameterSet, get_project
from autumn.core.utils.display import pretty_print

In [2]:
warnings.filterwarnings("ignore")
input_db = get_input_db()
def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)

pop = input_db.query(table_name="population", conditions={"iso3": "KIR"})

def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)


pop["age_bins"] = find_agebins(
    pop["start_age"], [0, 5, 15, 35, 50, 100], ["0-4", "5-14", "15-34", "35-49", "50+"]
)
pop = pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()
pop['sum'] = pop['sum'].astype('int')
pop['type'] = "Data"

## Running an empty model

In [3]:
project = get_project("tb_dynamics", "kiribati")

In [4]:
base_params = project.param_set.baseline

In [5]:
pretty_print(base_params)

Params
{ 'age_breakpoints': [0, 5, 15, 35, 50],
  'age_mixing': {'age_adjust': True, 'source_iso3': 'FJI'},
  'age_specific_latency': { 'early_activation': { 0: 0.0066,
                                                  5: 0.0027,
                                                  15: 0.00027},
                            'late_activation': { 0: 1.9e-11,
                                                 5: 6.4e-06,
                                                 15: 3.3e-06},
                            'stabilisation': {0: 0.012, 5: 0.012, 15: 0.0054}},
  'contact_rate': 0.00178,
  'crude_birth_rate': 1.0,
  'crude_death_rate': 0.0008,
  'cumulative_start_time': 1990.0,
  'description': None,
  'infect_death_rate': 0.2,
  'infectious_seed': 1.0,
  'iso3': 'KIR',
  'progression_multiplier': 1.0,
  'rr_infection_latent': 0.21,
  'rr_infection_recovered': 1.0,
  'self_recovery_rate': 0.2,
  'start_population_size': 26000.0,
  'time': {'end': 2050.0, 'start': 1900.0, 'step': 0.1}}


In [104]:
update_params = {'contact_rate': 0.002000099234094257,
 'rr_infection_latent': 0.20144223241352158,
 'rr_infection_recovered': 0.2028341387124968,
 'start_population_size': 11998.06440742042}
params = project.param_set.baseline.update(update_params, calibration_format=True)

In [105]:
model_0 = project.run_baseline_model(params)

In [106]:
df_0 = model_0.get_outputs_df()
derived_df_0 = model_0.get_derived_outputs_df()

In [107]:
derived_df_0

Unnamed: 0,total_population,percentage_latent,prevalence_infectious,incidence_early,incidence_late,cumulative_deaths,cumulative_diseased,incidence
1800.0,4000.000000,0.000000,25.000000,0.000000,0.000000,0.000000,0.000000,0.000000
1800.1,3996.643666,0.008750,24.193407,0.104161,0.000031,0.000000,0.000000,2.606997
1800.2,3993.458492,0.016894,23.800103,0.264023,0.000142,0.000000,0.000000,6.614934
1800.3,3990.441757,0.024728,23.631267,0.351004,0.000331,0.000000,0.000000,8.804427
1800.4,3987.590895,0.032405,23.592414,0.400867,0.000575,0.000000,0.000000,10.067277
...,...,...,...,...,...,...,...,...
2019.6,99433.460912,64.368944,1495.368220,528.337188,85.901515,8031.956582,16780.980904,617.738432
2019.7,99531.662180,64.373379,1494.696104,528.845527,85.986524,8061.702505,16842.464109,617.725091
2019.8,99629.858208,64.377646,1494.078180,529.499813,86.071415,8091.464943,16904.021232,617.858179
2019.9,99728.050098,64.381728,1493.528545,529.788528,86.156186,8121.245092,16965.615703,617.624342


In [108]:
total_pop = input_db.query(table_name="population", conditions={"iso3": base_params['iso3']})
total_pop["age_bins"] = find_agebins(total_pop["start_age"], [0, 100], ["Total"])
total_pop = total_pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()

In [109]:
fig2_1 = px.line(
    derived_df_0,
    x=derived_df_0.index,
    y="total_population",
)
fig2_2 = px.scatter(total_pop, x="year", y="sum")
fig2_2.update_traces(marker=dict(color="red"))
fig2_3 = go.Figure(
    data=fig2_1.data + fig2_2.data,
)
fig2_3.update_layout(
    title="Modelled vs Data", title_x=0.5, xaxis_title="Year", yaxis_title="Population"
)
fig2_3.show()

## Modeled age group

In [None]:
modeled_pop = df_0.filter(regex="^susceptible", axis=1).reset_index()
modeled_pop.rename(columns={"index": "year", "variable": "age_bins", "value": "sum"}, inplace=True)
modeled_pop = modeled_pop.melt(id_vars="year")
modeled_pop["variable"] = modeled_pop["variable"].replace(
    [
        "susceptibleXage_0",
        "susceptibleXage_5",
        "susceptibleXage_15",
        "susceptibleXage_35",
        "susceptibleXage_50",
    ],
    ["0-4", "5-14", "15-34", "35-49", "50+"],
)
modeled_pop.rename(columns={"variable": "age_bins", "value": "sum"}, inplace=True)
ryear = list(range(int(base_params["time"]["start"]), int(base_params["time"]["end"] + 1), 5))
modeled_pop = modeled_pop[modeled_pop["year"].isin(ryear)]
modeled_pop["sum"] = modeled_pop["sum"].astype("int")
modeled_pop["year"] = modeled_pop["year"].astype("int")
modeled_pop["type"] = "Modeled"
modeled_pop.reset_index(drop=True, inplace=True)


In [None]:
total_pop = pd.concat([pop,modeled_pop])

In [None]:
fig4 = go.Figure()
fig4.update_layout(
    template = "plotly",
    xaxis =  dict(title_text="Year"),
    yaxis = dict(title_text="Population"),
    barmode = "stack",
)

for group in total_pop['age_bins'].unique():
    plot_df = total_pop[total_pop['age_bins'] == group]
    fig4.add_trace(
        go.Bar(x = [plot_df['year'],plot_df['type']], y = plot_df['sum'], name = group)
    )

fig4

In [None]:
fig5 =  px.bar(
    total_pop,
    x='year',
    y='sum',
    color='age_bins',
    facet_col="type",
    labels={'year':'Year', 'sum':'Population', 'age_bins':'Age group'}
)
fig5.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig5.show()