In [None]:
%matplotlib inline
import numpy as np
import pandas as pd
import h5py

import io
import time
import glob
import fnmatch
from itertools import chain, cycle
from textwrap import dedent
from collections import defaultdict
from collections import OrderedDict as odict
from os.path import exists
from pandas.api.types import union_categoricals

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import EngFormatter

from bokeh.io import push_notebook, show, output_notebook, output_file, save
from bokeh.models import Band, ColumnDataSource
from bokeh.models import TapTool, Select, Div, BoxZoomTool, ResetTool
from bokeh.models import HoverTool, CustomJS, TabPanel, Tabs
from bokeh.plotting import figure 
from bokeh.layouts import gridplot, layout, row, column

import bspinn
from bspinn.io_utils import get_ovatgrps, drop_unqcols
from bspinn.io_cfg import configs_dir, results_dir
from bspinn.io_cfg import keyspecs, nullstr
from bspinn.io_utils import deep2hie, hie2deep
from bspinn.io_utils import save_h5data, load_h5data
from bspinn.io_utils import get_h5du, resio, get_dfidxs
from bspinn.summary import summarize

import yaml
from ruamel import yaml as ruyaml
from IPython import display as ICD

In [None]:
dflt_dashcfg = dict(xcol='epoch', ycol=None, huecol='fpidxgrp', 
    colsep=None, rngcol='rng_seed', frame_width=350, 
    frame_height=235, ncols=4, sharex=True, sharey=True, 
    header='Ablation Studies', menu_width=250, color_reset='figure',
    fig_title='Ablation: {ablname}', colors='snsdark', tooltip=None,
    y_axis_type='auto', y_tick_fmt=None, y_tick_lbls=None, 
    x_axis_type='auto', x_tick_fmt=None, x_tick_lbls=None)

#################################################
########### Aggregation of RNG Seeds ############
#################################################
def get_aggdf(hpdf, stdf, xcol, huecol, rngcol, agg='sem'):
    grphpdf = drop_unqcols(hpdf)
    if huecol not in grphpdf.columns:
        grphpdf.insert(0, huecol, hpdf[huecol])
    for col in grphpdf.columns:
        grphpdf[col] = grphpdf[col].tolist()
    grpstdf = stdf.copy()
    
    grpstdf = grpstdf.replace(nullstr, np.nan)

    stcols = [col for col in grpstdf.columns 
            if col not in (xcol, rngcol)]
    hpcols = list(grphpdf.columns)

    grpdf = pd.concat([grphpdf, grpstdf], axis=1)
    grpbycols = [huecol, xcol]

    if callable(agg):
        aggpolicy = {c: 'first' for c in hpcols}
        aggpolicy.update({c: agg for c in stcols})
        aggpolicy.pop(huecol, None)
        aggdf = grpdf.groupby(grpbycols, sort=False).agg(aggpolicy)
        aggcolsdict = odict()
        for col in stcols:
            pkg = tuple(zip(*(aggdf[col].tolist())))
            aggcolsdict[f'{col}/mean'] = pkg[0]
            aggcolsdict[f'{col}/low'] = pkg[1]
            aggcolsdict[f'{col}/high'] = pkg[2]
        aggdf = pd.concat([aggdf.reset_index().drop(columns=stcols), pd.DataFrame(aggcolsdict)], axis=1)
    elif agg in (None, 'mean', 'sem'):
        aggpolicy = {c: ['first'] for c in hpcols}
        aggpolicy.update({c: ['mean', 'sem'] for c in stcols})
        aggdf = grpdf.groupby(grpbycols, sort=False).agg(aggpolicy)
        for col in stcols:
            aggdf.loc[:, (col, 'low')] = aggdf[col]['mean'] - 1.96 * aggdf[col]['sem']
            aggdf.loc[:, (col, 'high')] = aggdf[col]['mean'] + 1.96 * aggdf[col]['sem']

        aggdf = aggdf.reindex(columns=hpcols+stcols, level=0)
        aggdf.columns = aggdf.columns.map('/'.join)
        aggdf = aggdf.reset_index().drop(columns=huecol)

        scc = [col for col in aggdf.columns if not col.endswith('/first')]
        aggdf = aggdf.rename(columns={f'{col}/first': col for col in hpcols})
        aggdf = aggdf[hpcols + scc]
    else:
        raise ValueError(f'error_band={error_band} not defined')
    
    outdict = dict(aggdf=aggdf, hpcols=hpcols, stcols=stcols)
    return outdict

