In [6]:
import jax
from jax import grad
import jax.numpy as jnp
import jax.random
from jaxlib.xla_extension import DeviceArray
from jax._src.basearray import Array
import numpy as np
from functools import partial
import matplotlib.pyplot as plt

# JAX

### Gradient

In [3]:
alpha=1/2
def f(x):
    return x**alpha
print(grad(f)(1.))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


0.5


Gradient of the CES

In [7]:
def CES(x, alpha, sigma, nu:float=1, theta:float=1):
    return theta*jnp.sum(alpha*x**((sigma-1)/sigma))**(nu*sigma/(sigma-1))
theta = 1
sigma = 1.2
nu = 0.9
alpha = jnp.array([0.3,0.7])
x = jnp.array([0.5,1.5], dtype=jnp.float32)
grad(partial(CES, alpha=alpha, sigma=sigma, nu=nu))(x)

Array([0.51635367, 0.48230636], dtype=float32)

In [8]:
theta*nu*alpha*x**(-1/sigma)*CES(x, alpha, sigma, nu, theta)**((sigma*(nu-1)+1)/(nu*sigma))

Array([0.51635367, 0.48230636], dtype=float32)

### Multiplying matrices

In [46]:
import time

def multiply_matrix(size, package='np',loops=20):
    if loops>1:
        return np.mean([multiply_matrix(size,package,loops=1) for _ in range(loops)])
    np.random.seed(1)
    A = np.random.uniform(size=size)
    if package == 'np':
        start = time.process_time()
        A@A
        execution_time = time.process_time()-start
    elif package == 'jax':
        A = jnp.array(A)
        start = time.process_time()
        A@A
        execution_time = time.process_time()-start
    return execution_time

In [60]:
def multiply_matrix(matrix):
    start = time.process_time()
    matrix@matrix
    return time.process_time()-start

multiply_matrix_jit=jax.jit(multiply_matrix)

In [61]:
def time_function(fun,loops,**kwargs):
    if loops>1:
        return np.mean([time_function(fun,loops=1,**kwargs) for _ in range(loops)])
    else:
        return fun(**kwargs)


In [76]:
10**np.arange(1,6,1)

array([    10,    100,   1000,  10000, 100000])

In [77]:
sizes = 10**np.arange(1,5,1)
np.random.seed(1)
matrices = [np.random.uniform(size=(i,i)) for i in sizes]
matrices_jax = [np.random.uniform(size=(i,i)) for i in sizes]

fig,ax=plt.subplots()
ax.plot(sizes, [time_function(multiply_matrix, loops=10, matrix=matrix) for matrix in matrices],label='np')
ax.plot(sizes, [time_function(multiply_matrix, loops=10, matrix=jnp.array(matrix)) for matrix in matrices],label='jax')
ax.plot(sizes, [time_function(multiply_matrix_jit, loops=10, matrix=jnp.array(matrix)) for matrix in matrices],label='jax jit')
ax.legend()
ax.set(title='Matrix multiplication',xlabel='Matrix size')
fig.show()

In [66]:
def test_numpy(A):
    return A@A

def test_jax(A):
    return A@A

@jax.jit
def test_jax_jit(A):
    return A@A    

size=(10000,10000)
A_numpy = np.random.uniform(size=size)
A_jax = jnp.array(A_numpy)

In [67]:
%%timeit
test_numpy(A_numpy)

1.17 s ± 178 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [68]:
%%timeit
test_jax(A_jax)

470 ms ± 18.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [69]:
%%timeit
test_jax_jit(A_jax)

498 ms ± 21.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Infinite times 0 is nan

In [12]:
jnp.inf*0

nan

In [13]:
jnp.inf*False

nan

### While Loop

In [14]:
jax.lax.while_loop(lambda x: x[1]<4, lambda x: (x[0], x[1]+1), (-3,-1))

(DeviceArray(-3, dtype=int32, weak_type=True),
 DeviceArray(4, dtype=int32, weak_type=True))

### jit

In [4]:
seed = 123
key = jax.random.PRNGKey(seed)
A = jax.random.uniform(key, shape=(5,5), minval=-1, maxval=1)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


JAX can't `jit` functions with conditional slicing.

In [None]:
@jax.jit
def foo(A:Array)->Array:
    return A.at[A<0].set(0)
foo(A)

But sometimes there are workarounds

In [9]:
@jax.jit
def foo(A:Array)->Array:
    return jnp.maximum(A,0)
foo(A)

Array([[0.70078254, 0.        , 0.5044954 , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.29831648, 0.600971  ],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.07184577, 0.6359787 , 0.        , 0.        ],
       [0.        , 0.        , 0.09762692, 0.40657496, 0.        ]],      dtype=float32)

Slightly more complicated

In [10]:
def foo(A):
    return A.at[A<0].set(999)

@jax.jit
def foo(A:Array)->Array:
    mask = A<0
    return jnp.maximum(A,0)+mask*999
foo(A)

Array([[7.0078254e-01, 9.9900000e+02, 5.0449538e-01, 9.9900000e+02,
        9.9900000e+02],
       [9.9900000e+02, 9.9900000e+02, 9.9900000e+02, 2.9831648e-01,
        6.0097098e-01],
       [9.9900000e+02, 9.9900000e+02, 9.9900000e+02, 9.9900000e+02,
        9.9900000e+02],
       [9.9900000e+02, 7.1845770e-02, 6.3597870e-01, 9.9900000e+02,
        9.9900000e+02],
       [9.9900000e+02, 9.9900000e+02, 9.7626925e-02, 4.0657496e-01,
        9.9900000e+02]], dtype=float32)

In [13]:
x = jnp.array([1,0,0,1])

In [18]:
x.at[x==0].set(None)



Array([1, 0, 0, 1], dtype=int32)

In [None]:
A = jnp.array([-jnp.inf,3,4,-1])
jnp.min(A)

DeviceArray(-inf, dtype=float32)

In [None]:
-jnp.inf*0+1

nan

### Random number generator vs Numpy

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
seed = 123
key = jax.random.PRNGKey(seed)
x = jnp.linspace(0,1)
n=1

In [None]:
%%timeit
jax.random.choice(key, a=x, shape=(n,))

95 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%%timeit
np.random.choice(a=x, size=n)

13.6 µs ± 69.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [None]:
n_it = 10000

In [None]:
%%timeit
for _ in range(n_it):
    jax.random.choice(key, a=x, shape=(n,))

943 ms ± 627 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
jax.vmap(lambda i: jax.random.choice(key, a=x, shape=(n,)))(jnp.array(range(n_it)))

1.92 ms ± 8.18 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### vmap

In [38]:
def foo(x):
    return x.shape

x = jnp.array([[1,2,3]]).T

print(jax.vmap(foo)(x))


def foo(x,y):
    print(x.shape)
    print(y.shape)
    return x*y

y = jnp.array([[3,2,1],[10,20,30]]).T

print(jax.vmap(foo)(x,y))   

(DeviceArray([1, 1, 1], dtype=int32, weak_type=True),)
(1,)
(2,)
[[ 3 10]
 [ 4 40]
 [ 3 90]]


(3,)

### Pass functions in jitted functions

Use `Partial` from `jax` instead of `functools.partial`. In this example, `foo` evaluates `x` at the function `fun`. However, `fun` also requires one argument `add`. 

In [3]:
from jax._src.tree_util import Partial
@jax.jit
def foo(x, fun):
    return fun(x)

@jax.jit
def f(x, add):
    return x+add

foo(1, Partial(f, add=1))

Array(2, dtype=int32, weak_type=True)

Static argnames

In [None]:
@partial(jax.jit, static_argnames='hola')

### CPU vs GPU

In [None]:
import jax