In [11]:
from collections import defaultdict
import itertools
import jax
import numpy as np

from benchmark import benchmark_chains, cumulative_avg, err, ess, get_num_latents
import blackjax
from blackjax.mcmc.mhmclmc import rescale
from blackjax.util import run_inference_algorithm
import jax.numpy as jnp 

from inference_models import models

def sampler_mhmclmc(step_size, L):

    def s(logdensity_fn, num_steps, initial_position, key):

        num_steps_per_traj = L/step_size
        alg = blackjax.mcmc.mhmclmc.mhmclmc(
        logdensity_fn=logdensity_fn,
        step_size=step_size,
        integration_steps_fn = lambda k: jnp.round(jax.random.uniform(k) * rescale(num_steps_per_traj+ 0.5)) ,
        # integration_steps_fn = lambda _ : 5,
        # integration_steps_fn = lambda key: jnp.ceil(jax.random.poisson(key, L/step_size )) ,

        )
        
        _, out, info = run_inference_algorithm(
        rng_key=key,
        initial_state_or_position=initial_position,
        inference_algorithm=alg,
        num_steps=num_steps, 
        transform=lambda x: x.position, 
        progress_bar=True)

        print(out.mean(axis=0))

        # print(info.acceptance_rate.mean(), "acceptance probability")
        # print(out.var(axis=0), "acceptance probability")

        return out, num_steps_per_traj

    return s

results = defaultdict(float)

# def benchmark(model, sampler, n=100000):



#     identity_fn = model.sample_transformations['identity']
#     # print('True mean', identity_fn.ground_truth_mean)
#     # print('True std', identity_fn.ground_truth_standard_deviation)
#     # print("Empirical mean", samples.mean(axis=0))
#     # print("Empirical std", samples.std(axis=0))

#     logdensity_fn = model.unnormalized_log_prob
#     d = get_num_latents(model)
#     initial_position = jax.random.normal(jax.random.PRNGKey(0), (d,))
#     samples, num_steps_per_traj = sampler(logdensity_fn, n, initial_position, jax.random.PRNGKey(0))
#     # print(samples[-1], samples[0], "samps", samples.shape)

#     favg, fvar = identity_fn.ground_truth_mean, identity_fn.ground_truth_standard_deviation**2
#     err_t = err(favg, fvar, jnp.average)(cumulative_avg(samples))
#     # print(err(favg, fvar, jnp.average)(samples[:2]), samples[:10], "errt")
#     # print(err_t[-1], "benchmark err_t[0]")
#     ess_per_sample = ess(err_t, grad_evals_per_step=2)
    
#     return ess_per_sample

for model in ["Banana"]:
    # for sampler in ["mclmc"]:
    # for sampler in itertools.product([60.786648], [15.196002]):
    # for sampler in itertools.product([15.196002], [60.786648]):
    for sampler in itertools.product([1.196002], np.linspace(50, 70, 5)):
        # result = benchmark(models[model], samplers[sampler])
        # result = benchmark_chains(models[model], samplers[sampler])
        result = benchmark_chains(models[model], sampler_mhmclmc(*sampler), n=1000000, batch=1)
        # result = benchmark(models[model], sampler_mhmclmc(*sampler), n=1000000)
        # print(result, result2, "results")
        results[(model, sampler)] = result
results


Traced<ShapedArray(float32[2])>with<BatchTrace(level=1/0)> with
  val = Array([[-0.004184 , -1.5177863]], dtype=float32)
  batch_dim = 0
crossing 1000000
True mean [0. 0.]
True std [10.          4.35889894]
Empirical mean [-0.004184  -1.5177863]
Empirical std [7.0250397 2.204241 ]



Traced<ShapedArray(float32[2])>with<BatchTrace(level=1/0)> with
  val = Array([[ 0.00508401, -1.5205412 ]], dtype=float32)
  batch_dim = 0
crossing 1000000
True mean [0. 0.]
True std [10.          4.35889894]
Empirical mean [ 0.00508401 -1.5205412 ]
Empirical std [7.0250273 2.2118993]



Traced<ShapedArray(float32[2])>with<BatchTrace(level=1/0)> with
  val = Array([[ 0.00371127, -1.5180322 ]], dtype=float32)
  batch_dim = 0
crossing 1000000
True mean [0. 0.]
True std [10.          4.35889894]
Empirical mean [ 0.00371127 -1.5180322 ]
Empirical std [7.0259323 2.2059674]



Traced<ShapedArray(float32[2])>with<BatchTrace(level=1/0)> with
  val = Array([[ 0.01146103, -1.5166684 ]], dtype=float32)
  batch_dim = 0
crossing 1000000
True mean [0. 0.]
True std [10.          4.35889894]
Empirical mean [ 0.01146103 -1.5166684 ]
Empirical std [7.0318165 2.213777 ]



Traced<ShapedArray(float32[2])>with<BatchTrace(level=1/0)> with
  val = Array([[ 7.5468089e-04, -1.5207368e+00]], dtype=float32)
  batch_dim = 0
crossing 1000000
True mean [0. 0.]
True std [10.          4.35889894]
Empirical mean [ 7.5468089e-04 -1.5207368e+00]
Empirical std [7.0231533 2.2026107]


defaultdict(float,
            {('Banana', (1.196002, 50.0)): Array(0., dtype=float32),
             ('Banana', (1.196002, 55.0)): Array(0., dtype=float32),
             ('Banana', (1.196002, 60.0)): Array(0., dtype=float32),
             ('Banana', (1.196002, 65.0)): Array(0., dtype=float32),
             ('Banana', (1.196002, 70.0)): Array(0., dtype=float32)})

In [9]:
import seaborn as sns

import matplotlib.pyplot as plt

# Extract x and y values from the keys of the results dictionary
x_values = [key[0] for key in results.keys()]
y_values = [key[1] for key in results.keys()]

print(len(x_values))
# raise Exception

# Extract heat values from the dictionary
heat_values = list(results.values())


# Reshape the heat values into a 2D array
# heat_array = np.array(heat_values).reshape((len(x_values), len(y_values)))
heat_array = np.array(heat_values).reshape((10,2))

# Create the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(heat_array, annot=True, cmap='viridis')
plt.xlabel('num steps per traj')
plt.ylabel('step size')
plt.title('Heatmap of Results')
plt.show()


2


ValueError: cannot reshape array of size 2 into shape (10,2)

In [10]:
results

defaultdict(float,
            {('simple', (10.0, 1.0)): Array(0.18656716, dtype=float32),
             ('simple',
              (10.0, 1.4444444444444444)): Array(0.25452492, dtype=float32),
             ('simple',
              (10.0, 1.8888888888888888)): Array(0.24509802, dtype=float32),
             ('simple',
              (10.0, 2.333333333333333)): Array(0.17281109, dtype=float32),
             ('simple',
              (10.0, 2.7777777777777777)): Array(0.17999998, dtype=float32),
             ('simple',
              (10.0, 3.2222222222222223)): Array(0.14502098, dtype=float32),
             ('simple',
              (10.0, 3.6666666666666665)): Array(0.15321758, dtype=float32),
             ('simple',
              (10.0, 4.111111111111111)): Array(0.1253831, dtype=float32),
             ('simple',
              (10.0, 4.555555555555555)): Array(0.11315059, dtype=float32),
             ('simple', (10.0, 5.0)): Array(0.11111111, dtype=float32),
             ('simple',
         