In [8]:
import warnings
from typing import List

import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from autumn.core.inputs.database import get_input_db
from autumn.core.project import get_project
from autumn.settings import Region, Models
from autumn.core import inputs
from autumn.core.inputs.social_mixing.build_synthetic_matrices import build_synthetic_matrices
from autumn.core.utils.display import pretty_print

In [9]:
warnings.filterwarnings("ignore")
pd.options.plotting.backend = "plotly"

In [10]:
region = Region.KIRIBATI
model = Models.TBD
p = get_project(model, region, reload=True)
baseline_params = p.param_set.baseline

In [11]:
pretty_print(baseline_params)

Params
{ 'age_breakpoints': [0, 5, 15, 35, 50],
  'age_mixing': {'age_adjust': True, 'source_iso3': 'FJI'},
  'age_specific_latency': { 'early_activation': { 0: 0.0066,
                                                  5: 0.0027,
                                                  15: 0.00027},
                            'late_activation': { 0: 1.9e-11,
                                                 5: 6.4e-06,
                                                 15: 3.3e-06},
                            'stabilisation': {0: 0.012, 5: 0.012, 15: 0.0054}},
  'contact_rate': 0.00178,
  'crude_birth_rate': 1.0,
  'crude_death_rate': 0.0008,
  'cumulative_start_time': 1990.0,
  'description': None,
  'infect_death_rate': 0.2,
  'infectious_seed': 1.0,
  'iso3': 'KIR',
  'progression_multiplier': 1.0,
  'rr_infection_latent': 0.21,
  'rr_infection_recovered': 1.0,
  'self_recovery_rate': 0.2,
  'start_population_size': 33048.0,
  'time': {'end': 2020.0, 'start': 1950.0, 'step': 0.1}}


In [12]:
iso3 = baseline_params['iso3']
source_iso3 = baseline_params['age_mixing']['source_iso3']
age_adjust = baseline_params['age_mixing']['age_adjust']
age_string_map = {
    0: "0-4",
    5: "5-14",
    15: "15-34",
    35: "35-49",
    50: "50+",
}

## Population


In [13]:
modelled_age_groups = baseline_params["age_breakpoints"]
print(f"Modelled age groups are {modelled_age_groups}")

Modelled age groups are [0, 5, 15, 35, 50]


In [14]:
input_db = get_input_db()
pop = input_db.query(table_name="population", conditions={"iso3": "KIR"})
pop.groupby('year')['population'].sum().plot()

In [15]:
def find_agebins(x, bins: List[int], labels: List[str]):
    return pd.cut(x, bins, labels=labels, right=False)


pop["age_bins"] = find_agebins(
    pop["start_age"], [0, 5, 15, 35, 50, 100], ["0-4", "5-14", "15-34", "35-49", "50+"]
)
pop = pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()
pop['sum'] = pop['sum'].astype('int')
pop = pop.rename({'sum':'population'}, axis=1)

In [16]:
age_plot = px.bar(pop, x='year', y='population', color='age_bins', barmode="group")
age_plot.show()

### Birth rate

In [17]:
#Query and visualize the crude birth rate from database
birth_rates, years = inputs.get_crude_birth_rate(iso3)
birth_rates = [b / 1000.0 for b in birth_rates]  # Birth rates are provided / 1000 population
birth_rates_series = pd.Series(birth_rates, index=years)
br_fig= px.line(birth_rates_series)
br_fig.update_traces(mode='markers+lines')
br_fig.update_layout(
    title="Crude birth rate of Kiribati from 1950 to 2020",
    title_x=0.5,
    xaxis_title="Year",
    yaxis_title="Crude birth rate",
    showlegend=False,
)

### Death rate

In [18]:
#Query and visualize the crude birth rate from database
death_rates, years = inputs.get_death_rates_by_agegroup(modelled_age_groups, iso3)
death_rate_by_age_group = pd.DataFrame.from_dict(death_rates, orient='index').T
death_rate_by_age_group.columns = age_string_map.values()
death_rate_by_age_group.index = years
dr_fig = px.line(death_rate_by_age_group, x=death_rate_by_age_group.index, y=death_rate_by_age_group.columns)
dr_fig.update_traces(mode='markers+lines')
dr_fig.update_layout(
    title="Death rate by age group of Kiribati from 1950 to 2020",
    title_x=0.5,
    xaxis_title="Year",
    yaxis_title="Death rate",
    legend_title="",
)
dr_fig.show()

## Mixing matrix

In [19]:
print(f"Target country: {region.upper()} ({iso3})")
print(f"Proxy country: {source_iso3} (FIJI)")
print("Always age-adjusted to target population")

Target country: KIRIBATI (KIR)
Proxy country: FJI (FIJI)
Always age-adjusted to target population


Please note that there is no detailed mixing matrix for the proxy country, so all locations will be requested

In [20]:
mixing_matrices = build_synthetic_matrices(
    iso3, source_iso3, modelled_age_groups, age_adjust, requested_locations=['all_locations']
)

In [21]:
print(f"Total daily contacts for each age group is {mixing_matrices['all_locations'].sum(axis=1)}")
px.imshow(mixing_matrices["all_locations"])

Total daily contacts for each age group is [5.4292262  7.24222393 6.08134097 5.20221704 4.38075489]


## Age adjustment

### Latency

In [22]:
age_params = baseline_params["age_specific_latency"]

#### Early Activation

In [23]:
pretty_print(age_params["early_activation"])

{0: 0.0066, 5: 0.0027, 15: 0.00027}


#### Stabilisation

In [24]:
pretty_print(age_params["stabilisation"])

{0: 0.012, 5: 0.012, 15: 0.0054}


#### Late Activation

In [25]:
pretty_print(age_params["late_activation"])

{0: 1.9e-11, 5: 6.4e-06, 15: 3.3e-06}


## Calibration target

Innitially, the model was calibrated to the population size

In [26]:
project = get_project("tb_dynamics", "kiribati")
base_params = project.param_set.baseline
model_0 = project.run_baseline_model(baseline_params)
df_0 = model_0.get_outputs_df()
derived_df_0 = model_0.get_derived_outputs_df()
total_pop = input_db.query(table_name="population", conditions={"iso3": iso3})
total_pop["age_bins"] = find_agebins(total_pop["start_age"], [0, 100], ["Total"])
total_pop = total_pop.groupby(["year", "age_bins"])["population"].agg(["sum"]).reset_index()
fig2_1 = px.line(
    derived_df_0,
    x=derived_df_0.index,
    y="total_population",
)
fig2_2 = px.scatter(total_pop, x="year", y="sum")
fig2_2.update_traces(marker=dict(color="red"))
fig2_3 = go.Figure(
    data=fig2_1.data + fig2_2.data,
)
fig2_3.update_layout(
    title="Modelled vs Actual Population", title_x=0.5, xaxis_title="Year", yaxis_title="Population"
)
fig2_3.show()