In [None]:
import numpy as np
import pandas as pd
from autumn.infrastructure.remote import springboard
from plotly.subplots import make_subplots
from plotly import graph_objects as go
from emutools.tex import DummyTexDoc, StandardTexDoc
from aust_covid.calibration import get_targets
pd.options.plotting.backend = 'plotly'
from inputs.constants import PLOT_START_DATE, ANALYSIS_END_DATE, RUN_IDS

In [None]:
rts = springboard.task.RemoteTaskStore()
rts.cd('projects/aust_covid/alternate_analyses')
spaghettis = {}
for type, id in RUN_IDS.items():
    mt = rts.get_managed_task(id)
    mt.download_all()
    spaghettis[type] = pd.read_hdf(mt.local.path / 'output/results.hdf', 'spaghetti')

In [None]:
def plot_multi_spaghetti(output, targets):
    target = next(i for i in targets if i.name == output)
    fig = make_subplots(rows=2, cols=2, subplot_titles=list(RUN_IDS.keys()), shared_yaxes=True)
    for i, analysis in enumerate(RUN_IDS.keys()):
        col = i % 2 + 1
        row = int(np.floor(i / 2)) + 1
        spaghetti = spaghettis[analysis][output]
        spaghetti.columns = [f'{str(chain)}, {str(draw)}' for chain, draw in spaghetti.columns]    
        fig.add_traces(spaghetti.plot().data, rows=row, cols=col)
        fig.add_trace(go.Scatter(x=target.data.index, y=target.data, mode='markers', marker={'color': 'black', 'size': 12}), row=row, col=col)
    fig.update_layout(height=1000, title={'text': output})
    fig.update_xaxes(range=(PLOT_START_DATE, ANALYSIS_END_DATE))
    return fig

In [None]:
targets = get_targets(DummyTexDoc())

In [None]:
plot_multi_spaghetti('notifications_ma', targets)

In [None]:
plot_multi_spaghetti('deaths_ma', targets)

In [None]:
plot_multi_spaghetti('adult_seropos_prop', targets)