# Imports and Defaults

In [1]:
import json
import os
from collections import namedtuple
from functools import lru_cache
from zipfile import ZipFile

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
sns.set_theme(font_scale=16, style="whitegrid")
sns.set_theme(style="whitegrid")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "figure.dpi": 300,
})

seed = 1234
rng = np.random.default_rng(seed)

In [3]:
posterior = "irt_2pl"

# Load Data

In [4]:
fname = f"../../data/{posterior}/baseline"

drghmc_sampler = "adapt_metric=True__damping=0.08__max_proposals=3__metric=1__probabilistic=False__reduction_factor=4__sampler_type=drghmc__step_count_method=const_step_count__step_size_factor=2"
drhmc_sampler = "adapt_metric=True__damping=1.0__max_proposals=3__metric=1__probabilistic=False__reduction_factor=4__sampler_type=drhmc__step_count_factor=0.9__step_size_factor=2"
nuts_sampler = "adapt_metric=True__metric=identity__sampler_type=nuts"
ghmc_sampler = "adapt_metric=True__damping=0.08__max_proposals=1__metric=1__probabilistic=False__reduction_factor=4__sampler_type=ghmc__step_count_method=const_step_count__step_size_factor=2"

In [5]:
def get_fname(sampler, chain):
    return os.path.join(fname, sampler, f"history__chain={chain}.npz")

def get_chain_data(sampler, chain, downsample=1000):
    data = np.load(get_fname(sampler, chain))
    return data["draws"][::downsample], data["grad_evals"][::downsample]

In [6]:
@lru_cache(maxsize=3)
def get_data(sampler):
    data_list, grad_list = [], []
    for chain in range(30):
        print(chain)
        try:
            draws, grad_evals = get_chain_data(sampler, chain) # draws = (n_samples, n_params)
            data_list.append(draws)
            grad_list.append(np.array(grad_evals))
            del draws, grad_evals
        except:
            pass
    
    # do not stack b/c variable len chains
    return data_list, grad_list

In [7]:
ghmc_data, ghmc_grads = get_data(ghmc_sampler)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


In [None]:
ghmc_data, ghmc_grads = get_data(ghmc_sampler)
drghmc_data, drghmc_grads = get_data(drghmc_sampler)
drhmc_data, drhmc_grads = get_data(drhmc_sampler)
nuts_data, nuts_grads = get_data(nuts_sampler)

# Compute True Params from Reference Draws

In [None]:
TrueParams = namedtuple('TrueParams', ['p1_mean', 'p1_std', 'p2_mean', 'p2_std'])

def load_true_params(posterior):
    ref_draws_fname = os.path.join(f"../../posteriors/{posterior}/{posterior}.ref_draws.json.zip")
    with ZipFile(ref_draws_fname) as z:
        with z.open(f"{posterior}.ref_draws.json") as f:
            ref_draws = json.load(f)
    return ref_draws

def compute_true_params(posterior):
    ref_draws_raw = load_true_params(posterior) # list of dicts
    ref_draws = np.array([list(d.values()) for d in ref_draws_raw]) # [num_chains, num_params, num_samples]
    # ref_draws = np.array(list(ref_draws_chains[0].values()))
    tp = TrueParams(
        p1_mean=np.mean(ref_draws, axis=(0, 2)),
        p1_std=np.std(ref_draws, axis=(0, 2)),
        p2_mean=np.mean(ref_draws**2, axis=(0, 2)),
        p2_std=np.std(ref_draws**2, axis=(0, 2))
    )
    del ref_draws, ref_draws_raw
    return tp

In [None]:
tp = compute_true_params(posterior)

# Compute Per-Chain Error

In [None]:
def get_mean(data):
    n = data.shape[0]
    denom = np.arange(1, n+1)[:, None]
    return np.cumsum(data, axis=0) / denom

def get_error(data, true_mean, true_std):
    error = np.abs(get_mean(data) - true_mean) / true_std
    return np.linalg.norm(error, axis=1, ord=2)  # inf norm equivalent to max

In [None]:
# [chains, samples]
drghmc_error = [get_error(data, tp.p1_mean, tp.p1_std) for data in drghmc_data]
drhmc_error = [get_error(data, tp.p1_mean, tp.p1_std) for data in drhmc_data]
nuts_error = [get_error(data, tp.p1_mean, tp.p1_std) for data in nuts_data]

