Goal: test batching semantics and some other potential ADEV optims manually on some basic examples to see if they could be worth it.

In [6]:
import jax.numpy as jnp
import jax
from jax import jit

def g(b, dp, dr):
    return jax.lax.cond(b, lambda x: x, lambda x: 1.0 - x, dr) * dp

def dkont(x):
    return x[0]

# v1: using array inplace updates
def f1(dp_list, dr_init):
    k = 2   # here it's just 2 because it's a flip_enum
    n = len(dp_list)
    array = jnp.zeros(k**n)
    array = array.at[0].set(dr_init)

    # going down the tree, breadth first
    for i in range(n):   
        for j in range(2**i):
            array = array.at[2**i+j].set(g(jnp.mod(1+j, 2), dp_list[i], array[(2**i+j-1)//2])) # 1+j mod 2 is a hacky way to encode True/False
  
    array2 = jnp.zeros(k**n)
    # only need the second half of array as these are the final values passed to the final continuation dkont

    # jax.debug.print("array: {x}", x=array)
    mid = 2**(n-1)	
    # jax.debug.print("mid: {x}", x=mid)
    # jax.debug.print("array2: {x}", x=array2)
    array2 = array2.at[:mid].set(dkont(array[mid:]))
    # jax.debug.print("array2: {x}", x=array2)
    for i in range(n-1,-1,-1):  # going up the tree, breadth first
        for j in range(2**i):
            array2 = array2.at[2**i+j].set(array2[2*(2**i+j)] * dp_list[n-i-1] + array2[2*(2**i+j)+1] * (1.0 - dp_list[n-i-1])) # still computing in dual-number land
            
    return array2[0]

# v2: same as v1 but using more JAX primitives
def f2(dp_list, dr_init):
    k = 2  
    n = len(dp_list)
    array = jnp.zeros(k**n)
    array = array.at[0].set(dr_init)

    array = jax.lax.fori_loop(0, n, lambda i, x: jax.lax.fori_loop(0, 2**i, lambda j, y: y.at[2**i+j].set(g(jnp.mod(1+j, 2), dp_list[i], y[(2**i+j-1)//2])), x), array)
  
    array2 = jnp.zeros(k**n)

    mid = 2**(n-1)	
    array2 = array2.at[:mid].set(dkont(array[mid:]))
            
    array2 = jax.lax.fori_loop(0, n, lambda i, x: jax.lax.fori_loop(0, 2**(n-1-i), lambda j, y: y.at[2**(n-1-i)+j].set(y[2*(2**(n-1-i)+j)] * dp_list[i] + y[2*(2**(n-1-i)+j)+1] * (1.0 - dp_list[i])), x), array2)
    
    return array2[0]

# v3: using jax.lax.dynamic_slice
def f3(dp_list, dr_init):
    k = 2  
    n = len(dp_list)
    array = jnp.zeros(k**n)
    array = array.at[0].set(dr_init)

    array = jax.lax.fori_loop(0, n, 
        lambda i, x: jax.lax.fori_loop(0, 2**i, lambda j, y: y.at[2**i+j].set(g(jnp.mod(1+j, 2), dp_list[i], y[(2**i+j-1)//2])), x), array)
    
    # same as v2 but using dynamic_slice, using only one for loop
    # array = jax.lax.fori_loop(0, n, lambda i, x: jax.lax.dynamic_slice(x, [2**(n-1-i), 2**(i+1)], [2**(n-1-i), 0]), array)

    array2 = jnp.zeros(k**n)

    mid = 2**(n-1)	
    array2 = array2.at[:mid].set(dkont(array[mid:]))
            
    array2 = jax.lax.fori_loop(0, n, lambda i, x: jax.lax.fori_loop(0, 2**(n-1-i), lambda j, y: y.at[2**(n-1-i)+j].set(y[2*(2**(n-1-i)+j)] * dp_list[i] + y[2*(2**(n-1-i)+j)+1] * (1.0 - dp_list[i])), x), array2)
    
    return array2[0]

# v4: using naive parallelism
# hmmm, doesn't actually seem easier to write than v1
def f3(dp_list, dr_init):
    k = 2
    n = len(dp_list)
    array = jnp.full(k**n, dr_init)
    return array
    
# Testing
dp_list = jnp.arange(1, 25)
dr_init = 1.0
# print(f1(dp_list, dr_init))
print(f2(dp_list, dr_init))
print()
# jitted1 = jit(f1)
# print(jitted1(dp_list, dr_init))
jitted2 = jit(f2)
print(jitted2(dp_list, dr_init))

# %timeit jitted1(dp_list, dr_init)
%timeit jitted2(dp_list, dr_init)
# TODO: this is somehow very slow on GPU even though the doc says that the updates should be in-place as the function is jitted
# takes 400ms for len(dp_list) = 14 on GPU, and compile time is also not constant
# and growing rapidly with len(dp_list).

# Update from Matin: need to replace inner for_loop with jax.lax.dynamic_slice and dynamic_update_slice and a vmap. also jitting tree-map should be doing something similar on its own, at the price of a possibly much higher compile time.

-9.923923e+21

-9.923923e+21
79.5 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