class BStrapAgg:
    stat2func = {'mean': np.mean, 'median': np.median}

    def __init__(self, n_boot=20, q=(5, 95), 
        stat='mean', seed=12345):
        self.n_boot = n_boot
        self.q = q
        self.stat_f = self.stat2func[stat]
        self.np_random = np.random.RandomState(seed)

    def __call__(self, series):
        with np.errstate(invalid='ignore'):
            v = series.values
            n = v.size
            rsv = self.np_random.choice(v, self.n_boot*n, replace=True)
            x = rsv.reshape(self.n_boot, n)
            m = self.stat_f(x, axis=1)
            l = np.percentile(m, self.q[0])
            h = np.percentile(m, self.q[1])
            return (m.mean(), l, h)

#################################################
########### Dashboard Data & Building ###########
#################################################
def get_dflt_ablspec(hpdf, fpgrps, huecol='fpidxgrp', ablcol='ablgrp'):
    grpspec = []
    uhpdf = hpdf.drop_duplicates().copy()
    uhpdf = drop_unqcols(uhpdf.drop('fpidx', axis=1))
    uhpdf = uhpdf.reset_index(drop=True)
    uhpidf = uhpdf.copy().set_index(huecol)
    for ii, fpidxs in enumerate(fpgrps):
        grphpdf = drop_unqcols(uhpidf.loc[fpidxs].reset_index(huecol))
        grpid = set(grphpdf.columns).difference({ablcol, huecol})
        grpid = sorted(grpid)
        grpspec.append(dict(name=', '.join(grpid), columns=[grpid]))
    return grpspec

