In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap

In [4]:
def predict(params, inputs):
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        inputs = jnp.tahn(outputs) # inputs to the next layer
    return outputs # no activation on last layer

In [5]:
def loss(params, inputs, targets):
    preds = predict(params, inputs)
    return jnp.sum((preds - targets) ** 2)

In [6]:
grad_loss = jit(grad(loss)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads