Skip to content

Annealed Importance Sampling - numpy/numba/jax comparison #1421

@adamhaber

Description

@adamhaber

I'm currently using a pure-numpy implementation of Annealed Importance Sampling (for binary multivariate RVs), and I'm trying to improve performance by porting the code to jax. I'm benchmarking the jax version against pure-numpy and numba-jitted versions; at the moment, I'm getting better performance using numba than using jax, and I'm wondering if it's because I'm not using jax correctly (very likely), or maybe this specific computation is somehow better suited for numba (if that's the reason, I'd also be very interested in learning why :-). From certain model sizes (around N=75 on my laptop), the jax implementation is actually slower than the pure-numpy version.

I'm attaching the comparison; AIS is documented in the jax version but the numba and numpy functions are very similar. The model is an all-to-all connected Ising model with random weights and biases:

from time import time

import numpy as onp
from numba import njit
import jax
import jax.numpy as np
from jax import random, jit, vmap
from jax.ops import index_update, index_add, index
from jax.lax import fori_loop
from jax.scipy.special import logsumexp

def ais_jax(N, K, n_betas, alphas, init_states):
    old_nrgs = np.zeros(K)     #energies according the the uniform distribution
    log_w_ratio = np.zeros(K)       #vector holding log of importance weights ratios
    key = random.PRNGKey(0)
    betas = np.array(np.linspace(0, 1, n_betas))  #inverse temp. vector

    @jit
    def run_mh(j, loop_carry):
        states, log_w_ratio, old_nrgs, key = loop_carry
        beta = betas[j]
        key, key2 = random.split(key)
        
        #log_p of former states
        log_p_former_states = -((states@alphas*beta)*states).sum(1)   

        #new proposed states
        proposed = index_update(states, index[:, j % N], 1-states[:, j % N])
        new_nrgs = ((proposed@alphas*beta)*proposed).sum(1)

        dE = new_nrgs - old_nrgs
        accept = ((dE < 0) | (jax.random.uniform(key2, shape=(K,)) < np.exp(-dE)))

        #new accepted states and nrgs
        states = np.where(accept[:, np.newaxis], proposed, states)
        old_nrgs = np.where(accept, new_nrgs, old_nrgs)

        #log_p of new states and importance weights
        log_p_new_states = -((states@alphas*beta)*states).sum(1)
        log_w_ratio += log_p_former_states - log_p_new_states
        return states, log_w_ratio, old_nrgs, key

    states, log_w_ratio, old_nrgs, key = fori_loop(
        1, n_betas-1, run_mh, (init_states, log_w_ratio, old_nrgs, key))
    log_w_ratio += -((states@alphas)*states).sum(1)

    #computing logZ using logsumexp 
    return N*np.log(2)-np.log(K)+logsumexp(log_w_ratio)

@njit
def ais_numba(N, K, n_betas, alphas, states):
    old_nrgs = onp.zeros(K)
    log_w_ratio = onp.zeros(K)
    betas = onp.linspace(0, 1, n_betas)
    
    for j in range(n_betas-1):
        beta = betas[j]
        log_p_former_states = -((states@alphas*beta)*states).sum(1)

        proposed = states.copy()
        proposed[:,j % N] = 1-proposed[:,j % N]

        new_nrgs = ((proposed@alphas*beta)*proposed).sum(1)
        dE = new_nrgs - old_nrgs
        accept = ((dE < 0) | (onp.random.rand(K) < onp.exp(-dE)))
        states[accept] = proposed[accept]

        old_nrgs = onp.where(accept, new_nrgs, old_nrgs)
        log_p_new_states = -((states@alphas*beta)*states).sum(1)
        log_w_ratio += log_p_former_states - log_p_new_states

    log_w_ratio += -((states@alphas)*states).sum(1)
    max_lws = onp.max(log_w_ratio)

    return N*onp.log(2)-onp.log(K)+onp.log(onp.exp(log_w_ratio-max_lws).sum())+max_lws

def ais_numpy(N, K, n_betas, alphas, states):
    old_nrgs = onp.zeros(K)
    log_w_ratio = onp.zeros(K)
    betas = onp.linspace(0, 1, n_betas)
    
    for j in range(n_betas-1):
        beta = betas[j]
        log_p_former_states = -((states@alphas*beta)*states).sum(1)

        proposed = states.copy()
        proposed[:,j % N] = 1-proposed[:,j % N]

        new_nrgs = ((proposed@alphas*beta)*proposed).sum(1)
        dE = new_nrgs - old_nrgs
        accept = ((dE < 0) | (onp.random.rand(K) < onp.exp(-dE)))
        states[accept] = proposed[accept]

        old_nrgs = onp.where(accept, new_nrgs, old_nrgs)
        log_p_new_states = -((states@alphas*beta)*states).sum(1)
        log_w_ratio += log_p_former_states - log_p_new_states

    log_w_ratio += -((states@alphas)*states).sum(1)
    max_lws = onp.max(log_w_ratio)

    return N*onp.log(2)-onp.log(K)+onp.log(onp.exp(log_w_ratio-max_lws).sum())+max_lws

K = 100
n_betas = 100000

#iterate over different model sizes
for N in range(15,100,5):
    #five repititions per model size
    for rep in range(5):
        #randomize weights and biases, use matrix represnetation
        alphas = -onp.diag(onp.random.normal(2, 0.02, N))
        alphas[onp.triu_indices(N, 1)] = onp.random.normal(-0.05, 0.01, N*(N-1)//2)
        alphas += alphas.T

        init_states = onp.random.randint(0, 2, size=(K, N)).astype('float')
        start = time()
        logZ_ais_np = ais_numpy(N, K, n_betas, alphas, init_states)
        ais_np_time = time()-start
        
        start = time()
        logZ_ais_nb = ais_numba(N, K, n_betas, alphas, init_states)
        ais_nb_time = time()-start
        
        start = time()
        logZ_ais_jax = onp.array(ais_jax(N, K, n_betas, alphas, np.array(init_states.astype(int))))
        ais_jax_time = time()-start

        print(f"N={N}, trial={rep}")
        print(f"numba/np = {(logZ_ais_nb/logZ_ais_np).round(3)}, numba_time/numpy_time = {ais_nb_time/ais_np_time}")
        print(f"jax/np = {(logZ_ais_jax/logZ_ais_np).round(3)}, jax_time/numpy_time = {ais_jax_time/ais_np_time}")

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions