In [None]:
import numpy as np
import pandas as pd
from autumn.infrastructure.remote import springboard
from plotly.subplots import make_subplots
pd.options.plotting.backend = 'plotly'

In [None]:
run_ids = {
    'none': '2023-10-09T1251-none-d20k-t10k-b5k',
    'mob': '2023-10-09T1253-mob-d20k-t10k-b5k',
    'vacc': '2023-10-09T1254-vacc-d20k-t10k-b5k',
    'both': '2023-10-09T1255-both-d20k-t10k-b5k',
}
measure = 'logposterior'

In [None]:
rts = springboard.task.RemoteTaskStore()
rts.cd('projects/aust_covid/alternate_analyses')
logpost_dfs = {}
for analysis, id in run_ids.items():
    mt = rts.get_managed_task(id)
    # mt.download_all()
    like_df = pd.read_hdf(mt.local.path / 'output/results.hdf', 'likelihood')
    like_df['chain'] = like_df.index.get_level_values(0)
    like_df['index'] = like_df.index.get_level_values(1)
    logpost_dfs[analysis] = like_df.pivot(index='index', columns=['chain'])[measure]

In [None]:
likelihoods = pd.DataFrame(columns=run_ids.keys())
for analysis, run_id in run_ids.items():
    mt = rts.get_managed_task(run_id)
    likelihoods[analysis] = pd.read_hdf(mt.local.path / 'output' / 'results.hdf', 'likelihood')['loglikelihood']
seaborn.kdeplot(likelihoods)

In [None]:
fig = make_subplots(rows=2, cols=2, subplot_titles=list(run_ids.keys()), shared_yaxes=True)
for i, analysis in enumerate(run_ids.keys()):
    col = i % 2 + 1
    row = int(np.floor(i / 2)) + 1
    fig.add_traces(logpost_dfs[analysis].plot().data, rows=row, cols=col)
fig.update_layout(height=1000, title={'text': measure})
fig