New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HOWTO: Per example gradient #306
Conversation
Codecov Report
@@ Coverage Diff @@
## master #306 +/- ##
=======================================
Coverage 79.33% 79.33%
=======================================
Files 34 34
Lines 2255 2255
=======================================
Hits 1789 1789
Misses 466 466 Continue to review full report at Codecov.
|
+ grad = jax.tree_map(mean_fn, grads) | ||
+ | ||
optimizer = optimizer.apply_gradient(grad) | ||
metrics = compute_metrics(logits, batch['label']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check compute_metrics. It probably computes the loss a second time. Instead, we should use jax.value_and_grad
where vmap_loss_grad is computed.
We've redone our HOWTO system. Here's an example of the new format: Rendered in ReadTheDocs: https://flax.readthedocs.io/en/latest/howtos/ensembling.html We'd happy take pull requests revamping this HOWTO to the new system. |
Looks like this PR is inactive, so I'll close it. We're tracking the HOWTO request here: #858 |
This howto exemplifies how per-example gradients can be retrieved using
jax.vmap
.