In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from autumn.infrastructure.remote import springboard
from estival.sampling import tools as esamp

In [None]:
run_ids = {
    'none': '2023-10-04T1338-none-d20k-t10k-b5k',
    'mob': '2023-10-04T1339-mob-d20k-t10k-b5k',
    'vacc': '2023-10-04T1340-vacc-d20k-t10k-b5k',
    'both': '2023-10-04T1340-both-d20k-t10k-b5k',
}

In [None]:
quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]

In [None]:
rts = springboard.task.RemoteTaskStore()
rts.cd('projects/aust_covid/alternate_analyses')
spaghettis = {}
quantile_outputs = {}
for analysis, run_id in run_ids.items():
    mt = rts.get_managed_task(run_id)
    # mt.download_all()
    spaghettis[analysis] = pd.read_hdf(mt.local.path / 'output/results.hdf', 'spaghetti')
    quantile_outputs[analysis] = esamp.quantiles_for_results(spaghettis[analysis], quantiles)

In [None]:
from inputs.constants import PLOT_START_DATE, ANALYSIS_END_DATE

In [None]:
from aust_covid.calibration import get_targets
from emutools.tex import DummyTexDoc
targets = get_targets(DummyTexDoc())


In [None]:
target_names = [t.name for t in targets]
target_name = 'notifications_ma'
if target_name in target_names:
    target = next((t.data for t in targets if t.name == target_name))

In [None]:
target.index

In [None]:
from plotly.subplots import make_subplots
outputs = ['notifications_ma', 'deaths_ma', 'adult_seropos_prop', 'reproduction_number']
fig = make_subplots(rows=2, cols=2, subplot_titles=[o.replace('_ma', '').replace('_', ' ') for o in outputs])
max_alpha = 0.8
analysis = 'mob'
analysis_data = quantile_outputs[analysis]
target_names = [t.name for t in targets]
for i, output in enumerate(outputs):
    col = i % 2 + 1
    row = int(np.floor(i / 2)) + 1
    data = analysis_data[output]
    for q, quant in enumerate(quantiles):
        alpha = min((q, len(quantiles) - q)) / np.floor(len(quantiles) / 2) * max_alpha
        fill_colour = f'rgba(0,30,180,{str(alpha)})'
        fig.add_traces(go.Scatter(x=data.index, y=data[quant], fill='tonexty', fillcolor=fill_colour, line={'width': 0}, name=quant), rows=row, cols=col)
    fig.add_traces(go.Scatter(x=data.index, y=data[0.5], line={'color': 'black'}, name='median'), rows=row, cols=col)
    if output in target_names:
        target = next((t.data for t in targets if t.name == output))
        fig.add_traces(go.Scatter(x=target.index, y=target, mode='markers', marker={'size': 10.0, 'color': 'rgba(135, 206, 250, 0.2)', 'line': {'width': 1.0}}), rows=row, cols=col)
fig.update_layout(height=700)
fig.update_xaxes(range=[PLOT_START_DATE, ANALYSIS_END_DATE])