drghmc_error_squared = [get_error(data**2, tp.p2_mean, tp.p2_std) for data in drghmc_data]
drhmc_error_squared = [get_error(data**2, tp.p2_mean, tp.p2_std) for data in drhmc_data]
nuts_error_squared = [get_error(data**2, tp.p2_mean, tp.p2_std) for data in nuts_data]

# Error vs Grad Evals

In [None]:
# flatten drghmc_error and keep track of chain index
drghmc_error_flat = np.concatenate(drghmc_error)
drhmc_error_flat = np.concatenate(drhmc_error)
nuts_error_flat = np.concatenate(nuts_error)

drghmc_error

drghmc_chain_idx = np.concatenate([np.full_like(e, i) for i, e in enumerate(drghmc_error)])
drhmc_chain_idx = np.concatenate([np.full_like(e, i) for i, e in enumerate(drhmc_error)])
nuts_chain_idx = np.concatenate([np.full_like(e, i) for i, e in enumerate(nuts_error)])

# Final Error Box Plot

In [None]:
# get last error only. [chains]
drghmc_e1 = np.array([chain[-1] for chain in drghmc_error])
drhmc_e1 = np.array([chain[-1] for chain in drhmc_error])
nuts_e1 = np.array([chain[-1] for chain in nuts_error])

drghmc_e2 = np.array([chain[-1] for chain in drghmc_error_squared])
drhmc_e2 = np.array([chain[-1] for chain in drhmc_error_squared])
nuts_e2 = np.array([chain[-1] for chain in nuts_error_squared])

In [None]:
data = {
    "Sampler": ["DR-G-HMC"]*len(drghmc_e1) + ["DR-HMC"]*len(drhmc_e1) + ["NUTS"]*len(nuts_e1),
    "Errors": np.concatenate([drghmc_e1, drhmc_e1, nuts_e1]),
    "Error Squared": np.concatenate([drghmc_e2, drhmc_e2, nuts_e2]),
}

data = pd.DataFrame(data)
data = data.melt(id_vars="Sampler", var_name="Error Type", value_name="Error")
data

In [None]:
fig = sns.catplot(
    data=data,
    kind="box",
    x="Sampler",
    y="Error",
    col="Error Type",
    hue="Sampler",
    hue_order=["DR-G-HMC", "DR-HMC", "NUTS"],
    aspect=1.5,
    showmeans=True,
    meanline=True,
    meanprops=dict(linestyle="--", linewidth=2, color="black"),
    legend=False,
)

fig.axes.flat[0].set_title(r'Error in Mean ($\mathcal{L}_{\theta, T}$)')
fig.axes.flat[1].set_title(r'Error in Variance ($\mathcal{L}_{\theta^2, T}$)')
fig.set(yscale="log")
fig.set(ylim=(1e-2, 1e1))
plt.show()

In [None]:
# print mean of draws for each dimension
# recall drghmc_data, drhmc_data, nuts_data are lists of numpy arrays
print("\t\tDR-G-HMC\tDR-HMC\t\tNUTS\t\tRef")

drghmc_mean = np.mean([np.mean(data, axis=0) for data in drghmc_data], axis=0)
drhmc_mean = np.mean([np.mean(data, axis=0) for data in drhmc_data], axis=0)
nuts_mean = np.mean([np.mean(data, axis=0) for data in nuts_data], axis=0)
ref_mean = tp.p1_mean
ref_std = tp.p1_std

mean_iter = zip(drghmc_mean, drhmc_mean, nuts_mean, ref_mean)
for idx, (m1, m2, m3, m4) in enumerate(mean_iter):
    # print(f"Param {idx+1}:\t{m1:.4f}\t\t{m2:.4f}\t\t{m3:.4f}\t\t{m4:.4f}")
    m1_diff, m2_diff, m3_diff = np.abs(m1 - m4) / tp.p1_std[idx], np.abs(m2 - m4) / tp.p1_std[idx], np.abs(m3 - m4) / tp.p1_std[idx]
    print(f"Diff {idx + 1}: \t{m1_diff:.4f}\t\t{m2_diff:.4f}\t\t{m3_diff:.4f}")


In [None]:
# l-inf norm of mean of draws
norm_ord = np.inf
print("DRGHMC Error:\t", np.linalg.norm(np.abs(drghmc_mean - ref_mean) / ref_std, ord=norm_ord))
print("DRHMC Error:\t", np.linalg.norm(np.abs(drhmc_mean - ref_mean) / ref_std, ord=norm_ord))
print("NUTS Error:\t", np.linalg.norm(np.abs(nuts_mean - ref_mean) / ref_std, ord=norm_ord))