In [None]:
%matplotlib inline
import numpy as np
import pandas as pd
import h5py

import io
import time
import glob
import fnmatch
from itertools import chain
from textwrap import dedent
from collections import defaultdict
from collections import OrderedDict as odict
from os.path import exists
from pandas.api.types import union_categoricals

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import EngFormatter
from adjustText import adjust_text

import bspinn
from bspinn.io_cfg import configs_dir, results_dir
from bspinn.io_cfg import keyspecs, nullstr
from bspinn.io_utils import deep2hie, hie2deep
from bspinn.io_utils import save_h5data, load_h5data
from bspinn.io_utils import get_h5du, resio
from bspinn.summary import summarize

# Utility Functions

In [None]:
#####################################################
############# Plotting Utility Functions ############
#####################################################

def get_linetext(ax, line_idx):
    """
    Gets the label of a plotted line object in a 
    matplotib axis.

    Parameters
    ----------
    ax: (plt.Axis) an axis object

    line_idx: (int) the index of the line drawn in `ax`

    Outputs
    -------
    text: (str) the label of the line indexed at `line_idx`
    """
    leg_ = ax.get_legend()
    leg = leg_ if leg_ is not None else ax.legend()
    text = leg.texts[line_idx].get_text()
    if leg_ is None:
            leg.remove()
    return text
    
def autoannot(fig, ax, xpercs=None, **kwargs):
    """
    Automatically annotates a line plot with their labels.

    Parameters
    ----------
    fig: (plt.Figure) a figure object

    ax: (plt.Axis) an axis object

    xpercs: An object mapping each label to the horizontal 
      axis percentiles for the arrow. 
        
      This argument can accept various data types:

        * `None` will evenly space the lines in [0.1, 0.6].
        
        * `tuple` or `list` of length two will define define 
            the even spacing interval.

        * `dict` can be provided as a mapping between the 
            labels and their x percentiles for the arrow.

        * `callable` objects can define the mapping between 
            the labels and their x percentiles for the arrow.

    **kwargs: (dict) all the options that will be piped to 
        `adjust_text`.
    """
    lines = ax.get_lines()
    lines = [(i, line) for i, line in enumerate(lines)]
    lines = [(i, line) for i, line in lines 
             if line.get_xydata().shape[0] > 1]
    labels = [get_linetext(ax, i) for i, line in lines]
    lines = [line for i, line in lines]
    if xpercs is None:
        xpercs = np.linspace(0.1, 0.6, len(lines))
    elif isinstance(xpercs, (tuple, list)) and len(xpercs) == 2:
        xpercs = np.linspace(*xpercs, len(lines))
    elif isinstance(xpercs, dict):
        xpercs = [xpercs[lbl] for lbl in labels]
    elif callable(xpercs):
        xpercs = [xpercs(lbl) for lbl in labels]
    else:
        raise ValueError(f'Not sure how to use xpercs={xpercs}')

    xpercs_np = np.array(xpercs) * 100 
    plttxts = []
    all_x, all_y = [], []
    for line, xperc, label in zip(lines, xpercs_np, labels):
        lxy = line.get_xydata()
        x = np.percentile(lxy[:,0], xperc)
        y = np.interp(x, lxy[:,0], lxy[:,1])
        plttxt = ax.text(x, y, label)
        all_x.append(x)
        all_y.append(y)
        plttxts.append(plttxt)
    
    kwargs_ = dict(expand_points=(4.5, 2.0),
        arrowprops=dict(arrowstyle="->", color='black', lw=1.5, 
            relpos=(0., 0.5), connectionstyle='arc3,rad=0.2'))  
    kwargs_.update(kwargs)  
    adjust_text(plttxts, ax=ax, **kwargs_)



# Summary Data Save and Load

In [None]:
figdir = './11_plotting'
!mkdir -p {figdir}

smrypath = f'{figdir}/poisson.h5'
# ! rm -f {smrypath}

if exists(smrypath):
    get_h5du(smrypath, verbose=True, detailed=False)
    data = load_h5data(smrypath)
    data_ = hie2deep(data, maxdepth=1)
    dfd_bts = data_['bts']
    dfd_mse = data_['mse']
    dfd_ds = data_['ds']
