In [14]:
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np

In [2]:
"""Jax numpy"""


def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


In [3]:
"""Random number generation"""

key = jax.random.key(1701)
x = jax.random.normal(key, (1_000_000,))

In [4]:
"""JIT compilation"""
%timeit selu(x).block_until_ready()

selu_jit = jax.jit(selu)
_ = selu_jit(x)

%timeit selu_jit(x).block_until_ready()

173 μs ± 5.62 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
23 μs ± 1.67 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [5]:
"""Gradients"""


def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))


x_small = jnp.arange(3.0)
deriv_fn = jax.grad(sum_logistic)
print(deriv_fn(x_small))


def finite_diff(f, x, eps=1e-3):
    return jnp.array(
        [(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in jnp.eye(len(x))]
    )


print(finite_diff(sum_logistic, x_small))

[0.25       0.19661194 0.10499357]
[0.24998187 0.1965761  0.10502338]


In [7]:
"""Jacobians and Hessians"""

print(jax.jacobian(jnp.exp)(x_small))


def hessian(fn):
    return jax.jit(jax.jacfwd(jax.jacrev(fn)))


print(hessian(sum_logistic)(x_small))
print(jax.hessian(sum_logistic)(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]
[[-0.         -0.         -0.        ]
 [-0.         -0.09085774 -0.        ]
 [-0.         -0.         -0.07996248]]
[[-0.         -0.         -0.        ]
 [-0.         -0.09085774 -0.        ]
 [-0.         -0.         -0.07996248]]


In [None]:
"""jax.vmap()"""

key1, key2 = jax.random.split(key)
mat = jax.random.normal(key1, (150, 100))
batched_x = jax.random.normal(key2, (10, 100))


def apply_matrix(x):
    return jnp.dot(mat, x)


def naively_batched_apply_matrix(batched_x):
    return jnp.stack([apply_matrix(x) for x in batched_x])


@jax.jit
def batched_apply_matrix(batched_x):
    return jnp.dot(batched_x, mat.T)


@jax.jit
def vmap_batched_apply_matrix(batched_x):
    return jax.vmap(apply_matrix)(batched_x)


print("Naively batched")
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
print("Manually batched")
_ = batched_apply_matrix(batched_x).block_until_ready()
%timeit batched_apply_matrix(batched_x).block_until_ready()
print("Auto-vectorized with vmap")
_ = vmap_batched_apply_matrix(batched_x).block_until_ready()
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

np.testing.assert_allclose(
    naively_batched_apply_matrix(batched_x),
    batched_apply_matrix(batched_x),
    atol=1e-4,
    rtol=1e-6,
)

Naively batched
729 μs ± 31.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Manually batched
24.4 μs ± 1.92 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Auto-vectorized with vmap
23 μs ± 1.93 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
"""Jax expression"""

print(jax.make_jaxpr(selu)(x_small))
print(jax.make_jaxpr(sum_logistic)(x_small))

{ lambda ; a:f32[3]. let
    b:bool[3] = gt a 0.0
    c:f32[3] = exp a
    d:f32[3] = mul 1.6699999570846558 c
    e:f32[3] = sub d 1.6699999570846558
    f:f32[3] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[3] h:f32[3] i:f32[3]. let
          j:f32[3] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[3] = mul 1.0499999523162842 f
  in (k,) }
{ lambda ; a:f32[3]. let
    b:f32[3] = neg a
    c:f32[3] = exp b
    d:f32[3] = add 1.0 c
    e:f32[3] = div 1.0 d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }


In [31]:
"""PRNG keys"""

key = jax.random.key(43)
print(key)

print(jax.random.normal(key))
print(jax.random.normal(key))

for i in range(3):
    new_key, sub_key = jax.random.split(key)
    del key
    val = jax.random.normal(sub_key)
    del sub_key
    print(f"draw {i}: {val}")
    key = new_key

key, *subkeys = jax.random.split(key, num=4)
print(key, subkeys)

Array((), dtype=key<fry>) overlaying:
[ 0 43]
0.07520543
0.07520543
draw 0: -1.9133632183074951
draw 1: -1.4749839305877686
draw 2: -0.36703771352767944
Array((), dtype=key<fry>) overlaying:
[3722464693 2600049559] [Array((), dtype=key<fry>) overlaying:
[1615207904  772808876], Array((), dtype=key<fry>) overlaying:
[ 153309274 3877468463], Array((), dtype=key<fry>) overlaying:
[911978728  92600883]]


In [32]:
"""Lack of sequential equivalence"""

key = jax.random.key(42)
subkeys = jax.random.split(key, num=3)
sequences = np.stack([jax.random.normal(subkey) for subkey in subkeys])
print(sequences)

key = jax.random.key(42)
print(jax.random.normal(key, shape=(3,)))

key = jax.random.key(42)
subkeys = jax.random.split(key, num=3)
print(jax.vmap(jax.random.normal)(subkeys))

[0.07592554 0.60576403 0.4323065 ]
[-0.02830462  0.46713185  0.29570296]
[0.07592554 0.60576403 0.4323065 ]


In [11]:
"""Impure functions"""


@jax.jit
def log2_with_print(x):
    print(x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2


print("Call 1")
print(log2_with_print(3.0))
print("Call 2")
print(log2_with_print(3.0))
print(jax.make_jaxpr(log2_with_print)(3.0))

Call 1
Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>
1.5849625
Call 2
1.5849625
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
      name=log2_with_print
      jaxpr={ lambda ; c:f32[]. let
          d:f32[] = log c
          e:f32[] = log 2.0
          f:f32[] = div d e
        in (f,) }
    ] a
  in (b,) }


In [12]:
"""Retracing"""


@jax.jit
def ndim_func(x):
    print(x)
    if x.ndim == 2:
        return -x
    else:
        return x


print("Branch 1")
print(ndim_func(jnp.array([3.0])))
print(ndim_func(jnp.array([3.0])))
print("Branch 2")
print(ndim_func(jnp.array([3.0, 3.0])))
print(ndim_func(jnp.array([3.0, 3.0])))

Branch 1
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace>


[3.]
[3.]
Branch 2
Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>
[3. 3.]
[3. 3.]


In [20]:
"""Conditionals"""


# @jax.jit  # Fails
def abs_val(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i


print(abs_val(3.0, 5))


@jax.jit  # Works (always define outside scope)
def loop_body(prev_i):
    return prev_i + 1


def abs_val_jit(x, n):
    i = 0
    while i < n:
        i = loop_body(i)
    return x + i


print(abs_val_jit(3.0, 5))

8.0
8.0


In [19]:
"""Static arguments - needs recompilation for each value of n"""


@partial(jax.jit, static_argnames=["n"])  # Or: @jax.jit(static_argnums=(1,))
def abs_val_decorated(x, n):
    print("Compiling")
    i = 0
    while i < n:
        i += 1
    return x + i


print(abs_val_decorated(3.0, 5))
print(abs_val_decorated(4.0, 5))
print(abs_val_decorated(4.0, 6))

Compiling
8.0
9.0
Compiling
10.0


In [22]:
"""Higher-order derivatives"""

f = lambda x: x**3 + 2 * x**2 - 3 * x + 1

dfdx = jax.grad(f)
d2fdx2 = jax.grad(dfdx)
d3fdx3 = jax.grad(d2fdx2)
d4fdx4 = jax.grad(d3fdx3)

print(dfdx(1.0))
print(d2fdx2(1.0))
print(d3fdx3(1.0))
print(d4fdx4(1.0))

4.0
10.0
6.0
0.0


In [24]:
"""Logistic regression example"""

key = jax.random.key(0)


def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)


def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)


