In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, jacfwd, jacrev
import optax

In [2]:
def predict(params, features):
    logits = features @ params["w"] + params["b"]
    return jnp.array([logits, jax.nn.sigmoid(logits)])

In [3]:
batched_predict = vmap(predict, in_axes=(None, 0))

In [4]:
params = {"w": jnp.array([1.5, -0.5]), "b": 0.25}
X = jnp.array([[1.0, 2.0], [2.5, -1.0], [-0.25, 0.75]])
logits_probs = batched_predict(params, X)
print("shape:", logits_probs.shape)
print("first row:", logits_probs[0])

shape: (3, 2)
first row: [0.75      0.6791787]


In [5]:
print(logits_probs)

[[ 0.75        0.6791787 ]
 [ 4.5         0.9890131 ]
 [-0.5         0.37754068]]


In [6]:
def scaled_loss(scale, preds, targets):
    residual = preds - targets
    return scale * jnp.sum(residual**2)

batched_loss = vmap(scaled_loss, in_axes=(None, 0, 0))

scale = 0.5
preds = jnp.array([[0.2, 0.8], [0.6, 0.4]])
targets = jnp.array([[0.0, 1.0], [1.0, 0.0]])
print(batched_loss(scale, preds, targets))

[0.04 0.16]


In [7]:
def normalize(x):
    x_centered = x - jnp.mean(x)
    return x_centered / jnp.std(x)

images = jax.random.normal(jax.random.PRNGKey(0), (32, 28, 28))
channel_last_norm = vmap(normalize, in_axes=0, out_axes=-1)(images)
print("channel_last_norm shape:", channel_last_norm.shape)

channel_last_norm shape: (28, 28, 32)


In [8]:
def euclidean_distance(x, y):
    return jnp.sqrt(jnp.sum((x - y) ** 2))

pairwise_dist = vmap(vmap(euclidean_distance, in_axes=(None, 0)), in_axes=(0, None))

X = jnp.array([[0., 0.], [1., 0.], [0., 1.]])
Y = jnp.array([[0., 0.], [1., 1.]])
print(pairwise_dist(X, Y))

[[0.        1.4142135]
 [1.        1.       ]
 [1.        1.       ]]


In [9]:
def loss_per_example(params, x, y):
    logits = x @ params["w"] + params["b"]
    return jnp.mean(optax.sigmoid_binary_cross_entropy(logits, y))

grad_per_example = vmap(grad(loss_per_example), in_axes=(None, 0, 0))

params = {"w": jnp.array([1.0, -2.0]), "b": 0.3}
features = jnp.array([[1.0, 0.0], [0.5, 1.5], [2.0, -1.0]])
targets = jnp.array([[1., 0., 1.], [0., 1., 0.], [1., 0., 1.]])
grads = grad_per_example(params, features, targets)
print("dw shape:", grads["w"].shape)
print("db shape:", grads["b"].shape)

dw shape: (3, 2)
db shape: (3,)


In [10]:
def softmax_logits(theta):
    return jax.nn.softmax(theta)

theta = jnp.array([1.0, 0.0, -1.0])
jacobian = jacrev(softmax_logits)(theta)
print(jacobian)

[[ 0.22269541 -0.1628034  -0.05989202]
 [-0.1628034   0.18483643 -0.02203304]
 [-0.05989202 -0.02203305  0.08192506]]


In [11]:
def dropout_layer(key, x, rate=0.1):
    keep = jax.random.bernoulli(key, p=1.0 - rate, shape=x.shape)
    return keep * x / (1.0 - rate)

def batched_dropout(key, inputs):
    keys = jax.random.split(key, inputs.shape[0])
    return vmap(dropout_layer)(keys, inputs)

key = jax.random.PRNGKey(42)
activations = jnp.ones((4, 5))
print(batched_dropout(key, activations))

[[1.1111112 1.1111112 0.        1.1111112 1.1111112]
 [1.1111112 1.1111112 1.1111112 1.1111112 1.1111112]
 [1.1111112 1.1111112 1.1111112 1.1111112 1.1111112]
 [1.1111112 1.1111112 1.1111112 1.1111112 1.1111112]]


In [12]:
def rnn_cell(carry, x_t, params):
    h_prev = carry
    w_hh, w_xh, b = params
    h = jnp.tanh(h_prev @ w_hh + x_t @ w_xh + b)
    return h, h

def run_sequence(inputs, params):
    h0 = jnp.zeros(params[0].shape[0])
    _, outputs = jax.lax.scan(lambda c, x: rnn_cell(c, x, params), h0, inputs)
    return outputs

batched_run_sequence = vmap(run_sequence, in_axes=(0, None))
jitted_batched_run = jit(batched_run_sequence)