In [None]:
%env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu
# JAX_DISABLE_JIT=1

In [None]:
import jax
from jax import vmap, jit, grad, lax
import jax.numpy as jnp
from functools import partial

In [None]:
x = jnp.array([1, 2., 3, 4])
y = jnp.array([1., 2, 3])
m = (x <=2)

@jit
def csum(x, m):
    return jnp.sum(x**2, where=m)

@jit
def kernel(x, y):
    return jnp.sin(x)+ jnp.exp(-y**2)

@jit
def kernel2(x):
    ii = jnp.nonzero(x > 2, size=4, fill_value=-1)[0]
    tot = 0.0
    for i in ii:
        tot += x[i]
    return tot

@jit
def kernel3(x, y):
    tot = 0.0
    for xi in x:
        tot += kernel(xi, y).sum()
    return tot

@jit
def kernel4(x, m, y, my):
    tot = 0.0
    for xi, mi in zip(x, m):
        tot += jnp.where(mi, jnp.sum(kernel(xi, y), where=my), 0)
    return tot

@jit
def kernel5(x, m, y, my):
    tot = 0.0
    for xi, mi in zip(x, m):
        tot += jax.lax.cond(mi, lambda xi: jnp.sum(kernel(xi, y), where=my), lambda xi: 0.0, xi)
    return tot

In [None]:
%timeit kernel5(x, x > 2, y, y > 1)

In [None]:
%timeit kernel4(x, x > 2, y, y > 1)

In [None]:
jit(grad(kernel4))(x, x > 2, y, y > 1)

In [None]:
# %timeit csum(x, m)

In [None]:
# %timeit csum(x, m)

In [None]:
@jit
def f(carry, row):
    even = 0
    for n in row:
        even += jax.lax.cond(n % 2 == 0, lambda: 1, lambda: 0)
    return carry + even, even

numbers = jnp.array([[3.0, 14.0], [15.0, -7.0], [16.0, -11.0]])
numbers

In [None]:
%timeit jax.lax.scan(f, 0, numbers)

In [None]:
@jit
def f(carry, xy):
    res = jax.lax.cond(xy[0] > 2, lambda: 0, lambda: xy[1])
    return carry + res, res

@jit
def f2(carry, xy):
    res = jnp.where(xy[0] > 2, 0, xy[1])
    return carry + res, res

def f_slow(carry, x, y):
    total = 0
    for xi, yi in zip(x, y):
        if xi > 2:
            total += 0
        else:
            total += yi
    return carry + total

In [None]:
x = jnp.array(range(100))
y = jnp.array(range(100))

In [None]:
%timeit jax.lax.scan(f, 0, (x, y))

In [None]:
%timeit jax.lax.scan(f2, 0, (x, y))

In [None]:
%timeit f_slow(0, x, y)

In [None]:
class A:
    def __init__(self, a: float):
        self.a = a
        
    @partial(jax.jit, static_argnums=(0))
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.a * x.sum()
    
kernel = A(2.0)
kernel(jnp.array([1, 2]))

In [None]:
@jit
def f(carry, xy):
    res = jax.lax.cond(xy[0] > 2, lambda x: 0.0, kernel, xy[1])
    return carry + res, res

In [None]:
%timeit jax.lax.scan(f, 0, (x, y))

In [None]:
jax.lax.scan(f, 0, (x, y))

In [None]:
rij = jnp.array([1, 2., 3.])
mask_ik = rij > 2
dis_i = jnp.array([3., 4, 5])
xs = rij, mask_ik, dis_i
xs

In [None]:
%timeit jax.lax.scan(f, 0, xs)