inputs = jnp.array(
    [[0.52, 1.12, 0.77], [0.88, -1.08, 0.15], [0.52, 0.06, -1.30], [0.74, -2.49, 1.39]]
)

targets = jnp.array([True, True, False, True])


def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))


key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())

W_grad = jax.grad(loss, argnums=0)(W, b)
print(f"W_grad: {W_grad}")
b_grad = jax.grad(loss, 1)(W, b)
print(f"b_grad: {b_grad}")
grads = jax.grad(loss, argnums=(0, 1))(W, b)
print(f"Both gradients: {grads}")

W_grad: [-0.43314588 -0.7354602  -1.2598921 ]
b_grad: -0.690017580986023
Both gradients: (Array([-0.43314588, -0.7354602 , -1.2598921 ], dtype=float32), Array(-0.6900176, dtype=float32))


In [25]:
"""Differentiating w.r.t. lists, tuples, and dicts"""


def loss2(params_dict):
    preds = predict(params_dict["W"], params_dict["b"], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))


grads = jax.grad(loss2)({"W": W, "b": b})
print(grads)

{'W': Array([-0.43314588, -0.7354602 , -1.2598921 ], dtype=float32), 'b': Array(-0.6900176, dtype=float32)}


In [26]:
"""Forward and backwards in one pass"""

loss_val, grads = jax.value_and_grad(loss, (0, 1))(W, b)
print(loss_val)
print(grads)

