In [1]:
import jax
from jax import grad
import jax.numpy as jnp
import jax.random
from jaxlib.xla_extension import DeviceArray
import numpy as np
from functools import partial

# JAX

### Gradient

In [2]:
alpha=1/2
def f(x):
    return x**alpha

In [3]:
print(grad(f)(1.))



0.5


In [37]:
def CES(x, alpha, sigma, nu:float=1, theta:float=1):
    return theta*jnp.sum(alpha*x**((sigma-1)/sigma))**(nu*sigma/(sigma-1))

In [48]:
theta = 1
sigma = 1.2
nu = 0.9
alpha = jnp.array([0.3,0.7])
x = jnp.array([0.5,1.5], dtype=jnp.float32)

In [49]:
grad(partial(CES, alpha=alpha, sigma=sigma, nu=nu))(x)

DeviceArray([0.51635367, 0.48230636], dtype=float32)

In [52]:
theta*nu*alpha*x**(-1/sigma)*CES(x, alpha, sigma, nu, theta)**((sigma*(nu-1)+1)/(nu*sigma))

DeviceArray([0.51635367, 0.48230636], dtype=float32)

In [33]:
def test_numpy(A):
    return A@A

def test_jax(A):
    return A@A

@jax.jit
def test_jax_jit(A):
    return A@A    

size=(1000,1000)
A_numpy = np.random.uniform(size=size)
A_jax = jnp.array(A_numpy)

In [34]:
%%timeit
test_numpy(A_numpy)

3.48 ms ± 38.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [35]:
%%timeit
test_jax(A_jax)

4.55 ms ± 37.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [36]:
%%timeit
test_jax_jit(A_jax)

4.56 ms ± 117 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [67]:
jnp.inf*0

nan

In [68]:
jnp.inf*False

nan

### While Loop

In [37]:
jax.lax.while_loop(lambda x: x[1]<4, lambda x: (x[0], x[1]+1), (-3,-1))

(DeviceArray(-3, dtype=int32, weak_type=True),
 DeviceArray(4, dtype=int32, weak_type=True))

## jit

In [38]:
seed = 123
key = jax.random.PRNGKey(seed)
A = jax.random.uniform(key, shape=(5,5), minval=-1, maxval=1)

JAX can't `jit` functions with conditional slicing.

In [None]:
@jax.jit
def foo(A:DeviceArray)->DeviceArray:
    return A.at[A<0].set(0)
foo(A)

But sometimes there are workarounds

In [50]:
@jax.jit
def foo(A:DeviceArray)->DeviceArray:
    return jnp.maximum(A,0)
foo(A)

DeviceArray([[0.70078254, 0.        , 0.5044954 , 0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.29831648, 0.600971  ],
             [0.        , 0.        , 0.        , 0.        , 0.        ],
             [0.        , 0.07184577, 0.6359787 , 0.        , 0.        ],
             [0.        , 0.        , 0.09762692, 0.40657496, 0.        ]],            dtype=float32)

Slightly more complicated

In [64]:
def foo(A):
    return A.at[A<0].set(999)

#@jax.jit
def foo(A:DeviceArray)->DeviceArray:
    mask = A<0
    return jnp.maximum(A,0)+mask*jnp.inf
foo(A)

DeviceArray([inf,  3.,  4., inf], dtype=float32)

In [61]:
A = jnp.array([-jnp.inf,3,4,-1])
jnp.min(A)

DeviceArray(-inf, dtype=float32)

In [66]:
-jnp.inf*0+1

nan