def get_dashdata(data, ymlpath, write_yml=False,
    en_exclude=True, en_include=True):
    if exists(ymlpath):
        with open(ymlpath, 'r') as fp:
            dashconfig = ruyaml.load(fp, ruyaml.RoundTripLoader)
    else:
        dashconfig = dict()

    figspec = dashconfig.setdefault('figures', dict())
    aggcfg = dashconfig.setdefault('aggregate', dict())
    dropcfg = dashconfig.setdefault('drop', dict())
    renamecfg = dashconfig.setdefault('rename', dict())
    plotcfg = dashconfig.setdefault('plot', dict())

    #################################################
    ############### Aggregator Setup ################
    #################################################
    aggtype = aggcfg.setdefault('type', 'mean')
    if aggtype == 'bootstrap':
        n_boot = aggcfg.setdefault('n_boot', 20)
        q = aggcfg.setdefault('q', [5, 95])
        stat = aggcfg.setdefault('stat', 'mean')
        aggregator = BStrapAgg(n_boot=n_boot, q=q, stat=stat, seed=12345)
        assert len(aggcfg) == 4
    elif aggtype in ('sem', 'mean'):
        aggregator = 'sem'
        assert len(aggcfg) == 1
    else:
        raise ValueError(f'aggregation type={aggtype} not defined')

    #################################################
    ########### Column Definition Options ###########
    #################################################
    drop_colpats = dropcfg.setdefault('columns', [])
    assert len(dropcfg) == 1

    col_renamer = renamecfg.setdefault('columns', dict())
    col_replacer = renamecfg.setdefault('colrplc', dict())
    assert len(renamecfg) == 2
    
    for opt, dfltoptval in dflt_dashcfg.items():
        optval = plotcfg.get(opt, dfltoptval)
        if opt.endswith('col'):
            optval = col_renamer.get(optval, optval)
            if isinstance(optval, str):
                for kk, vv in col_replacer.items():
                    optval = optval.replace(kk, vv)
        plotcfg[opt] = optval
    assert len(plotcfg) == len(dflt_dashcfg), plotcfg

    xcol = plotcfg['xcol']
    huecol = plotcfg['huecol']
    rngcol = plotcfg['rngcol']
    
    #################################################
    ################ Data Processing ################
    #################################################
    dashdata = []
    for method, hpdf, stdf in data:
        meth_ablspec = figspec.setdefault(method, list())

        # Getting a copy of the input dataframes
        hpdf, stdf = hpdf.copy(), stdf.copy()

        # Dropping some stat columns
        if drop_colpats is not None:
            stdfcols = stdf.columns.tolist()
            drop_cols = chain.from_iterable(fnmatch.filter(stdfcols, pat) 
                for pat in drop_colpats)
            stdf = stdf.drop(columns=drop_cols)

        # Shortening the fpidx and fpidxgrp columns to save space on html
        hpdf['fpidx'] = hpdf['fpidx'].cat.codes
        hpdf[huecol] = hpdf[huecol].cat.codes

        # Renaming the columns
        if col_renamer is not None:
            stdf = stdf.rename(columns=col_renamer)
        if col_replacer is not None:
            colsrplcd = dict()
            for col in stdf.columns.tolist():
                colrp = col
                for kk, vv in col_replacer.items():
                    colrp = colrp.replace(kk, vv)
                colsrplcd[col] = colrp
            stdf = stdf.rename(columns=colsrplcd)

        # Grouping the hyper-parameters 
        fpgrps = get_ovatgrps(hpdf)
        dfltabldefs = get_dflt_ablspec(hpdf, fpgrps, huecol, 'ablgrp')
        grphpcols = [tuple(x['columns'][0]) for x in dfltabldefs]
        hpcols2fpgrp = dict(zip(grphpcols, fpgrps))
        
        # Adding existing ablations that do not exist in
        # the ablation config file
        mcols = [l['columns'] for l in meth_ablspec]
        mcols = [x if x != '*' else grphpcols for x in mcols]
        mohpc = chain.from_iterable(mcols)
        mohpc = [tuple(sorted(x)) for x in mohpc]
        used_fpgrpidxs = set(chain.from_iterable(hpcols2fpgrp[x] for x in mohpc))
        for fpgrp, abldict in zip(fpgrps, dfltabldefs):
            if any(fpgrpidx not in used_fpgrpidxs for fpgrpidx in fpgrp):
                meth_ablspec.append(abldict)

        # Adding the method column
        hpdf.insert(0, 'method', method)

        # Adding the ablation group column
        hpdfgrp_list, stdfgrp_list = [], []
        for abldict in meth_ablspec:
            ablname = abldict['name']
            exclist = abldict.get('exclude', [])
            exclist = exclist if en_exclude else []
            inclist = abldict.get('include', [])
            inclist = inclist if en_include else []
            allhpcols = abldict['columns'] if (abldict['columns'] != '*') else grphpcols
            for hpcols in allhpcols:
                fpgrp = hpcols2fpgrp[tuple(sorted(hpcols))]
                ii = hpdf[huecol].isin(fpgrp)
                if len(exclist) > 0:
                    exc_il = [get_dfidxs(hpdf, excdict) for excdict in exclist]
                    jj = np.stack(exc_il, axis=1).any(axis=1)
                    ii = np.logical_and(ii, np.logical_not(jj))
                if len(inclist) > 0:
                    inc_il = [get_dfidxs(hpdf, incdict) for incdict in inclist]
                    jj = np.stack(inc_il, axis=1).any(axis=1)
                    ii = np.logical_and(ii, jj)
                hpdfgrp = hpdf[ii].copy()
                stdfgrp = stdf[ii]
                hpdfgrp.insert(1, 'ablgrp', ablname)
                hpdfgrp_list.append(hpdfgrp)
                stdfgrp_list.append(stdfgrp)
        hpdf = pd.concat(hpdfgrp_list, axis=0, ignore_index=True)
        stdf = pd.concat(stdfgrp_list, axis=0, ignore_index=True)

        hpdf = hpdf.fillna(nullstr)
        stdf = stdf.fillna(nullstr)

        hpdf, stdf = hpdf, stdf
        for fig_idx, (ablname, ablhpdf) in enumerate(hpdf.groupby('ablgrp', sort=False)):
            ablstdf = stdf.loc[ablhpdf.index.values]
            aggdict = get_aggdf(ablhpdf, ablstdf, xcol, huecol, rngcol, aggregator)
            aggdf, hpcols, stcols = aggdict['aggdf'], aggdict['hpcols'], aggdict['stcols']
            dashdata.append((method, ablname, aggdf, hpcols, stcols))

    if write_yml: 
        if exists(ymlpath):
            with open(ymlpath, 'w') as fp:
                ruyaml.dump(dashconfig, fp, ruyaml.RoundTripDumper)
        else:
            with open(ymlpath, 'w') as fp:
                yaml.dump(dashconfig, fp, sort_keys=False, default_flow_style=None)

    outdict = dict(plotcfg)
    outdict['data'] = dashdata
    outdict.pop('rngcol', None)
    return outdict

