In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from autumn.infrastructure.remote import springboard
from inputs.constants import PLOT_START_DATE, ANALYSIS_END_DATE, RUN_IDS
from estival.sampling import tools as esamp
from aust_covid.calibration import get_targets
from emutools.tex import DummyTexDoc

In [None]:
def get_row_col_for_subplots(i_panel, n_cols):
    return int(np.floor(i_panel / n_cols)) + 1, i_panel % n_cols + 1


def plot_output_ranges_by_analysis(quantile_outputs, targets, output, analyses, quantiles, max_alpha=0.7):
    """
    Plot the credible intervals with subplots for each analysis type,
    for a single output of interest.
    """
    n_cols = 2
    target_names = [t.name for t in targets]
    fig = make_subplots(rows=2, cols=n_cols, subplot_titles=list(analyses))
    for a, analysis in enumerate(analyses):
        row, col = get_row_col_for_subplots(a, n_cols)
        analysis_data = quantile_outputs[analysis]
        data = analysis_data[output]
        for q, quant in enumerate(quantiles):
            alpha = min((q, len(quantiles) - q)) / np.floor(len(quantiles) / 2) * max_alpha
            fill_colour = f'rgba(0,30,180,{str(alpha)})'
            fig.add_traces(go.Scatter(x=data.index, y=data[quant], fill='tonexty', fillcolor=fill_colour, line={'width': 0}, name=quant), rows=row, cols=col)
        fig.add_traces(go.Scatter(x=data.index, y=data[0.5], line={'color': 'black'}, name='median'), rows=row, cols=col)
        if output in target_names:
            target = next((t for t in targets if t.name == output))
            marker_format = {'size': 10.0, 'color': 'rgba(250, 135, 206, 0.2)', 'line': {'width': 1.0}}
            fig.add_traces(go.Scatter(x=target.data.index, y=target.data, mode='markers', marker=marker_format, name=target.name), rows=row, cols=col)
    fig.update_layout(height=700, title=output)
    fig.update_xaxes(range=[PLOT_START_DATE, ANALYSIS_END_DATE])
    return fig


def plot_output_ranges(quantile_outputs, targets, outputs, analysis, quantiles, max_alpha=0.7):
    n_cols = 2
    target_names = [t.name for t in targets]
    fig = make_subplots(rows=2, cols=n_cols, subplot_titles=[o.replace('_ma', '').replace('_', ' ') for o in outputs])
    analysis_data = quantile_outputs[analysis]
    for i, output in enumerate(outputs):
        row, col = get_row_col_for_subplots(i, n_cols)
        data = analysis_data[output]
        for q, quant in enumerate(quantiles):
            alpha = min((q, len(quantiles) - q)) / np.floor(len(quantiles) / 2) * max_alpha
            fill_colour = f'rgba(0,30,180,{str(alpha)})'
            fig.add_traces(go.Scatter(x=data.index, y=data[quant], fill='tonexty', fillcolor=fill_colour, line={'width': 0}, name=quant), rows=row, cols=col)
        fig.add_traces(go.Scatter(x=data.index, y=data[0.5], line={'color': 'black'}, name='median'), rows=row, cols=col)
        if output in target_names:
            target = next((t for t in targets if t.name == output))
            marker_format = {'size': 10.0, 'color': 'rgba(250, 135, 206, 0.2)', 'line': {'width': 1.0}}
            fig.add_traces(go.Scatter(x=target.data.index, y=target.data, mode='markers', marker=marker_format, name=target.name), rows=row, cols=col)
    fig.update_layout(height=700)
    fig.update_xaxes(range=[PLOT_START_DATE, ANALYSIS_END_DATE])
    return fig

In [None]:
rts = springboard.task.RemoteTaskStore()
rts.cd('projects/aust_covid/alternate_analyses')
spaghettis = {}
quantile_outputs = {}
quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]
for analysis, run_id in RUN_IDS.items():
    mt = rts.get_managed_task(run_id)
    # mt.download_all()
    spaghettis[analysis] = pd.read_hdf(mt.local.path / 'output/results.hdf', 'spaghetti')
    quantile_outputs[analysis] = esamp.quantiles_for_results(spaghettis[analysis], quantiles)

In [None]:
targets = get_targets(DummyTexDoc())
outputs = ['notifications_ma', 'deaths_ma', 'adult_seropos_prop', 'reproduction_number']
plot_output_ranges(quantile_outputs, targets, outputs, 'mob', quantiles)

In [None]:
plot_output_ranges_by_analysis(quantile_outputs, targets, 'notifications_ma', RUN_IDS.keys(), quantiles)

In [None]:
plot_output_ranges_by_analysis(quantile_outputs, targets, 'deaths_ma', RUN_IDS.keys(), quantiles)

In [None]:
plot_output_ranges_by_analysis(quantile_outputs, targets, 'adult_seropos_prop', RUN_IDS.keys(), quantiles)