In [1]:
import warnings
from typing import List

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from autumn.core.inputs.database import get_input_db
from autumn.core.inputs.demography.queries import get_population_by_agegroup
from autumn.core.project import ParameterSet, get_project
from autumn.core.utils.display import pretty_print

In [2]:
warnings.filterwarnings("ignore")
input_db = get_input_db()
def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)

pop = input_db.query(table_name="population", conditions={"iso3": "KIR"})

def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)


pop["age_bins"] = find_agebins(
    pop["start_age"], [0, 5, 15, 35, 50, 100], ["0-4", "5-14", "15-34", "35-49", "50+"]
)
pop = pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()
pop['sum'] = pop['sum'].astype('int')
pop['type'] = "Data"

## Running an empty model

In [5]:
project = get_project("tb_dynamics", "kiribati")

In [6]:
base_params = project.param_set.baseline

In [7]:
pretty_print(base_params)

Params
{ 'age_breakpoints': [0, 5, 15, 35, 50],
  'age_mixing': {'age_adjust': True, 'source_iso3': 'FJI'},
  'age_specific_latency': { 'early_activation': { 0: 0.0066,
                                                  5: 0.0027,
                                                  15: 0.00027},
                            'late_activation': { 0: 1.9e-11,
                                                 5: 6.4e-06,
                                                 15: 3.3e-06},
                            'stabilisation': {0: 0.012, 5: 0.012, 15: 0.0054}},
  'contact_rate': 0.00178,
  'crude_birth_rate': 1.0,
  'crude_death_rate': 0.0008,
  'cumulative_start_time': 1950.0,
  'description': None,
  'infect_death_rate': 0.2,
  'infectious_seed': 300.0,
  'iso3': 'KIR',
  'progression_multiplier': 1.0,
  'rr_infection_latent': 0.21,
  'rr_infection_recovered': 1.0,
  'self_recovery_rate': 0.2,
  'start_population_size': 33048.0,
  'time': {'end': 2020.0, 'start': 1950.0, 'step': 0.1}}


In [36]:
update_params = {
    'contact_rate': 0.000100003133105504045,
    'rr_infection_latent': 0.2005519434974716,
    'rr_infection_recovered': 0.20460496591598593, 
    'infectious_seed': 1,
}
params = project.param_set.baseline.update(update_params, calibration_format=True)

In [37]:
model_0 = project.run_baseline_model(params)

In [38]:
pretty_print(params)

Params
{ 'age_breakpoints': [0, 5, 15, 35, 50],
  'age_mixing': {'age_adjust': True, 'source_iso3': 'FJI'},
  'age_specific_latency': { 'early_activation': { 0: 0.0066,
                                                  5: 0.0027,
                                                  15: 0.00027},
                            'late_activation': { 0: 1.9e-11,
                                                 5: 6.4e-06,
                                                 15: 3.3e-06},
                            'stabilisation': {0: 0.012, 5: 0.012, 15: 0.0054}},
  'contact_rate': 0.00010000313310550405,
  'crude_birth_rate': 1.0,
  'crude_death_rate': 0.0008,
  'cumulative_start_time': 1950.0,
  'description': None,
  'infect_death_rate': 0.2,
  'infectious_seed': 1,
  'iso3': 'KIR',
  'progression_multiplier': 1.0,
  'rr_infection_latent': 0.2005519434974716,
  'rr_infection_recovered': 0.20460496591598593,
  'self_recovery_rate': 0.2,
  'start_population_size': 33048.0,
  'time': {'end': 2020.

In [39]:
df_0 = model_0.get_outputs_df()
derived_df_0 = model_0.get_derived_outputs_df()

In [40]:
derived_df_0

Unnamed: 0,total_population,percentage_latent,prevalence_infectious,incidence_early,incidence_late,cumulative_deaths,cumulative_diseased,incidence
1950.0,33048.000000,0.000000,3.025902,0.000000,0.000000,0.020000,0.000000,0.000000
1950.1,33076.613440,0.000062,2.895769,0.005693,0.000002,0.039578,0.000570,0.017218
1950.2,33106.127713,0.000119,2.773816,0.014346,0.000009,0.058339,0.002005,0.043361
1950.3,33136.527320,0.000172,2.658479,0.018793,0.000021,0.076332,0.003886,0.056777
1950.4,33167.797058,0.000222,2.548805,0.020974,0.000036,0.093595,0.005987,0.063343
...,...,...,...,...,...,...,...,...
2019.6,108938.130647,0.000052,0.000168,0.000004,0.000069,0.527270,0.092842,0.000067
2019.7,109084.050573,0.000052,0.000167,0.000005,0.000069,0.527273,0.092850,0.000067
2019.8,109230.045320,0.000052,0.000166,0.000005,0.000068,0.527277,0.092857,0.000067
2019.9,109376.115776,0.000051,0.000165,0.000004,0.000068,0.527281,0.092864,0.000066


In [None]:
total_pop = input_db.query(table_name="population", conditions={"iso3": base_params['iso3']})
total_pop["age_bins"] = find_agebins(total_pop["start_age"], [0, 100], ["Total"])
total_pop = total_pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()

In [None]:
fig2_1 = px.line(
    derived_df_0,
    x=derived_df_0.index,
    y="total_population",
)
fig2_2 = px.scatter(total_pop, x="year", y="sum")
fig2_2.update_traces(marker=dict(color="red"))
fig2_3 = go.Figure(
    data=fig2_1.data + fig2_2.data,
)
fig2_3.update_layout(
    title="Modelled vs Actual Population", title_x=0.5, xaxis_title="Year", yaxis_title="Population"
)
fig2_3.show()

## Modeled age group

In [None]:
modeled_pop = df_0.filter(regex="^susceptible", axis=1).reset_index()
modeled_pop.rename(columns={"index": "year", "variable": "age_bins", "value": "sum"}, inplace=True)
modeled_pop = modeled_pop.melt(id_vars="year")
modeled_pop["variable"] = modeled_pop["variable"].replace(
    [
        "susceptibleXage_0",
        "susceptibleXage_5",
        "susceptibleXage_15",
        "susceptibleXage_35",
        "susceptibleXage_50",
    ],
    ["0-4", "5-14", "15-34", "35-49", "50+"],
)
modeled_pop.rename(columns={"variable": "age_bins", "value": "sum"}, inplace=True)
ryear = list(range(int(base_params["time"]["start"]), int(base_params["time"]["end"] + 1), 5))
modeled_pop = modeled_pop[modeled_pop["year"].isin(ryear)]
modeled_pop["sum"] = modeled_pop["sum"].astype("int")
modeled_pop["year"] = modeled_pop["year"].astype("int")
modeled_pop["type"] = "Modeled"
modeled_pop.reset_index(drop=True, inplace=True)


In [None]:
total_pop = pd.concat([pop,modeled_pop])

In [None]:
fig4 = go.Figure()
fig4.update_layout(
    template = "plotly",
    xaxis =  dict(title_text="Year"),
    yaxis = dict(title_text="Population"),
    barmode = "stack",
)

for group in total_pop['age_bins'].unique():
    plot_df = total_pop[total_pop['age_bins'] == group]
    fig4.add_trace(
        go.Bar(x = [plot_df['year'],plot_df['type']], y = plot_df['sum'], name = group)
    )

fig4

In [None]:
fig5 =  px.bar(
    total_pop,
    x='year',
    y='sum',
    color='age_bins',
    facet_col="type",
    labels={'year':'Year', 'sum':'Population', 'age_bins':'Age group'}
)
fig5.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig5.show()