def build_dashboard(doc, data, xcol, ycol, header,
    huecol, fig_title, colors=None, frame_width=350, 
    frame_height=235, ncols=4, sharex=True, sharey=True, 
    tooltip=None, menu_width=150, colsep=None, color_reset='figure',
    y_axis_type='auto', y_tick_fmt=None, y_tick_lbls=None, 
    x_axis_type='auto', x_tick_fmt=None, x_tick_lbls=None):

    if colors in (None, 'snsdark'):
        colors = ['#001c7f', '#b1400d', '#12711c', '#8c0800', '#591e71', 
                  '#592f0d', '#a23582', '#3c3c3c', '#b8850a', '#006374']
    else:
        assert isinstance(colors, list)

    if fig_title is None:
        fig_title = 'Ablation: {ablname}'

    all_bkfigdata = []
    all_tabbkfigdata = []
    tabs_list = []

    tab2data = odict()
    for tabname, ablname, aggdf, hpcols, stcols in data:
        figsdata = tab2data.setdefault(tabname, [])
        figsdata.append((ablname, aggdf, hpcols, stcols))
        if ycol is None:
            ycol = [c[:-5] for c in aggdf.columns 
                if c.endswith('/mean')][-1]
    assert ycol is not None
    assert sharex in (True, False, 'tab', 'all', 'none')
    assert sharey in (True, False, 'tab', 'all', 'none')
    assert color_reset in ('figure', 'tab', 'all')
    if color_reset == 'all':
        colors_cycle = cycle(colors)
    for tabname, figsdata in tab2data.items():
        n_tabfig = len(figsdata)
        n_rows = int(np.ceil(n_tabfig / ncols))
        screen_width = int(frame_width*(ncols+0.1))
        screen_hight = int(frame_height*(n_rows+0.1))
        tab_bkfigdata = []

        if color_reset == 'tab':
            colors_cycle = cycle(colors)
        for fig_idx, (ablname, aggdf, hpcols, stcols) in enumerate(figsdata):
            ax_row, ax_col = fig_idx // ncols, fig_idx % ncols
            idcols = [col for col in hpcols if col not in ('fpidx', huecol)]
            if tooltip is not None:
                idcols += tooltip

            show_xaxis = (sharex in (False, 'none')) or (ax_row == (n_rows-1))
            show_yaxis = (sharey in (False, 'none')) or (ax_col == 0)

            zoomtool = BoxZoomTool()
            figtools = [zoomtool, 'reset,pan,wheel_zoom']
            
            hover_opts = dict(show_arrow=False,
                line_policy='next', mode='mouse', toggleable=False)
            
            figure_opts = dict(y_axis_type=y_axis_type, 
                x_axis_type=x_axis_type, frame_height=frame_height, 
                frame_width=frame_width, sizing_mode='inherit')
            
            fig = figure(**figure_opts, tools=figtools)
            fig.toolbar.active_drag = zoomtool

            source_list = []
            fpidf_list = []
            if color_reset == 'figure':
                colors_cycle = cycle(colors)
            for lineidx, (fpidx, fpidf) in enumerate(aggdf.groupby(huecol, sort=False)):
                color = next(colors_cycle)
                fpidf = fpidf.copy()

                fpidf = fpidf.set_index(xcol)
                fpidf = fpidf.replace(np.inf, np.nan)
                fpidf = fpidf.interpolate(method='index', limit_direction='both')
                fpidf = fpidf.reset_index()

                for col in fpidf.columns:
                    if fpidf[col].dtype == 'bool':
                        fpidf[col] = fpidf[col].astype(str)

                fpidf['y/mean'] = fpidf[f'{ycol}/mean']
                fpidf['y/low'] = fpidf[f'{ycol}/low']
                fpidf['y/high'] = fpidf[f'{ycol}/high']
                
                source = ColumnDataSource(fpidf)
                line = fig.line(x=xcol, y=f'y/mean', source=source,  
                    legend_label=str(fpidx), color=color, line_width=5)

                tooltip_tups = [(col, '@{'+col+'}') for col in idcols 
                    if not (fpidf[col] == nullstr).values.all()]
                if len(tooltip_tups) > 0:
                    fig.add_tools(HoverTool(renderers=[line], 
                        tooltips=tooltip_tups, **hover_opts))
                
                band = Band(base=xcol, lower=f'y/low', 
                    upper=f'y/high', source=source,
                    fill_alpha=0.3, fill_color=color)
                fig.add_layout(band)
                source_list.append(source)
                fpidf_list.append(fpidf)

            fontsize = "10pt"
            if show_xaxis:
                fig.xaxis.axis_label = xcol
                fig.xaxis.axis_label_text_font_size = fontsize
                fig.xaxis.major_label_text_font_size = fontsize
                fig.xaxis.axis_label_text_color = "black"
            else:
                fig.xaxis.visible = False

            if show_yaxis:
                ycol_ = ycol if len(ycol) <= 50 else ycol.split(colsep)[0]
                fig.yaxis.axis_label = ycol_
                fig.yaxis.axis_label_text_font_size = fontsize
                fig.yaxis.major_label_text_font_size = fontsize
                fig.yaxis.axis_label_text_color = "black"
            else:
                fig.yaxis.visible = False

            fig.title.text = fig_title.format(ablname=ablname) 
            fig.legend.visible = False 
            fig.legend.destroy()

            for tick_fmt, tick_lbls, axis in [(x_tick_fmt, x_tick_lbls, fig.xaxis),
                                              (y_tick_fmt, y_tick_lbls, fig.yaxis)]:
                if tick_lbls is not None:
                    axis.ticker = tick_lbls
                if tick_fmt == 'eng':
                    msg_ = f'for eng fmt, tick labels must be specified'
                    assert tick_lbls is not None, msg_
                    mpleng = EngFormatter()
                    engfmtr = lambda n: mpleng.format_eng(n).replace(' ', '')
                    axis.major_label_overrides = {x: engfmtr(x) for x in tick_lbls}
                else:   
                    assert tick_fmt in ('eng', None)
            
            fig.outline_line_color = 'black'
            fig.min_border = 10

            figdict=dict(fig=fig, source_list=source_list, fpidf_list=fpidf_list)
            tab_bkfigdata.append(figdict)

        all_bkfigdata += tab_bkfigdata
        all_tabbkfigdata.append(tab_bkfigdata)
        tab_figs = [x['fig'] for x in tab_bkfigdata]
        gridfigs = gridplot(tab_figs, ncols=ncols, sizing_mode='inherit')
        tabpanel = TabPanel(child=gridfigs, title=tabname, closable=True)
        tabs_list.append(tabpanel)
    
    tabgridfigs = Tabs(tabs=tabs_list, sizing_mode='inherit')
    all_figs = [x['fig'] for x in all_bkfigdata]
    all_tabfigs = [[x['fig'] for x in tab_bkfigdata] for tab_bkfigdata in all_tabbkfigdata]
    
    if sharey in (True, 'all'):
        for fig in all_figs:
            fig.y_range = all_figs[0].y_range
    elif sharey == 'tab':
        for tabfigs in all_tabfigs:
            for fig in tabfigs:
                fig.y_range = tabfigs[0].y_range
    elif sharey in (False, 'none'):
        pass
    else:
        raise ValueError(f'sharey={sharey} not defined')
    
    if sharex in (True, 'all'):
        for fig in all_figs:
            fig.x_range = all_figs[0].x_range
    elif sharex == 'tab':
        for tabfigs in all_tabfigs:
            for fig in tabfigs:
                fig.x_range = tabfigs[0].x_range
    elif sharex in (False, 'none'):
        pass
    else:
        raise ValueError(f'sharex={sharex} not defined')

    # Menu update callback in javascript
    code_js = """
        ////////////////////////////////////////////////////
        ///////////////// Menu Separation //////////////////
        ////////////////////////////////////////////////////
        var opts, menu, val;
        var ccd = cdeep;
        var sects = sectscds.data["sects"];
        for (let i=0; i < n_menu; i++) {
            if (i <= menu_idx) {
                if (i < menu_idx) {
                    val = sects[i];
                } else {
                    val = right_menus[0].value;
                    sects[i] = val;
                }
                ccd = ccd[val];
            } else {
                menu = right_menus[i-menu_idx];
                if (ccd == null) {
                    val = null;
                    menu.visible = false;
                    menu.options = [];
                } else {
                    opts = Object.keys(ccd);
                    if (opts.includes(sects[i])) {
                        val = sects[i];
                    } else {
                        val = opts[0];
                    }
                    ccd = ccd[val];
                    menu.visible = true;
                    menu.options = opts;
                    menu.value = val;
                }
                sects[i] = val;
            }

        }
        // The new axis
        var ycol_new = sects.filter(Boolean).join(colsep);
        
        ////////////////////////////////////////////////////
        ////////////////// Source Update ///////////////////
        ////////////////////////////////////////////////////
        // var ycol_new = cb_obj.value;
        const statnames = ["mean", "low", "high"];
        for (let j=0; j<all_srcs.length; j++) {
            var sdata = all_srcs[j].data;
            for (let k=0; k< statnames.length; k++) {
                var stn = statnames[k];
                var ycol_stn = ycol_new.concat(colsep, stn);
                var y_stn = "y".concat(colsep, stn);
                sdata[y_stn] = [];
                if (ycol_stn in sdata) {
                    for (let i=0;i<sdata[ycol_stn].length; i++) {
                        sdata[y_stn].push(sdata[ycol_stn][i]);
                    }
                }
            }
        }

        for (let j=0; j<all_srcs.length; j++) {
            all_srcs[j].change.emit();
        }

        var axlbl = ycol_new;
        if (axlbl.length > 50) {
            axlbl = axlbl.split(colsep)[0];
        }
        for (let j=0; j<all_yaxes.length; j++) {
            all_yaxes[j].axis_label = axlbl;
        }
    """
    # Creating Menus and seperating the columns
    if colsep is None:
        colsep = '######'
    cdeep = hie2deep({c: None for c in stcols}, sep=colsep)
    n_menu = max(c.count(colsep) for c in stcols) + 1
    sects = ycol.split(colsep)
    sects += (n_menu - len(sects)) * [None]
    all_menus = []
    dd = cdeep
    for i in range(n_menu):
        sect = sects[i]
        assert (dd is None) == (sect is None)
        opts = [] if sect is None else list(dd.keys())
        ddd = dict(title='Y-Axis') if i == 0 else dict(title="      ")
        menu = Select(options=opts, value=sect, **ddd,
            width=menu_width, height=50, sizing_mode='fixed')
        if sect is None:
            menu.visible = False
        menu.align = 'end'
        if dd is not None:
            dd = dd[sect]
        all_menus.append(menu)

    sectscds = ColumnDataSource(data={'sects': sects})
    all_srcs = list(chain.from_iterable(figdata['source_list'] 
        for figdata in all_bkfigdata))
    all_yaxes = [fig.yaxis[0] for fig in all_figs]

    for menu_idx, menu in enumerate(all_menus):
        menu_args = dict(cdeep=cdeep, n_menu=n_menu, 
            menu_idx=menu_idx, right_menus=all_menus[menu_idx:], 
            sectscds=sectscds, colsep=colsep)
        data_args = dict(all_srcs=all_srcs, all_figs=all_figs, 
            all_yaxes=all_yaxes)
        jscb = CustomJS(args={**menu_args, **data_args}, code=code_js)
        menu.js_on_change('value', jscb)
    
    # heading fills available width
    heading = Div(text=f'<h1 style="text-align: center">{header}</h1>', 
        width=screen_width-menu_width*n_menu, height=50, 
        sizing_mode='stretch_width')
    
    fulllayout = column(row(heading, *all_menus, sizing_mode='stretch_width'), 
        tabgridfigs, sizing_mode='stretch_both')
    
    if doc is not None:
        doc.add_root(fulllayout)
    else:
        return fulllayout


