In [None]:
import numpy as np
import pandas as pd
from autumn.infrastructure.remote import springboard
import seaborn as sns
from matplotlib import pyplot as plt
from plotly.subplots import make_subplots
pd.options.plotting.backend = 'plotly'

In [None]:
rts = springboard.task.RemoteTaskStore()
rts.cd('projects/aust_covid/alternate_analyses')
run_ids = {
    'none': '2023-10-09T1251-none-d20k-t10k-b5k',
    'mob': '2023-10-09T1253-mob-d20k-t10k-b5k',
    'vacc': '2023-10-09T1254-vacc-d20k-t10k-b5k',
    'both': '2023-10-09T1255-both-d20k-t10k-b5k',
}

In [None]:
# Run once if needed
for analysis, run_id in run_ids.items():
    mt = rts.get_managed_task(run_id)
    mt.download_all()

In [None]:
def get_like_components_dfs(components):
    """
    Get dictionary containing dataframe for each (important) likelihood component,
    with columns for each analysis type and integer index.
    """
    like_outputs = {}
    for comp in components:
        like_outputs[comp] = pd.DataFrame(columns=list(run_ids.keys()))
        for run_id in run_ids:
            mt = rts.get_managed_task(run_ids[run_id])
            working_data = pd.read_hdf(mt.local.path / 'output' / 'results.hdf', 'likelihood')[comp]
            like_outputs[comp][run_id] = working_data
    return like_outputs


def get_outcome_df_by_chain():
    like_dfs = {}
    for analysis, run_id in run_ids.items():
        mt = rts.get_managed_task(run_id)
        like_df = pd.read_hdf(mt.local.path / 'output/results.hdf', 'likelihood')
        like_df['chain'] = like_df.index.get_level_values(0)
        like_df['index'] = like_df.index.get_level_values(1)
        like_dfs[analysis] = like_df.pivot(index='index', columns=['chain'])
    return like_dfs


def plot_like_components_by_analysis(like_outputs, components, plot_type, burn_in):
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 7))
    axes = axes.reshape(-1)
    plotter = getattr(sns, plot_type)
    for m, comp in enumerate(components):
        ax = axes[m]
        plotter(like_outputs[comp].loc[:, burn_in:, :], ax=ax)
        subtitle = comp.replace('log', '').replace('ll_', '').replace('_ma', '').replace('_', ' ')
        ax.set_title(subtitle)
    fig.suptitle('Log posterior and components')
    fig.tight_layout()
    return fig


def plot_indicator_progression(like_dfs, measure):
    fig = make_subplots(rows=2, cols=2, subplot_titles=list(run_ids.keys()), shared_yaxes=True)
    for i, analysis in enumerate(run_ids.keys()):
        col = i % 2 + 1
        row = int(np.floor(i / 2)) + 1
        fig.add_traces(like_dfs[analysis][measure].plot().data, rows=row, cols=col)
    fig.update_layout(height=1000, title={'text': measure})
    return fig

In [None]:
components = ['logposterior', 'logprior', 'loglikelihood', 'll_adult_seropos_prop', 'll_deaths_ma', 'll_notifications_ma']
burn_in = 5000
like_outputs = get_like_components_dfs(components)
plot_like_components_by_analysis(like_outputs, components, 'kdeplot', burn_in);

In [None]:
plot_like_components_by_analysis(like_outputs, components, 'violinplot', burn_in);

In [None]:
plot_like_components_by_analysis(like_outputs, components, 'histplot', burn_in);

In [None]:
like_dfs = get_outcome_df_by_chain()
plot_indicator_progression(like_dfs, 'logposterior')