Skip to content

How to get the derivative wrt. the hidden activations of a model in Flax/JAX? #1152

Answered by cgarciae
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

We now have the Module.perturb API for this. Check out Extracting gradients of intermediate values for a complete walk through on how to do this.

In the meanwhile here is a short example from the documenation:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(3)(x)
        x = self.perturb('dense3', x)
        return nn.Dense(2)(x)

def loss(params, perturbations, inputs, targets):
  variables = {'params': params, 'perturbations': perturbations}
  preds = model.apply(variables, inputs)
  return jnp.square(preds - targets).mean()

x = jnp.ones((2, 9))
y = jnp.ones((2, 2))
model = Foo()
variables = 

Replies: 2 comments 5 replies

Comment options

You must be logged in to vote
3 replies
@bastings
Comment options

@marcvanzee
Comment options

marcvanzee Dec 7, 2021
Maintainer Author

@DevPranjal
Comment options

Comment options

You must be logged in to vote
2 replies
@PhilipVinc
Comment options

@8bitmp3
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
6 participants