In [None]:
import numpy as np
import pandas as pd
from autumn.infrastructure.remote import springboard
import seaborn as sns
from matplotlib import pyplot as plt
from plotly.subplots import make_subplots
from inputs.constants import RUN_IDS, RUNS_PATH
from emutools.plotting import get_row_col_for_subplots
pd.options.plotting.backend = 'plotly'

In [None]:
def get_like_components_dfs(components):
    """
    Get dictionary containing dataframe for each requested likelihood component,
    with columns for each analysis type and integer index.
    """
    like_outputs = {}
    for comp in components:
        like_outputs[comp] = pd.DataFrame(columns=list(RUN_IDS.keys()))
        for analysis, run_id in RUN_IDS.items():
            working_data = pd.read_hdf(RUNS_PATH / run_id / 'output/results.hdf', 'likelihood')[comp]
            like_outputs[comp][analysis] = working_data
    return like_outputs


def get_outcome_df_by_chain():
    like_dfs = {}
    for analysis, run_id in RUN_IDS.items():
        like_df = pd.read_hdf(RUNS_PATH / run_id / 'output/results.hdf', 'likelihood')
        like_df['chain'] = like_df.index.get_level_values(0)
        like_df['index'] = like_df.index.get_level_values(1)
        like_dfs[analysis] = like_df.pivot(index='index', columns=['chain'])
    return like_dfs


def plot_like_components_by_analysis(like_outputs, components, plot_type, burn_in, clips={}):
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 7))
    axes = axes.reshape(-1)
    plotter = getattr(sns, plot_type)
    legend_plot_types = ['kdeplot', 'histplot']
    for m, comp in enumerate(components):
        clip = (clips[comp], 0.0) if clips else None
        kwargs = {'common_norm': False, 'clip': clip, 'shade': True} if plot_type == 'kdeplot' else {}        
        ax = axes[m]
        plotter(like_outputs[comp].loc[:, burn_in:, :], ax=ax, **kwargs)
        subtitle = comp.replace('log', '').replace('ll_', '').replace('_ma', '').replace('_', ' ')
        ax.set_title(subtitle)
        if m == 0 and plot_type in legend_plot_types:
            sns.move_legend(ax, loc='upper left')
        elif plot_type in legend_plot_types:
            ax.legend_.set_visible(False)
    fig.suptitle('Log posterior and components')
    fig.tight_layout()
    return fig


def plot_indicator_progression(like_dfs, measure):
    n_cols = 2
    fig = make_subplots(rows=2, cols=n_cols, subplot_titles=list(RUN_IDS.keys()), shared_yaxes=True)
    for i, analysis in enumerate(RUN_IDS.keys()):
        row, col = get_row_col_for_subplots(i, n_cols)
        fig.add_traces(like_dfs[analysis][measure].plot().data, rows=row, cols=col)
    fig.update_layout(height=1000, title={'text': measure})
    return fig

In [None]:
components = ['logposterior', 'logprior', 'loglikelihood', 'll_adult_seropos_prop', 'll_deaths_ma', 'll_notifications_ma']
burn_in = 5000
like_outputs = get_like_components_dfs(components)
clips = {
    'logposterior': -70.0, 
    'logprior': -46.0, 
    'loglikelihood': -28.0,
    'll_adult_seropos_prop': -4.0, 
    'll_deaths_ma': -10.0, 
    'll_notifications_ma': -17.0,
}
plot_like_components_by_analysis(like_outputs, components, 'kdeplot', burn_in, clips=clips);

In [None]:
plot_like_components_by_analysis(like_outputs, components, 'violinplot', burn_in);

In [None]:
plot_like_components_by_analysis(like_outputs, components, 'histplot', burn_in);

In [None]:
like_dfs = get_outcome_df_by_chain()
plot_indicator_progression(like_dfs, 'logposterior')