In [None]:
import pandas as pd
from emutools.tex import StandardTexDoc
import arviz as az
from inputs.constants import ANALYSIS_START_DATE, ANALYSIS_END_DATE, SUPPLEMENT_PATH, PLOT_START_DATE
from aust_covid.plotting import plot_key_outputs, plot_cdr_examples, plot_subvariant_props, plot_dispersion_examples
from aust_covid.calibration import get_priors, get_targets
from emutools.calibration import get_sampled_outputs
import plotly.graph_objects as go
from emutools.calibration import round_sigfig
import numpy as np
import plotly.express as px
from plotly.subplots import make_subplots
from plotly import graph_objects as go
pd.options.plotting.backend = 'plotly'

In [None]:
app_doc = StandardTexDoc(SUPPLEMENT_PATH, 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')

In [None]:
def plot_spaghetti(spaghetti, indicators, cols, targets):
    cols = 2
    rows = int(np.ceil(len(indicators) / cols))

    fig = make_subplots(rows=2, cols=cols, subplot_titles=indicators)
    for i, ind in enumerate(indicators):
        row = int(np.floor(i / cols)) + 1
        col = i % cols + 1

        ind_spagh = spaghetti[ind]
        ind_spagh.columns = [f'chain:{col[0]}, draw:{col[1]}' for col in ind_spagh.columns]
        ind_spagh = ind_spagh[(PLOT_START_DATE < ind_spagh.index) & (ind_spagh.index < ANALYSIS_END_DATE)]
        fig.add_traces(px.line(ind_spagh).data, rows=row, cols=col)

        target = next((t.data for t in targets if t.name == ind), None)
        if target is not None:
            target = target[(PLOT_START_DATE < target.index) & (target.index < ANALYSIS_END_DATE)]
            fig.add_trace(go.Scatter(x=target.index, y=target, marker=dict(size=15.0, line=dict(width=1.0, color='DarkSlateGrey')), name='targets', mode='markers'), row=row, col=col)
    fig.update_layout(showlegend=False, height=800)
    return fig

In [None]:
spaghetti = pd.read_csv('spaghetti.csv', header=[0, 1, 2], index_col=[0])
burnt_idata = az.from_netcdf('idata_for_spaghetti.nc')
targets = get_targets(app_doc)
spaghetti.index = pd.to_datetime(spaghetti.index)

In [None]:
indicators = ['notifications_ma', 'deaths_ma', 'adult_seropos_prop', 'reproduction_number']
plot_spaghetti(spaghetti, indicators, 2, targets)

#### Parameter hover figure
Not sure this is actually worthwhile, and probably shouldn't be the primary output.

In [None]:
fig = go.Figure()
indicator_spaghetti = spaghetti['adult_seropos_prop']
data = pd.DataFrame()
for col in indicator_spaghetti.columns:
    chain, draw = col
    data['values'] = indicator_spaghetti[col]
    variables = burnt_idata.posterior.sel(chain=int(chain), draw=int(draw)).variables
    info = {i: float(j) for i, j in dict(burnt_idata.posterior.sel(chain=int(chain), draw=int(draw)).variables).items()}
    for param in info:
        data[param] = int(info[param]) if param in ['chain', 'draw'] else round_sigfig(info[param], 3)
    lines = px.line(data, y='values', hover_data=data.columns)
    fig.add_traces(lines.data)
fig