else:
    keynames = ['hp', 'stat', 'var/eval/ug']
    dfd_bts = summarize('01_poisson/01_btstrp2d', keys=keynames)
    dfd_mse = summarize('01_poisson/02_mse2d', keys=keynames)
    dfd_ds = summarize('01_poisson/03_ds2d', keys=keynames)

    np.seterr(all="ignore")
    print('')
    for dfd in (dfd_bts, dfd_mse, dfd_ds):
        mdl_sol = dfd['var/eval/ug/sol/mdl']
        gt_sol = dfd['var/eval/ug/sol/gt']
        n_mdl, n_g = mdl_sol.shape
        assert mdl_sol.shape == (n_mdl, n_g)
        assert gt_sol.shape == (n_mdl, n_g)

        n_gpd = int(np.sqrt(n_g))
        (n_gpd1, n_gpd2) = (n_gpd, n_gpd)

        x1_pd = np.linspace(-1, 1, n_gpd1)
        assert x1_pd.shape == (n_gpd1,)
        x2_pd = np.linspace(-1, 1, n_gpd2)
        assert x2_pd.shape == (n_gpd2,)
        x1_msh = np.broadcast_to(x1_pd[:,None], (n_gpd1, n_gpd2))
        assert x1_msh.shape == (n_gpd1, n_gpd2)
        x2_msh = np.broadcast_to(x2_pd[None,:], (n_gpd1, n_gpd2))
        assert x2_msh.shape == (n_gpd1, n_gpd2)
        x1_1d = x1_msh.reshape(-1)
        assert x1_1d.shape == (n_g,)
        x2_1d = x2_msh.reshape(-1)
        assert x2_1d.shape == (n_g,)
        x = np.stack([x1_1d, x2_1d, np.ones(n_g)], axis=1)
        assert x.shape == (n_g, 3)

        A = np.linalg.pinv(x.T @ x) @ x.T
        assert A.shape == (3, n_g)
        A_tilde = A[None,...]
        assert A_tilde.shape == (1, 3, n_g)
        print('.', end='', flush=True)

        mdl_y = mdl_sol[:, :, None]
        assert mdl_y.shape == (n_mdl, n_g, 1)
        mdl_beta = np.matmul(A_tilde, mdl_y)
        assert mdl_beta.shape == (n_mdl, 3, 1)
        mdl_prd = np.matmul(x[None, :, :], mdl_beta)
        assert mdl_prd.shape == (n_mdl, n_g, 1)
        mdl_res = mdl_prd - mdl_y
        assert mdl_res.shape == (n_mdl, n_g, 1)
        print('.', end='', flush=True)

        gt_y = gt_sol[:, :, None]
        assert gt_y.shape == (n_mdl, n_g, 1)
        gt_beta = np.matmul(A_tilde, gt_y)
        assert gt_beta.shape == (n_mdl, 3, 1)
        gt_prd = np.matmul(x[None, :, :], gt_beta)
        assert gt_prd.shape == (n_mdl, n_g, 1)
        gt_res = gt_prd - gt_y
        assert gt_res.shape == (n_mdl, n_g, 1)
        print('.', end='', flush=True)

        errunb = (mdl_res - gt_res)[..., 0]
        assert errunb.shape == (n_mdl, n_g)
        mseunb = np.square(errunb).mean(axis=1)
        assert mseunb.shape == (n_mdl,)
        maeunb = np.abs(errunb).mean(axis=1)
        assert maeunb.shape == (n_mdl,)
        print('.', end='', flush=True)

        msolnob = mdl_sol - mdl_sol.mean(axis=1, keepdims=True)
        print('.', end='', flush=True)
        gsolnob = gt_sol - gt_sol.mean(axis=1, keepdims=True)
        print('.', end='', flush=True)
        errnob = msolnob - gsolnob
        print('.', end='', flush=True)
        msenob = np.square(errnob).mean(axis=1)
        print('.', end='', flush=True)
        maenob = np.abs(errnob).mean(axis=1)
        print('.', flush=True)

        dfd['stat']['perf/ug/mdl/mse2'] = msenob
        dfd['stat']['perf/ug/mdl/mae2'] = maenob
        dfd['stat']['perf/ug/mdl/mse3'] = mseunb
        dfd['stat']['perf/ug/mdl/mae3'] = maeunb

    data = deep2hie({'bts': dfd_bts, 'mse': dfd_mse, 'ds': dfd_ds})
    data = odict([(k, v) for k, v in data.items() if 'var/eval' not in k])
    save_h5data(data, smrypath)
    get_h5du(smrypath, verbose=True, detailed=False)

