How to get the derivative wrt. the hidden activations of a model in Flax/JAX? #1152
-
Original question by @untom. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
Answer by @jheek: It is annoying if you don't have code that nicely factors into functions, so if you want the grad wrt to all hidden activations it's definitely annoying but there is a trick: class GradWrapper(nn.Module):
mdl: Module
@nn.compact
def __call__(self, *args, **kwargs):
y = self.mdl(*args, **kwargs)
eps = self.variable('inter_grads', 'activation', lambda: jnp.zeros_like(y))
return y + eps
variables = model.init(...)
grads = jax.grad(model.apply)(variables, batch)
param_grads = grads['params']
inter_grads = grads['inter_grads'] So the trick is to add a "delta epsilon" everywhere you want to add intermediate gradients (it's like calculus 101 all over again 😛). The wrapper itself might not be the right place to put it. You just need that zero variable which adds to whatever you want to track. The only problem with this pattern is that you should make sure you don't "materialize" the zeros for optimal performance. If XLA knows it's just zeros it can optimize the redundant + zeros away XLA should realize that it's just zeros in this case if you simply do: inter_grad = jax.tree_map(jnp.zeros_like, variables['inter_grad']) |
Beta Was this translation helpful? Give feedback.
-
We now have the Module.perturb API for this. Check out Extracting gradients of intermediate values for a complete walk through on how to do this. In the meanwhile here is a short example from the documenation: import jax
import jax.numpy as jnp
import flax.linen as nn
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(3)(x)
x = self.perturb('dense3', x)
return nn.Dense(2)(x)
def loss(params, perturbations, inputs, targets):
variables = {'params': params, 'perturbations': perturbations}
preds = model.apply(variables, inputs)
return jnp.square(preds - targets).mean()
x = jnp.ones((2, 9))
y = jnp.ones((2, 2))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
print(intm_grads['dense3']) # ==> [[-1.456924 -0.44332537 0.02422847]
# [-1.456924 -0.44332537 0.02422847]] |
Beta Was this translation helpful? Give feedback.
We now have the Module.perturb API for this. Check out Extracting gradients of intermediate values for a complete walk through on how to do this.
In the meanwhile here is a short example from the documenation: