In [1]:
import warnings
from typing import List

import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from autumn.core.inputs.database import get_input_db
from autumn.core.project import get_project
from autumn.settings import Region, Models
from autumn.core import inputs
from autumn.core.inputs.social_mixing.build_synthetic_matrices import build_synthetic_matrices
from autumn.core.utils.display import pretty_print
from autumn.core.inputs.tb_camau import queries
import pathlib
from autumn.model_features.curve.interpolate import build_static_sigmoidal_multicurve

In [3]:
warnings.filterwarnings("ignore")
pd.options.plotting.backend = "plotly"

In [4]:
csv_path = pathlib.Path("camau.csv")

In [5]:
region = Region.CAMAU
model = Models.TBD2
p = get_project(model, region, reload=True)
baseline_params = p.param_set.baseline

In [6]:
iso3 = baseline_params['country']['iso3']
source_iso3 = baseline_params['age_mixing']['source_iso3']
age_adjust = baseline_params['age_mixing']['age_adjust']
age_string_map = {
    0: "0-4",
    5: "5-14",
    15: "15-34",
    35: "35-49",
    50: "50+",
}

## Population

In [7]:
modelled_age_groups = baseline_params["age_breakpoints"]
print(f"Modelled age groups are {modelled_age_groups}")

Modelled age groups are [0, 5, 15, 35, 50]


In [8]:
pop_df = pd.read_csv(csv_path)
pop_df = pop_df.set_index(['year'])
pop_df['population'].plot()

### Birth rate

In [9]:
birth_rates, years = inputs.get_crude_birth_rate("VNM")
birth_rates = [b / 1000.0 for b in birth_rates]  # Birth rates are provided / 1000 population
birth_rates_series = pd.Series(birth_rates, index=years)
br_fig= px.line(birth_rates_series)
br_fig.update_traces(mode='markers+lines')
br_fig.update_layout(
    title="Crude birth rate of Ca Mau",
    title_x=0.5,
    xaxis_title="Year",
    yaxis_title="Crude birth rate",
    showlegend=False,
)

### Death rate

In [10]:
#Query and visualize the crude birth rate from database
years, death_rates = queries.get_camau_death_rate()
death_rates = [b / 1000.0 for b in death_rates]  # Birth rates are provided / 1000 population
death_rates_series = pd.Series(death_rates, index=years)
dr_fig= px.line(death_rates_series)
dr_fig.update_traces(mode='markers+lines')
dr_fig.update_layout(
    title="Crude death rate of Ca Mau",
    title_x=0.5,
    xaxis_title="Year",
    yaxis_title="Crude death rate",
    showlegend=False,
)

### Mixing matrix

In [11]:
print(f"Target region: {region.upper()} ({iso3})")
print(f"Proxy country: {source_iso3}")
print("Always age-adjusted to target population")

Target region: CAMAU (VNM)
Proxy country: VNM
Always age-adjusted to target population


In [12]:
mixing_matrices = build_synthetic_matrices(
    iso3, source_iso3, modelled_age_groups, age_adjust, requested_locations=['all_locations']
)
print(f"Total daily contacts for each age group is {mixing_matrices['all_locations'].sum(axis=1)}")
px.imshow(mixing_matrices["all_locations"])

Total daily contacts for each age group is [ 5.64459085  6.61640819  7.65535217  9.06586113 10.04091602]


In [13]:
update_params = {
    'start_population_size':40000,
    'infectious_seed': 300,
    #'contact_rate':0.002
}
params = baseline_params.update(update_params,calibration_format=True)
model_0 = p.run_baseline_model(params)
derived_df_0 = model_0.get_derived_outputs_df()

In [None]:
#pretty_print(baseline_params)

In [14]:
fig2_1 = px.line(
    derived_df_0,
    x=derived_df_0.index,
    y="total_population",
)
fig2_2 = px.scatter(pop_df.reset_index(), x="year", y="population")
fig2_2.update_traces(marker=dict(color="red"))
fig2_3 = go.Figure(
    data=fig2_1.data + fig2_2.data,
)
fig2_3.update_layout(
    title="Modelled vs Data", title_x=0.5, xaxis_title="Year", yaxis_title="Population"
)
fig2_3.show()

In [15]:
derived_df_0.incidence.plot()