hpdf_mse = dfd_mse['hp']
statdf_mse = dfd_mse['stat']

hpdf_bts = dfd_bts['hp']
statdf_bts = dfd_bts['stat']

hpdf_ds = dfd_ds['hp']
statdf_ds = dfd_ds['stat']

# Poisson Figures

In [None]:
nrows, ncols = 1, 2
fig, axes = plt.subplots(nrows, ncols, 
    figsize=(2.2*ncols, 2.1*nrows), 
    sharex=True, sharey=True, dpi=144)
axes = np.array(axes).reshape(nrows, ncols)

axspecs = [('Det. Sampling',   1,    '01_poisson/02_mse2d.1.0', 0, dfd_mse, -100,   -1),
           ('Det. Sampling',  10,    '01_poisson/02_mse2d.3.0', 0, dfd_mse,   -2,   -1),
           ('Det. Sampling', 100,    '01_poisson/02_mse2d.5.0', 0, dfd_mse,   -2,   -1),
           ('Dbl. Sampling',   1,     '01_poisson/03_ds2d.0.0', 1,  dfd_ds, -100,   -1),
           ('Dbl. Sampling',  10,     '01_poisson/03_ds2d.2.0', 1,  dfd_ds,   -5,   -4),
           ('Dbl. Sampling', 100,     '01_poisson/03_ds2d.0.1', 1,  dfd_ds,   -5,   -4),
           ('Bootstrapping',   1, '01_poisson/01_btstrp2d.0.0', 2, dfd_bts,   -2,   -1)]

n_des = 100
for axspec in axspecs:
    ax_title, n, fpidx, ax_idx, dfd, a, b = axspec
    if (n != n_des) or (ax_idx >= axes.size):
        continue

    ax = axes.flat[ax_idx]
    hpdf_, stdf_ = dfd['hp'], dfd['stat']
    myidx = (hpdf_['fpidx'] == fpidx)
    myidx = stdf_[myidx].reset_index().sort_values(by=['epoch', 'rng_seed'])['index'].values

    hpdf = dfd['hp'].loc[myidx]
    stdf = dfd['stat'].loc[myidx]

    rio = resio(fpidx)
    mdlsols = rio('var/eval/ug/sol/mdl')

    n_gpd = int(np.sqrt(mdlsols.shape[1]))
    (n_gpd1, n_gpd2) = (n_gpd, n_gpd)
    z_msh = mdlsols[a:b, :].mean(axis=0).reshape(n_gpd1, n_gpd2)
    x1_1d = np.linspace(-1, 1, n_gpd1)
    x2_1d = np.linspace(-1, 1, n_gpd2)
    x1_msh = np.broadcast_to(x1_1d[:,None], (n_gpd1, n_gpd2))
    x2_msh = np.broadcast_to(x2_1d[None,:], (n_gpd1, n_gpd2))
    ax.pcolormesh(x1_msh, x2_msh, z_msh, shading='auto', 
         norm=None, cmap='RdBu', linewidth=0, rasterized=True)
    ax.set_title(f'{ax_title} (N={n})', fontsize=10)
    
fig.savefig(f'{figdir}/hmap_poiss_n{n_des}.pdf', dpi=200, bbox_inches="tight")

In [None]:
color_dict = {'Dbl. Sampling': 'blue', 'Det. Sampling': 'red', 'Bootstrapping': 'green'}
fig2, ax2 = plt.subplots(1, 1, figsize=(1*2.1, 1*2.2), 
    sharex=False, sharey=False, dpi=144)

n_des = 100
stdfs = []
for axspec in axspecs:
    ax_title, n, fpidx, ax_idx, dfd, a, b = axspec
    if (n != n_des) and (ax_title != 'Bootstrapping'):
        continue
    hpdf_, stdf_ = dfd['hp'], dfd['stat']
    myidx = (hpdf_['fpidx'] == fpidx)
    myidx = stdf_[myidx].reset_index().sort_values(by=['epoch', 'rng_seed'])['index'].values
    hpdf = dfd['hp'].loc[myidx]
    stdf = dfd['stat'].loc[myidx]
    stdf = stdf.copy(deep=False)
    stdf['method'] = ax_title
    stdfs.append(stdf)

