I'm currently using a pure-numpy implementation of Annealed Importance Sampling (for binary multivariate RVs), and I'm trying to improve performance by porting the code to jax. I'm benchmarking the jax version against pure-numpy and numba-jitted versions; at the moment, I'm getting better performance using numba than using jax, and I'm wondering if it's because I'm not using jax correctly (very likely), or maybe this specific computation is somehow better suited for numba (if that's the reason, I'd also be very interested in learning why :-). From certain model sizes (around N=75 on my laptop), the jax implementation is actually slower than the pure-numpy version.
I'm attaching the comparison; AIS is documented in the jax version but the numba and numpy functions are very similar. The model is an all-to-all connected Ising model with random weights and biases:
from time import time
import numpy as onp
from numba import njit
import jax
import jax.numpy as np
from jax import random, jit, vmap
from jax.ops import index_update, index_add, index
from jax.lax import fori_loop
from jax.scipy.special import logsumexp
def ais_jax(N, K, n_betas, alphas, init_states):
old_nrgs = np.zeros(K) #energies according the the uniform distribution
log_w_ratio = np.zeros(K) #vector holding log of importance weights ratios
key = random.PRNGKey(0)
betas = np.array(np.linspace(0, 1, n_betas)) #inverse temp. vector
@jit
def run_mh(j, loop_carry):
states, log_w_ratio, old_nrgs, key = loop_carry
beta = betas[j]
key, key2 = random.split(key)
#log_p of former states
log_p_former_states = -((states@alphas*beta)*states).sum(1)
#new proposed states
proposed = index_update(states, index[:, j % N], 1-states[:, j % N])
new_nrgs = ((proposed@alphas*beta)*proposed).sum(1)
dE = new_nrgs - old_nrgs
accept = ((dE < 0) | (jax.random.uniform(key2, shape=(K,)) < np.exp(-dE)))
#new accepted states and nrgs
states = np.where(accept[:, np.newaxis], proposed, states)
old_nrgs = np.where(accept, new_nrgs, old_nrgs)
#log_p of new states and importance weights
log_p_new_states = -((states@alphas*beta)*states).sum(1)
log_w_ratio += log_p_former_states - log_p_new_states
return states, log_w_ratio, old_nrgs, key
states, log_w_ratio, old_nrgs, key = fori_loop(
1, n_betas-1, run_mh, (init_states, log_w_ratio, old_nrgs, key))
log_w_ratio += -((states@alphas)*states).sum(1)
#computing logZ using logsumexp
return N*np.log(2)-np.log(K)+logsumexp(log_w_ratio)
@njit
def ais_numba(N, K, n_betas, alphas, states):
old_nrgs = onp.zeros(K)
log_w_ratio = onp.zeros(K)
betas = onp.linspace(0, 1, n_betas)
for j in range(n_betas-1):
beta = betas[j]
log_p_former_states = -((states@alphas*beta)*states).sum(1)
proposed = states.copy()
proposed[:,j % N] = 1-proposed[:,j % N]
new_nrgs = ((proposed@alphas*beta)*proposed).sum(1)
dE = new_nrgs - old_nrgs
accept = ((dE < 0) | (onp.random.rand(K) < onp.exp(-dE)))
states[accept] = proposed[accept]
old_nrgs = onp.where(accept, new_nrgs, old_nrgs)
log_p_new_states = -((states@alphas*beta)*states).sum(1)
log_w_ratio += log_p_former_states - log_p_new_states
log_w_ratio += -((states@alphas)*states).sum(1)
max_lws = onp.max(log_w_ratio)
return N*onp.log(2)-onp.log(K)+onp.log(onp.exp(log_w_ratio-max_lws).sum())+max_lws
def ais_numpy(N, K, n_betas, alphas, states):
old_nrgs = onp.zeros(K)
log_w_ratio = onp.zeros(K)
betas = onp.linspace(0, 1, n_betas)
for j in range(n_betas-1):
beta = betas[j]
log_p_former_states = -((states@alphas*beta)*states).sum(1)
proposed = states.copy()
proposed[:,j % N] = 1-proposed[:,j % N]
new_nrgs = ((proposed@alphas*beta)*proposed).sum(1)
dE = new_nrgs - old_nrgs
accept = ((dE < 0) | (onp.random.rand(K) < onp.exp(-dE)))
states[accept] = proposed[accept]
old_nrgs = onp.where(accept, new_nrgs, old_nrgs)
log_p_new_states = -((states@alphas*beta)*states).sum(1)
log_w_ratio += log_p_former_states - log_p_new_states
log_w_ratio += -((states@alphas)*states).sum(1)
max_lws = onp.max(log_w_ratio)
return N*onp.log(2)-onp.log(K)+onp.log(onp.exp(log_w_ratio-max_lws).sum())+max_lws
K = 100
n_betas = 100000
#iterate over different model sizes
for N in range(15,100,5):
#five repititions per model size
for rep in range(5):
#randomize weights and biases, use matrix represnetation
alphas = -onp.diag(onp.random.normal(2, 0.02, N))
alphas[onp.triu_indices(N, 1)] = onp.random.normal(-0.05, 0.01, N*(N-1)//2)
alphas += alphas.T
init_states = onp.random.randint(0, 2, size=(K, N)).astype('float')
start = time()
logZ_ais_np = ais_numpy(N, K, n_betas, alphas, init_states)
ais_np_time = time()-start
start = time()
logZ_ais_nb = ais_numba(N, K, n_betas, alphas, init_states)
ais_nb_time = time()-start
start = time()
logZ_ais_jax = onp.array(ais_jax(N, K, n_betas, alphas, np.array(init_states.astype(int))))
ais_jax_time = time()-start
print(f"N={N}, trial={rep}")
print(f"numba/np = {(logZ_ais_nb/logZ_ais_np).round(3)}, numba_time/numpy_time = {ais_nb_time/ais_np_time}")
print(f"jax/np = {(logZ_ais_jax/logZ_ais_np).round(3)}, jax_time/numpy_time = {ais_jax_time/ais_np_time}")
I'm currently using a pure-numpy implementation of Annealed Importance Sampling (for binary multivariate RVs), and I'm trying to improve performance by porting the code to jax. I'm benchmarking the jax version against pure-numpy and numba-jitted versions; at the moment, I'm getting better performance using numba than using jax, and I'm wondering if it's because I'm not using jax correctly (very likely), or maybe this specific computation is somehow better suited for numba (if that's the reason, I'd also be very interested in learning why :-). From certain model sizes (around N=75 on my laptop), the jax implementation is actually slower than the pure-numpy version.
I'm attaching the comparison; AIS is documented in the jax version but the numba and numpy functions are very similar. The model is an all-to-all connected Ising model with random weights and biases: