In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from inputs.constants import ANALYSIS_END_DATE, PLOT_START_DATE, PROJECT_PATH
from aust_covid.inputs import load_calibration_targets, load_who_data, load_serosurvey_data, get_ifrs, load_raw_pop_data, get_raw_state_mobility
from aust_covid.model import build_model
from aust_covid.plotting import plot_state_mobility, plot_processed_mobility, plot_example_model_matrices
from emutools.tex import StandardTexDoc
from emutools.inputs import load_param_info

In [None]:
app_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')

In [None]:
case_targets = load_calibration_targets(app_doc)
death_targets = load_who_data(app_doc)
serosurvey_targets = load_serosurvey_data(app_doc)

In [None]:
param_info = load_param_info(PROJECT_PATH / 'inputs' / 'parameters.yml')
ifrs = get_ifrs(app_doc)
param_info['value'].update(ifrs)
parameters = param_info['value'].to_dict()

In [None]:
aust_model = build_model(app_doc, mobility_sens=True)

In [None]:
aust_model.run(parameters=parameters)

In [None]:
fig = make_subplots(rows=3, cols=2)
derived_outputs = aust_model.get_derived_outputs_df()
x_vals = derived_outputs.index
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['notifications_ma'], name='modelled cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=case_targets.index, y=case_targets, name='reported cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['deaths_ma'], name='deaths_ma'), row=1, col=2)
fig.add_trace(go.Scatter(x=death_targets.index, y=death_targets, name='reported deaths ma'), row=1, col=2)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['adult_seropos_prop'], name='adult seropos'), row=2, col=1)
fig.add_trace(go.Scatter(x=serosurvey_targets.index, y=serosurvey_targets, name='seropos estimates'), row=2, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['reproduction_number'], name='reproduction number'), row=2, col=2)
for agegroup in aust_model.stratifications['agegroup'].strata:
    fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs[f'deathsXagegroup_{agegroup}'], name=f'{agegroup} deaths'), row=3, col=1)
    fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs[f'deathsXagegroup_{agegroup}'], name=f'{agegroup} deaths'), row=3, col=2)
fig['layout']['yaxis6'].update(type='log', range=[-2.0, 2.0])
fig.update_xaxes(range=(PLOT_START_DATE, ANALYSIS_END_DATE))
fig.update_layout(height=600, width=1200)
fig.show()

In [None]:
plot_example_model_matrices(aust_model, parameters, app_doc, show_fig=True)

In [None]:
app_doc.save_content()