In [1]:
# Derivative
from jax import grad
import jax.numpy as jnp

def squared(x):
    return x**2

squared(jnp.array([-3., 0, 3.]))

Array([9., 0., 9.], dtype=float32)

In [2]:
d_squared = grad(squared)
d_squared(2.)

Array(4., dtype=float32, weak_type=True)

In [3]:
# TypeError: Gradient only defined for scalar-output functions. Output had shape: (3,).
d_squared(jnp.array([-3, 0., 3.]))

TypeError: Gradient only defined for scalar-output functions. Output had shape: (3,).

In [4]:
# Auto-vectorization
from jax import vmap

vmap(d_squared)(jnp.arange(-3., 3.))

Array([-6., -4., -2.,  0.,  2.,  4.], dtype=float32)

In [5]:
# Jit
from jax import jit

%timeit d_squared(-2.).block_until_ready()

jit_d_squared = jit(d_squared)
%timeit jit_d_squared(-2.).block_until_ready()

512 μs ± 10.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
4.84 μs ± 15.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [6]:
# Scalar output
from jax import grad

def sum_squared(x):
  return jnp.sum(x**2)

d_sum_squared = grad(sum_squared)
d_sum_squared(jnp.array([-3., 0., 3.]))

Array([-6.,  0.,  6.], dtype=float32)

In [7]:
# Weird behavior
def squared2(x1, x2, x3):
    return x1**2 + x2**2 + x3**2

d_squared2 = grad(squared2)
d_squared2(3., 2., 1.)

Array(6., dtype=float32, weak_type=True)

In [8]:
# Random
from jax import random

key = random.key(1701)
x = random.normal(key, (10,))
display(x)
x = random.normal(key, (10,))
display(x)
x = random.normal(random.key(1702), (10,))
x

Array([ 0.7112138 , -0.30713332, -0.35718888,  1.3868424 , -0.40255332,
       -0.7842863 ,  1.0894859 ,  1.2576592 , -0.96818185,  0.17469993],      dtype=float32)

Array([ 0.7112138 , -0.30713332, -0.35718888,  1.3868424 , -0.40255332,
       -0.7842863 ,  1.0894859 ,  1.2576592 , -0.96818185,  0.17469993],      dtype=float32)

Array([ 2.2102735 ,  0.3470336 , -1.2627498 ,  0.69964904,  0.63290447,
       -1.2527993 , -0.9961991 ,  0.07268224,  0.12105393,  0.93025947],      dtype=float32)

In [9]:
# jnp.where
def y(x):
    return jnp.where(x > 0, x**2, -x**2)

y(jnp.array([-2., 0, 2.]))

Array([-4., -0.,  4.], dtype=float32)

In [10]:
# jnp.dot
def mx(x):
    m = jnp.array([
        [0., 1.],
        [10., 11.],
        [20., 21.],
        [30., 31.]])
    return jnp.dot(m, x)

mx(jnp.array([[0., 1., 2.], [10., 11., 12.]]))

Array([[ 10.,  11.,  12.],
       [110., 131., 152.],
       [210., 251., 292.],
       [310., 371., 432.]], dtype=float32)

In [12]:
# Batching
vmap(mx)(jnp.array([[0., 1.], [10., 11.]]))

Array([[  1.,  11.,  21.,  31.],
       [ 11., 221., 431., 641.]], dtype=float32)

In [14]:
# matmul
a = jnp.array([[0, 1, 2],
               [3, 4, 5]])
b = jnp.array([[0, 1],
               [2, 3],
               [4, 5]])
jnp.matmul(a, b)

Array([[10, 13],
       [28, 40]], dtype=int32)

In [15]:
# @ operator
a @ b

Array([[10, 13],
       [28, 40]], dtype=int32)