In [None]:
import numpy as np
import pandas as pd
from plotly.subplots import make_subplots
from plotly import graph_objects as go
from emutools.tex import DummyTexDoc, StandardTexDoc
from emutools.plotting import get_row_col_for_subplots
from aust_covid.calibration import get_targets
pd.options.plotting.backend = 'plotly'
from inputs.constants import PLOT_START_DATE, ANALYSIS_END_DATE, RUN_IDS, PROJECT_PATH, RUNS_PATH

In [None]:
spaghettis = {k: pd.read_hdf(RUNS_PATH / v / 'output/results.hdf', 'spaghetti') for k, v in RUN_IDS.items()}
targets = get_targets(DummyTexDoc())

In [None]:
def plot_multi_spaghetti(output, targets):
    target = next(i for i in targets if i.name == output)
    n_cols = 2
    fig = make_subplots(rows=2, cols=n_cols, subplot_titles=list(RUN_IDS.keys()), shared_yaxes=True)
    for i, analysis in enumerate(RUN_IDS.keys()):
        row, col = get_row_col_for_subplots(i, n_cols)
        spaghetti = spaghettis[analysis][output]
        spaghetti.columns = [f'{str(chain)}, {str(draw)}' for chain, draw in spaghetti.columns]    
        fig.add_traces(spaghetti.plot().data, rows=row, cols=col)
        fig.add_trace(go.Scatter(x=target.data.index, y=target.data, mode='markers', marker={'color': 'black', 'size': 12}), row=row, col=col)
    fig.update_layout(height=1000, title={'text': output})
    fig.update_xaxes(range=(PLOT_START_DATE, ANALYSIS_END_DATE))
    return fig

In [None]:
plot_multi_spaghetti('notifications_ma', targets)

In [None]:
plot_multi_spaghetti('deaths_ma', targets)

In [None]:
plot_multi_spaghetti('adult_seropos_prop', targets)