snsdf = pd.concat(stdfs, axis=0, ignore_index=True)
sns.lineplot(snsdf, x='epoch', y='perf/ug/mdl/mse2', hue='method',
    palette=color_dict, errorbar=('ci', 50), ax=ax2)

ax2.set_yscale('log')
ax2.set_ylabel('Ground Truth MSE')
if n_des == 1:
    ax2.set_yticks([10**i for i in [-4, -3, -2, -1, 0, 1, 2]])
    ax2.set_ylim(1e-4, 1e3)
if n_des == 100:
    ax2.set_ylim(0.0003, 0.031)
    ylocator = mpl.ticker.LogLocator(base=10.0, subs=(1.0, 0.3), numdecs=0, numticks=5)
    ax2.yaxis.set_major_locator(ylocator)
    ax2.yaxis.set_major_formatter(mpl.ticker.ScalarFormatter())

ax2.set_xlabel('Epoch')
ax2.set_xticks(np.linspace(0, 200_000, 5))
engfmt = EngFormatter(sep='')
ax2.xaxis.set_major_formatter(engfmt)

ax2.legend().remove()

fig2.savefig(f'{figdir}/nmse_vs_epoch_poiss_n{n_des}.pdf', dpi=200, bbox_inches="tight")

# The Smoluchowski Summary

In [None]:
figdir = './11_plotting'
!mkdir -p {figdir}

smrypath = f'{figdir}/smoluchowski.h5'
# ! rm -f {smrypath}

if exists(smrypath):
    get_h5du(smrypath, verbose=True, detailed=False)
    data = load_h5data(smrypath)
    data_ = hie2deep(data, maxdepth=1)
    dfd_bts = data_['bts']
    dfd_mse = data_['mse']
else:
    keynames = ['hp', 'stat']
    dfd_bts = summarize('02_smoll/01_btstrp', keys=keynames)
    dfd_mse = summarize('02_smoll/02_mse', keys=keynames)

    data = deep2hie({'bts': dfd_bts, 'mse': dfd_mse})
    save_h5data(data, smrypath)
    get_h5du(smrypath, verbose=True, detailed=False)

hpdf_mse = dfd_mse['hp']
statdf_mse = dfd_mse['stat']

hpdf_bts = dfd_bts['hp']
statdf_bts = dfd_bts['stat']

# Smoluchowski Figures Plotting

In [None]:
vartups = [(1,   1, '02_smoll/02_mse.0.0'),
           (1,   2, '02_smoll/02_mse.1.5'),
           (1,   3, '02_smoll/02_mse.1.6'),
           (100, 1, '02_smoll/02_mse.0.1'),
           (100, 2, '02_smoll/02_mse.3.5'),
           (100, 3, '02_smoll/02_mse.3.6'),
           (1,   1, '02_smoll/01_btstrp.0.0'),
           (1,   2, '02_smoll/01_btstrp.1.0'),
           (1,   3, '02_smoll/01_btstrp.0.1')]

fig, axes = plt.subplots(1, 3, figsize=(3*2.8, 1*1.9), 
    sharex=True, sharey=True, dpi=144)

for i, (N, d, fpidx) in enumerate(vartups):
    ax = axes[d-1]
    color = None
    if 'btstrp' in fpidx:
        color = 'green'
        s_df, h_df = statdf_bts, hpdf_bts
    else:
        color = 'red' if N == 1 else 'blue'
        s_df, h_df = statdf_mse, hpdf_mse
        
    
    mysdf = s_df[h_df['fpidx'] == fpidx]    
    sns.lineplot(mysdf, x='epoch', y='perf/ug1/mdl/mse', color=color,
        errorbar=('ci', 50), 
        ax=ax)
    
    ax.set_yscale('log')
    ax.set_ylabel('Ground Truth MSE')
    ax.set_yticks([10**i for i in [-6, -5, -4, -3, -2, -1, 0, 1]])
    ax.set_ylim(1e-7, 10)
    
    ax.set_xlabel('Epoch')
    ax.set_xticks(np.linspace(0, 200_000, 5))
    engfmt = EngFormatter(sep='')
    ax.xaxis.set_major_formatter(engfmt)
    
    ax.set_title(f'{d}D Problem')

fig.savefig(f'{figdir}/mse_vs_epoch_sm.pdf', dpi=200, bbox_inches="tight")