In [None]:
workdir = './12_bokeh'
! mkdir -p {workdir}

# The High-Dimensional Poisson Dashboard (Version 4)

This part relies on data generated in the `13_hdpviz.ipynb` notebook. In particular, the `13_hdpviz/02_hdpviz.h5` file must be present for this part to work.

Some of the improvements over the previous version were:

1. This dashboard added the deterministic radius sampling scheme.

In [None]:
smrypath = f'13_hdpviz/02_hdpviz.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

hpdf = hpdf.sort_values(by=['dim', 'trg/btstrp', 'srfpts/dblsmpl', 'vol/n'])
i1 = hpdf.index.values
hpdf = hpdf.reset_index(drop=True)
statdf = statdf.loc[i1, :].reset_index(drop=True)

# 13_hdpviz/02_hdpviz.h5 is not generated by the original summary script, 
# so we need to make some adjustments to it to avoid erros.
hpdf['fpidx'] = pd.Categorical(hpdf['fpidx'])
hpdf['fpidxgrp'] = pd.Categorical(hpdf['fpidxgrp'])
for col in hpdf.columns:
    if hpdf.dtypes[col] == 'category':
        hpdf[col] = hpdf[col].cat.add_categories(nullstr)

In [None]:
ymlpath = f'{workdir}/06_poisshidim.yml'

