jupytext | kernelspec | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
+++ {"id": "kORMl5KmfByI"}
Authors: Vlatimir Mikulik & Matteo Hessel
Computing gradients is a critical part of modern machine learning methods. This section considers a few advanced topics in the areas of automatic differentiation as it relates to modern machine learning.
While understanding how automatic differentiation works under the hood isn't crucial for using JAX in most contexts, we encourage the reader to check out this quite accessible video to get a deeper sense of what's going on.
The Autodiff Cookbook is a more advanced and more detailed explanation of how these ideas are implemented in the JAX backend. It's not necessary to understand this to do most things in JAX. However, some features (like defining custom derivatives) depend on understanding this, so it's worth knowing this explanation exists if you ever need to use them.
+++ {"id": "qx50CO1IorCc"}
JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations.
We illustrate this in the single-variable case:
The derivative of
:id: Kqsbj98UTVdi
import jax
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
+++ {"id": "ItEt15OGiiAF"}
The higher-order derivatives of
Computing any of these in JAX is as easy as chaining the grad
function:
:id: 5X3yQqLgimqH
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)
+++ {"id": "fVL2P_pcj8T1"}
Evaluating the above in
Using JAX:
:id: tJkIp9wFjxL3
:outputId: 581ecf87-2d20-4c83-9443-5befc1baf51d
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
+++ {"id": "3-fTelU7LHRr"}
In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its Hessian matrix, defined according to
The Hessian of a real-valued function of several variables, jax.jacfwd
and jax.jacrev
, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – see the video about autodiff linked above for an explanation.
:id: ILhkef1rOB6_
def hessian(f):
return jax.jacfwd(jax.grad(f))
+++ {"id": "xaENwADXOGf_"}
Let's double check this is correct on the dot-product
if
:id: Xm3A0QdWRdJl
:outputId: e1e8cba9-b567-439b-b8fc-34b21497e67f
import jax.numpy as jnp
def f(x):
return jnp.dot(x, x)
hessian(f)(jnp.array([1., 2., 3.]))
+++ {"id": "7_gbi34WSUsD"}
Often, however, we aren't interested in computing the full Hessian itself, and doing so can be very inefficient. The Autodiff Cookbook explains some tricks, like the Hessian-vector product, that allow to use it without materialising the whole matrix.
If you plan to work with higher-order derivatives in JAX, we strongly recommend reading the Autodiff Cookbook.
+++ {"id": "zMT2qAi-SvcK"}
Some meta-learning techniques, such as Model-Agnostic Meta-Learning (MAML), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier:
def meta_loss_fn(params, data):
"""Computes the loss after one step of SGD."""
grads = jax.grad(loss_fn)(params, data)
return loss_fn(params - lr * grads, data)
meta_grads = jax.grad(meta_loss_fn)(params, data)
+++ {"id": "3h9Aj3YyuL6P"}
Auto-diff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, we might want some additional control: for instance, we might want to avoid back-propagating gradients through some subset of the computational graph.
Consider for instance the TD(0) (temporal difference) reinforcement learning update. This is used to learn to estimate the value of a state in an environment from experience of interacting with the environment. Let's assume the value estimate
:id: fjLqbCb6SiOm
# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])
+++ {"id": "85S7HBo1tBzt"}
Consider a transition from a state
:id: T6cRPau6tCSE
# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])
+++ {"id": "QO5CHA9_Sk01"}
The TD(0) update to the network parameters is:
This update is not the gradient of any loss function.
However, it can be written as the gradient of the pseudo loss function
if the dependency of the target
How can we implement this in JAX? If we write the pseudo loss naively we get:
:id: uMcFny2xuOwz
:outputId: 79c10af9-10b8-4e18-9753-a53918b9d72d
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return (target - v_tm1) ** 2
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
+++ {"id": "CPnjm59GG4Gq"}
But td_update
will not compute a TD(0) update, because the gradient computation will include the dependency of target
on
We can use jax.lax.stop_gradient
to force JAX to ignore the dependency of the target on
:id: WCeq7trKPS4V
:outputId: 0f38d754-a871-4c47-8e3a-a961418a24cc
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return (jax.lax.stop_gradient(target) - v_tm1) ** 2
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
+++ {"id": "TNF0CkwOTKpD"}
This will treat target
as if it did not depend on the parameters
The jax.lax.stop_gradient
may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss).
The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function jax.lax.stop_gradient
:
:id: hdORJENmVHvX
:outputId: f0839541-46a4-45a9-fce7-ead08f20046b
def f(x):
return jnp.round(x) # non-differentiable
def straight_through_f(x):
# Create an exactly-zero expression with Sterbenz lemma that has
# an exactly-one gradient.
zero = x - jax.lax.stop_gradient(x)
return zero + jax.lax.stop_gradient(f(x))
print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))
print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
+++ {"id": "Wx3RNE0Sw5mn"}
While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch.
For instance, this is needed to prioritise data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis.
In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient.
In JAX we can define the code to compute the gradient per-sample in an easy but efficient way.
Just combine the jit
, vmap
and grad
transformations together:
:id: tFLyd9ifw4GG
:outputId: bf3ad4a3-102d-47a6-ece0-f4a8c9e5d434
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))
# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
+++ {"id": "VxvYVEYQYiS_"}
Let's walk through this one transformation at a time.
First, we apply jax.grad
to td_loss
to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs:
:id: rPO67QQrY5Bk
:outputId: fbb45b98-2dbf-4865-e6e5-87dc3eef5560
dtdloss_dtheta = jax.grad(td_loss)
dtdloss_dtheta(theta, s_tm1, r_t, s_t)
+++ {"id": "cU36nVAlcnJ0"}
This function computes one row of the array above.
+++ {"id": "c6DQF0b3ZA5u"}
Then, we vectorise this function using jax.vmap
. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, we produce a batch of outputs -- each output in the batch corresponds to the gradient for the corresponding member of the input batch.
:id: 5agbNKavaNDM
:outputId: ab081012-88ab-4904-a367-68e9f81445f0
almost_perex_grads = jax.vmap(dtdloss_dtheta)
batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
+++ {"id": "K-v34yLuan7k"}
This isn't quite what we want, because we have to manually feed this function a batch of theta
s, whereas we actually want to use a single theta
. We fix this by adding in_axes
to the jax.vmap
, specifying theta as None
, and the other args as 0
. This makes the resulting function add an extra axis only to the other arguments, leaving theta
unbatched, as we want:
:id: S6kd5MujbGrr
:outputId: d3d731ef-3f7d-4a0a-ce91-7df57627ddbd
inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
+++ {"id": "O0hbsm70be5T"}
Almost there! This does what we want, but is slower than it has to be. Now, we wrap the whole thing in a jax.jit
to get the compiled, efficient version of the same function:
:id: Fvr709FcbrSW
:outputId: 627db899-5620-4bed-8d34-cd1364d3d187
perex_grads = jax.jit(inefficient_perex_grads)
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
:id: FH42yzbHcNs2
:outputId: c8e52f93-615a-4ce7-d8ab-fb6215995a39
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()