In [12]:
import warnings
from typing import List

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from autumn.core.inputs.database import get_input_db
from autumn.core.inputs.demography.queries import get_population_by_agegroup
from autumn.core.project import ParameterSet, get_project
from autumn.core.utils.display import pretty_print

In [13]:
warnings.filterwarnings("ignore")

## Modelled Population

In [14]:
input_db = get_input_db()
pop = input_db.query(table_name="population", conditions={"iso3": "KIR"})

def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)


pop["age_bins"] = find_agebins(
    pop["start_age"], [0, 5, 15, 35, 50, 100], ["0-4", "5-14", "15-34", "35-49", "50+"]
)
pop = pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()


In [15]:
fig = px.area(
    pop,
    x="year",
    y="sum",
    color="age_bins",
    labels={"sum": "Population", "year": "Year", "age_bins": "Age Group"},
    title="Population of Kiribati over time",
)
fig.update_layout(title_x=0.5)
fig.show()

## Running an empty model

In [16]:
project = get_project("tb_dynamics", "kiribati")

In [17]:
base_params = project.param_set.baseline

In [18]:
pretty_print(base_params)

Params
{ 'age_breakpoints': [0, 5, 15, 35, 50],
  'crude_birth_rate': 1.0,
  'crude_death_rate': 1,
  'description': None,
  'infectious_seed': 1.0,
  'iso3': 'KIR',
  'start_population_size': 33048.0,
  'time': {'end': 2020.0, 'start': 1950.0, 'step': 0.1}}


In [7]:
model_0 = project.run_baseline_model(base_params)

In [8]:
df_0 = model_0.get_derived_outputs_df()

In [9]:
df_0

Unnamed: 0,population_size
1950.0,33048.000000
1950.1,33020.432104
1950.2,32994.274254
1950.3,32969.506010
1950.4,32946.107319
...,...
2019.6,96542.446175
2019.7,96671.813172
2019.8,96801.245971
2019.9,96930.745367


In [10]:
total_pop = input_db.query(table_name="population", conditions={"iso3": "KIR"})
total_pop["age_bins"] = find_agebins(total_pop["start_age"], [0, 100], ["Total"])
total_pop = total_pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()


In [11]:
fig1 = px.line(
    df_0,
    x=df_0.index,
    y="population_size",
)
fig2 = px.scatter(total_pop, x="year", y="sum")
fig2.update_traces(marker=dict(color="red"))
fig3 = go.Figure(
    data=fig1.data + fig2.data,
)
fig3.update_layout(
    title="Modelled vs Actual Population", title_x=0.5, xaxis_title="Year", yaxis_title="Population"
)
fig3.show()
