# Imports and Defaults

In [69]:
%load_ext autoreload
%autoreload 2
%load_ext snakeviz

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The snakeviz extension is already loaded. To reload it, use:
  %reload_ext snakeviz


In [70]:
import os
import sys
import timeit

import numpy as np
import seaborn as sns

project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
sys.path.insert(0, project_root)

from src.samplers.drghmc import DrGhmcDiag
from src.utils.models import BayesKitModel
from src.utils.posteriors import get_model_path, get_data_path, get_init

In [71]:
seed = 12345
sns.set_theme(style="whitegrid")

posterior_dir = "../../posteriors"
posterior_name = "irt_2pl"

# Setup

In [72]:
model_path = get_model_path(posterior_name, posterior_dir)
data_path = get_data_path(posterior_name, posterior_dir)
model = BayesKitModel(model_path=model_path, data_path=data_path)
init = get_init(posterior_name, posterior_dir, 0)
metric = np.ones(init.shape)

If the file has changed since the last time it was loaded, this load may not update the library!


In [73]:
drghmc_sampler = DrGhmcDiag(
    model=model,
    max_proposals=1,
    leapfrog_step_sizes=[0.2],
    leapfrog_step_counts=[1],
    damping=0.08,
    metric_diag=metric,
    init=init,
    seed=seed,
    prob_retry=False,
)

In [74]:
step_counts = 10
drhmc_sampler = DrGhmcDiag(
    model=model,
    max_proposals=1,
    leapfrog_step_sizes=[0.2],
    leapfrog_step_counts=[step_counts],
    damping=1.0,
    metric_diag=metric,
    init=init,
    seed=seed,
    prob_retry=False,
)

# Timing Experiments

In [75]:
def time_sampling(fn, number=1, repeat=1000):
    """Collect and analyze timing statistics for DRGHMC sampling."""
    times = timeit.repeat(fn, number=number, repeat=repeat)
    return {
        'mean': np.round(np.mean(times), 5),
        'std': np.round(np.std(times), 5),
        'min': np.round(np.min(times), 5),
        'max': np.round(np.max(times), 5),
    }

In [76]:
time_sampling(drghmc_sampler.sample, number=step_counts)

{'mean': 0.0019, 'std': 6e-05, 'min': 0.00183, 'max': 0.0031}

In [77]:
time_sampling(drhmc_sampler.sample)

{'mean': 0.00131, 'std': 0.00016, 'min': 0.00112, 'max': 0.002}

# Profiling Experiments

In [45]:
%snakeviz -t drghmc_sampler.sample()

 
*** Profile stats marshalled to file '/tmp/tmpnw3g9mng'.
Opening SnakeViz in a new tab...
Port 8080 in use, trying another.
Port 8081 in use, trying another.
snakeviz web server started on 127.0.0.1:8082; enter Ctrl-C to exit
http://127.0.0.1:8082/snakeviz/%2Ftmp%2Ftmpnw3g9mng


In [46]:
%snakeviz -t drhmc_sampler.sample()

 
*** Profile stats marshalled to file '/tmp/tmpw9g_k9jp'.
Opening SnakeViz in a new tab...
Port 8080 in use, trying another.
Port 8081 in use, trying another.
snakeviz web server started on 127.0.0.1:8082; enter Ctrl-C to exit
http://127.0.0.1:8082/snakeviz/%2Ftmp%2Ftmpw9g_k9jp