all_dims = sorted(hpdf['dim'].unique().tolist())
data = []
for dim in all_dims:
    didx = (hpdf['dim'] == dim)
    dhpdf = hpdf.loc[didx, :].reset_index(drop=True)
    dstdf = statdf.loc[didx, :].reset_index(drop=True)
    data.append((f'{dim}-Dimensional', dhpdf, dstdf))

dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/06_poisshidim.html')
save(fulllayout, title=dashdata['header'])

# The High-Dimensional Poisson Dashboard (Version 3)

This was the second earliest attempt at collecting the high-dimensional poisson problem results.

Some of the improvements over the previous version were:

1. The Poisson problem dimension range is [2, 3, 4, 5, 6, 7, 8, 9, 10].

2. The target regularization weight for bootstrapping was tuned properly.

In [None]:
smrypath = f'../summary/05_poisshidim.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

hpdf = hpdf.sort_values(by=['dim', 'trg/btstrp', 'srfpts/dblsmpl', 'vol/n'])
i1 = hpdf.index.values
hpdf = hpdf.reset_index(drop=True)
statdf = statdf.loc[i1, :].reset_index(drop=True)

In [None]:
ymlpath = f'{workdir}/05_poisshidim.yml'

all_dims = sorted(hpdf['dim'].unique().tolist())
data = []
for dim in all_dims:
    didx = (hpdf['dim'] == dim)
    dhpdf = hpdf.loc[didx, :].reset_index(drop=True)
    dstdf = statdf.loc[didx, :].reset_index(drop=True)
    data.append((f'{dim}-Dimensional', dhpdf, dstdf))

dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/05_poisshidim.html')
save(fulllayout, title=dashdata['header'])

# The High-Dimensional Poisson Dashboard (Version 2)

This was the second earliest attempt at collecting the high-dimensional poisson problem results.

Some of the improvements over the previous version were:

1. A single Poisson charge was placed at zero for all experiments.

2. No initial condition was applied here at all.

The following issues exist with this benchmarking:

1. The Poisson problem dimension range is [8, 16].

2. The target regularization weight for bootstrapping was still not set properly.

3. A very small number of bootstrapped trainings diverged due to Number 2.

In [None]:
smrypath = f'../summary/04_poisshidim.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

i1 = (hpdf['trg/btstrp'] == True)
i2 = (hpdf['trg/btstrp'] == False)
hpdf_bts = hpdf[i1].reset_index(drop=True)
statdf_bts = statdf[i1].reset_index(drop=True)
hpdf_mse = hpdf[i2].reset_index(drop=True)
statdf_mse = statdf[i2].reset_index(drop=True)

In [None]:
stable_idxs = []
for fpidxgrp, hpdf in hpdf_bts.groupby('fpidxgrp'):
    if hpdf.shape[0] < 1:
        continue
    stdf = statdf_bts.loc[hpdf.index, :]

    rsstdf = stdf.groupby('rng_seed').agg('max').reset_index()
    goodrs = rsstdf.loc[rsstdf['loss/total'] < 1, 'rng_seed']
    goodinic = stdf['rng_seed'].isin(goodrs)
    stable_idxs.append(goodinic[goodinic].index.values)

keepidxs = np.concatenate(stable_idxs, axis=0)

hpdf_bts = hpdf_bts.loc[keepidxs].reset_index(drop=True)
statdf_bts = statdf_bts.loc[keepidxs].reset_index(drop=True)

# Combining filtered bootstrapping and mse data-frames
hpdf = pd.concat([hpdf_bts, hpdf_mse], axis=0, ignore_index=True)
statdf = pd.concat([statdf_bts, statdf_mse], axis=0, ignore_index=True)

