In [None]:
from plotly.graph_objects import Figure
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from aust_covid.inputs import load_national_data, load_owid_data, load_calibration_targets, load_who_data, load_serosurvey_data
from inputs.constants import INPUTS_PATH, SUPPLEMENT_PATH, PLOT_START_DATE, ANALYSIS_END_DATE
from emutools.tex import DummyTexDoc, StandardTexDoc
from aust_covid.calibration import get_targets
from aust_covid.utils import add_image_to_doc

In [None]:
app_doc = StandardTexDoc(SUPPLEMENT_PATH, 'targets', 'Targets', 'austcovid')
dummy_doc = DummyTexDoc()
national_data = load_national_data(dummy_doc)
owid_data = load_owid_data(dummy_doc)
combined_data = load_calibration_targets(dummy_doc)
targets = get_targets(app_doc)
case_targets = next((t.data for t in targets if t.name == 'notifications_ma'))
death_data = load_who_data(dummy_doc)
death_targets = next((t.data for t in targets if t.name == 'deaths_ma'))
serosurvey_data = load_serosurvey_data(dummy_doc)
serosurvey_targets = next((t.data for t in targets if t.name == 'adult_seropos_prop'))
serosurvey_ceiling = next((t.data for t in targets if t.name == 'seropos_ceiling'))

In [None]:
subplot_specs = [
    [{'colspan': 2}, None], 
    [{}, {}]
]
fig = make_subplots(rows=2, cols=2, specs=subplot_specs)
fig.add_trace(go.Scatter(x=combined_data.index, y=combined_data, name='combined cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=national_data.index, y=national_data, name='national cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=owid_data.index, y=owid_data, name='owid cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=case_targets.index, y=case_targets, name='final case target (smoothed)'), row=1, col=1)
fig.add_trace(go.Scatter(x=death_data.index, y=death_data, name='who deaths'), row=2, col=1)
fig.add_trace(go.Scatter(x=death_targets.index, y=death_targets, name='death target (smoothed)'), row=2, col=1)
fig.add_trace(go.Scatter(x=serosurvey_data.index, y=serosurvey_data, name='serosurvey data'), row=2, col=2)
fig.add_trace(go.Scatter(x=serosurvey_targets.index, y=serosurvey_targets, name='serosurvey target'), row=2, col=2)
fig.add_trace(go.Scatter(x=serosurvey_ceiling.index, y=serosurvey_ceiling, name='seroprevalence ceiling'), row=2, col=2)
fig.update_layout(height=600)
fig.update_xaxes(range=(PLOT_START_DATE, ANALYSIS_END_DATE))

In [None]:
add_image_to_doc(fig, 'targets', 'Calibration targets with raw data comparison.', app_doc, 'Targets')
app_doc.write_doc()