2.9729185
(Array([-0.43314588, -0.7354602 , -1.2598921 ], dtype=float32), Array(-0.6900176, dtype=float32))


In [27]:
"""Test gradients"""

from jax.test_util import check_grads

check_grads(loss, (W, b), order=1)
check_grads(loss, (W, b), order=2)

In [39]:
"""Pytrees"""

trees = [
    [1, "a", object()],
    (1, (2, 3), ()),
    [1, {"k1": 2, "k2": (3, 4)}, 5],
    {"a": 2, "b": (2, 3)},
    jnp.array([1, 2, 3]),
]

for pytree in trees:
    leaves = jax.tree.leaves(pytree)
    print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4],
]

print(jax.tree.map(lambda x: x**2, list_of_lists))
another_list_of_lists = list_of_lists
print(jax.tree.map(lambda x, y: x + y, list_of_lists, another_list_of_lists))

[1, 'a', <object object at 0x7d1ef40c6c80>]   has 3 leaves: [1, 'a', <object object at 0x7d1ef40c6c80>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]
[[1, 4, 9], [1, 4], [1, 4, 9, 16]]
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]


In [45]:
"""Pytree MLP example"""


def init_mlp_params(layer_widths):
    params = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        params.append(
            dict(
                weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2 / n_in),
                biases=np.ones(shape=(n_out,)),
            )
        )
    return params


params = init_mlp_params([1, 4, 4, 1])
print(jax.tree.map(lambda x: x.shape, params))


def forward(params, x):
    *hidden, last = params
    for layer in hidden:
        x = jax.nn.relu(x @ layer["weights"] + layer["biases"])
    return x @ last["weights"] + last["biases"]


def loss_fn(params, x, y):
    return jnp.mean((forward(params, x) - y) ** 2)


eta = 1e-4


@jax.jit
def update(params, x, y):
    grads = jax.grad(loss_fn)(params, x, y)
    return jax.tree.map(lambda p, g: p - eta * g, params, grads)


print(params)
print(update(params, np.array([[1.0], [2.0]]), np.array([[0.0], [0.0]])))

[{'biases': (4,), 'weights': (1, 4)}, {'biases': (4,), 'weights': (4, 4)}, {'biases': (1,), 'weights': (4, 1)}]
[{'weights': array([[-1.1025483 ,  0.68066132, -1.91563952,  0.89516076]]), 'biases': array([1., 1., 1., 1.])}, {'weights': array([[ 0.19618383,  0.00748414, -0.24993495,  0.61554604],
       [ 0.02626683,  0.43764833, -0.61300345, -0.36793265],
       [-0.62884433, -0.14180592,  0.36823245,  0.71332384],
       [ 0.30250433,  1.30756572, -0.98433043,  0.11461864]]), 'biases': array([1., 1., 1., 1.])}, {'weights': array([[ 1.16071927],
       [-0.89529667],
       [ 1.32671317],
       [-0.98126922]]), 'biases': array([1.])}]
[{'biases': Array([1.       , 0.9999999, 1.       , 0.9996458], dtype=float32), 'weights': Array([[-1.1025482 ,  0.68066114, -1.9156395 ,  0.8945907 ]], dtype=float32)}, {'biases': Array([1.0004411, 0.9996598, 1.       , 0.9996271], dtype=float32), 'weights': Array([[ 0.19618383,  0.00748414, -0.24993494,  0.61554605],
       [ 0.02719115,  0.43693537, -

In [48]:
"""Gradient checkpointing"""


def g(W, x):
    y = jnp.dot(W, x)
    return jnp.sin(y)


def f(W1, W2, W3, x):
    x = g(W1, x)
    x = g(W2, x)
    x = g(W3, x)
    return x


W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
print()


# Force not to save residuals
def f2(W1, W2, W3, x):
    x = jax.checkpoint(g)(W1, x)
    x = jax.checkpoint(g)(W2, x)
    x = jax.checkpoint(g)(W3, x)
    return x


jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)

f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1161818/916765993.py:5:11 (g)
f32[5] output of cos from /tmp/ipykernel_1161818/916765993.py:5:11 (g)
f32[6] output of sin from /tmp/ipykernel_1161818/916765993.py:5:11 (g)
f32[6] output of cos from /tmp/ipykernel_1161818/916765993.py:5:11 (g)
f32[7] output of cos from /tmp/ipykernel_1161818/916765993.py:5:11 (g)

f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1161818/916765993.py:5:11 (g)
f32[6] output of sin from /tmp/ipykernel_1161818/916765993.py:5:11 (g)