hpdf = hpdf.sort_values(by=['dim', 'trg/btstrp', 'srfpts/dblsmpl', 'vol/n'])
i1 = hpdf.index.values
hpdf = hpdf.reset_index(drop=True)
statdf = statdf.loc[i1, :].reset_index(drop=True)

In [None]:
ymlpath = f'{workdir}/04_poisshidim.yml'

all_dims = sorted(hpdf['dim'].unique().tolist())
data = []
for dim in all_dims:
    didx = (hpdf['dim'] == dim)
    dhpdf = hpdf.loc[didx, :].reset_index(drop=True)
    dstdf = statdf.loc[didx, :].reset_index(drop=True)
    data.append((f'{dim}-Dimensional', dhpdf, dstdf))

dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/04_poisshidim.html')
save(fulllayout, title=dashdata['header'])

# The High-Dimensional Poisson Dashboard (Version 1)

This was an the earliest attempt at collecting the high-dimensional poisson problem results.

The following issues exist with this benchmarking:

1. The Poisson problem dimension range is [2, 4, 8, 16, 32, 64].

2. The SiLU activation with only two layers were used here.

3. An initial condition was applied that was neither effective nor harmful.

4. The target regularization weight for bootstrapping was not set properly.

5. Three poisson charges were placed stochastically in the unit ball.

6. Most bootstrapped trainings diverged due to Number 4.

In [None]:
smrypath = f'../summary/03_poisshidim.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

i1 = (hpdf['trg/btstrp'] == True)
i2 = (hpdf['srfpts/dblsmpl'] == True)
i3 = np.logical_not(np.logical_or(i1.values, i2.values))
hpdf_bts = hpdf[i1].reset_index(drop=True)
statdf_bts = statdf[i1].reset_index(drop=True)
hpdf_ds = hpdf[i2].reset_index(drop=True)
statdf_ds = statdf[i2].reset_index(drop=True)
hpdf_mse = hpdf[i3].reset_index(drop=True)
statdf_mse = statdf[i3].reset_index(drop=True)

In [None]:
ymlpath = f'{workdir}/03_poisshidim.yml'

all_dims = sorted(hpdf['dim'].unique().tolist())
data = []
for dim in all_dims:
    didx = (hpdf['dim'] == dim)
    dhpdf = hpdf.loc[didx, :].reset_index(drop=True)
    dstdf = statdf.loc[didx, :].reset_index(drop=True)
    data.append((f'{dim}-Dimensional', dhpdf, dstdf))

dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/03_poisshidim.html')
save(fulllayout, title=dashdata['header'])

# The 2-Dimensional Poisson Problem Ablations Dashboard

In [None]:
smrypath = f'11_plotting/poisson.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
data_ = hie2deep(data, maxdepth=1)
dfd_bts = data_['bts']
dfd_mse = data_['mse']
dfd_ds = data_['ds']
hpdf_mse = dfd_mse['hp']
statdf_mse = dfd_mse['stat']
hpdf_bts = dfd_bts['hp']
statdf_bts = dfd_bts['stat']
hpdf_ds = dfd_ds['hp']
statdf_ds = dfd_ds['stat']

In [None]:
ymlpath = f'{workdir}/01_poisson.yml'
data = [('Standard Training',  hpdf_mse, statdf_mse),
        ('Bootstrapping',      hpdf_bts, statdf_bts),
        ('Double Sampling',    hpdf_ds,  statdf_ds )]

dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/01_poisson.html')
save(fulllayout, title=dashdata['header'])

# The 1-Dimensional Smoluchowski Problem Ablations Dashboard

In [None]:
smrypath = f'11_plotting/smoluchowski.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
data_ = hie2deep(data, maxdepth=1)
dfd_bts = data_['bts']
dfd_mse = data_['mse']
hpdf_mse = dfd_mse['hp']
statdf_mse = dfd_mse['stat']
hpdf_bts = dfd_bts['hp']
statdf_bts = dfd_bts['stat']

In [None]:
ymlpath = f'{workdir}/02_smoluchowski.yml'
data = [('Standard Training', hpdf_mse, statdf_mse),
        ('Bootstrapping',     hpdf_bts, statdf_bts)]

dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/02_smoluchowski.html')
save(fulllayout, title=dashdata['header'])