In [1]:
#Import neccessary dependencies
import warnings
from typing import List

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from autumn.core import inputs
from autumn.core.inputs.database import get_input_db
from autumn.core.inputs.demography.queries import get_population_by_agegroup
from autumn.core.inputs.social_mixing.queries import get_prem_mixing_matrices
from autumn.core.inputs.social_mixing.build_synthetic_matrices import build_synthetic_matrices
from autumn.core.project import ParameterSet, get_project
from autumn.model_features.curve import scale_up_function

In [2]:
warnings.filterwarnings("ignore")

In [3]:
project = get_project("tb_dynamics", "kiribati")
base_params = project.param_set.baseline

# Baseline population visualization
The birth rate and death rate was taken from [United Nations's Popilation Division website] (https://population.un.org/wpp/Download/Standard/Mortality/). The death rate is stratified by age groups 0-4, 5-14, 15-34, 35-49, 50+.

## Crude birth rate

In [4]:
#Query and visualize the crude birth rate from database
birth_rates, years = inputs.get_crude_birth_rate(base_params["iso3"])
birth_rates = [b / 1000.0 for b in birth_rates]  # Birth rates are provided / 1000 population
birth_rates_series = pd.Series(birth_rates, index=years)
br_fig= px.line(birth_rates_series)
br_fig.update_traces(mode='markers+lines')
br_fig.update_layout(
    title="Crude birth rate of Kiribati from 1950 to 2020",
    title_x=0.5,
    xaxis_title="Year",
    yaxis_title="Crude birth rate",
    showlegend=False,
)

## Death rate by age group

In [5]:
death_rates, years = inputs.get_death_rates_by_agegroup(base_params['age_breakpoints'], base_params['iso3'])
death_rate_by_age_group = pd.DataFrame.from_dict(death_rates, orient='index').T
death_rate_by_age_group.columns = ['0-4','5-14','15-34', '35-49', '50+']
death_rate_by_age_group.index = years
dr_fig = px.line(death_rate_by_age_group, x=death_rate_by_age_group.index, y=death_rate_by_age_group.columns)
dr_fig.update_traces(mode='markers+lines')
dr_fig.update_layout(
    title="Death rate by age group of Kiribati from 1950 to 2020",
    title_x=0.5,
    xaxis_title="Year",
    yaxis_title="Death rate",
    legend_title="",
)
dr_fig.show()

## Comparing modeled population with actual population

In [6]:
# Query data from databas
input_db = get_input_db()
def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)

pop = input_db.query(table_name="population", conditions={"iso3": "KIR"})

pop["age_bins"] = find_agebins(
    pop["start_age"], [0, 5, 15, 35, 50, 100], ["0-4", "5-14", "15-34", "35-49", "50+"]
)
pop = pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()
pop['sum'] = pop['sum'].astype('int')
pop['type'] = "Empriric"

In [7]:
# run baseling model
model_0 = project.run_baseline_model(base_params)
# get modeled output
df_0 = model_0.get_outputs_df()
derived_df_0 = model_0.get_derived_outputs_df()
# create data frame
total_pop = input_db.query(table_name="population", conditions={"iso3": base_params['iso3']})
total_pop["age_bins"] = find_agebins(total_pop["start_age"], [0, 100], ["Total"])
total_pop = total_pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()

### Modeled vs actual population

In [8]:
fig2_1 = px.line(
    derived_df_0,
    x=derived_df_0.index,
    y="population_size",
)
fig2_2 = px.scatter(total_pop, x="year", y="sum")
fig2_2.update_traces(marker=dict(color="red"))
fig2_3 = go.Figure(
    data=fig2_1.data + fig2_2.data,
)
fig2_3.update_layout(
    title="Modelled vs Actual Population", title_x=0.5, xaxis_title="Year", yaxis_title="Population"
)
fig2_3.show()

In [9]:
modeled_pop = df_0.filter(regex="^susceptible", axis=1).reset_index()
modeled_pop.rename(columns={"index": "year", "variable": "age_bins", "value": "sum"}, inplace=True)
modeled_pop = modeled_pop.melt(id_vars="year")
modeled_pop["variable"] = modeled_pop["variable"].replace(
    [
        "susceptibleXage_0",
        "susceptibleXage_5",
        "susceptibleXage_15",
        "susceptibleXage_35",
        "susceptibleXage_50",
    ],
    ["0-4", "5-14", "15-34", "35-49", "50+"],
)
modeled_pop.rename(columns={"variable": "age_bins", "value": "sum"}, inplace=True)
ryear = list(range(int(base_params["time"]["start"]), int(base_params["time"]["end"] + 1), 5))
modeled_pop = modeled_pop[modeled_pop["year"].isin(ryear)]
modeled_pop["sum"] = modeled_pop["sum"].astype("int")
modeled_pop["year"] = modeled_pop["year"].astype("int")
modeled_pop["type"] = "Modeled"
modeled_pop.reset_index(drop=True, inplace=True)
total_pop = pd.concat([pop,modeled_pop])


In [10]:
fig4 = go.Figure()
fig4.update_layout(
    template = "plotly",
    xaxis =  dict(title_text="Year"),
    yaxis = dict(title_text="Population"),
    barmode = "stack",
)
for group in total_pop['age_bins'].unique():
    plot_df = total_pop[total_pop['age_bins'] == group]
    fig4.add_trace(
        go.Bar(x = [plot_df['year'],plot_df['type']], y = plot_df['sum'], name = group)
    )
fig4.show()

In [11]:
fig5 =  px.bar(
    total_pop,
    x='year',
    y='sum',
    color='age_bins',
    facet_col="type",
    labels={'year':'Year', 'sum':'Population'}
)
fig5.update_layout(legend_title = '')
fig5.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
fig5.show()

## Mixing matrix visualization

In [36]:
# Instead of using Prem mixing matrix, the model used a synthetic mixing matrix based on that of Fiji (follow Marshall mode of Romain et al.).
age_mixing_matrices = build_synthetic_matrices(
    base_params["iso3"],
    base_params["age_mixing"]["source_iso3"],
    base_params["age_breakpoints"],
    base_params["age_mixing"]["age_adjust"],
    requested_locations=["all_locations"],
)
age_mixing_matrix = age_mixing_matrices["all_locations"]
fig6 = px.imshow(np.flipud(np.transpose(age_mixing_matrix)))
fig6.update_layout(
    xaxis=dict(
        tickmode="linear", tick0=0, dtick=1, ticktext=["0-4", "5-14", "15-34", "35-49", "50+"]
    ),
    yaxis=dict(autorange=True, ticktext=["0-4", "5-14", "15-34", "35-49", "50+"])
)
fig6.update_coloraxes(showscale=False)
fig6.show()
