0.12.6
What's Changed
- Annotate
wrapped_fn'sselfargument indecorator_lift_transformfor trace-ability by @copybara-service[bot] in #5298 - add with_attributes by @copybara-service[bot] in #5308
- Remove optax pin by @samanklesaria in #5292
- Removed redundant code checking wrt in nnx.Optimizer by @vfdev-5 in #5226
- Fix PyTreeNode + Generic losing parameters when Generic is last in bases by @mohsinm-dev in #5237
- add recursive_map example by @copybara-service[bot] in #5311
- Add sharding propagation support in nnx.eval_shape (clone of #5111) by @samanklesaria in #5247
- fix eval_shape's _to_variable by @copybara-service[bot] in #5316
- support out_sharding mapping in set_metadata by @copybara-service[bot] in #5312
- feat(nnx): add GQA support to MultiHeadAttention by @ayulockedin in #5259
- Allow DenyList to be compared. by @copybara-service[bot] in #5322
- add graph_updates argument to jit by @copybara-service[bot] in #5317
- add graph_updates argument in shard_map by @copybara-service[bot] in #5319
- Remove jax/tools/colab_tpu.py. by @copybara-service[bot] in #5324
- add graph_updates to vmap by @copybara-service[bot] in #5320
- Add split method to RngStream by @samanklesaria in #5270
- add graph_updates for scan by @copybara-service[bot] in #5327
- add graph_updates to while_loop by @copybara-service[bot] in #5328
- Added support for data masking in Average, Accuracy and MultiMetric by @vfdev-5 in #5326
- add graph_updates to fori_loop by @copybara-service[bot] in #5329
- add graph_updates to pmap, grad, and value_and_grad by @copybara-service[bot] in #5330
- add tree-mode-nnx FLIP by @copybara-service[bot] in #5310
- Copybara import of the project: by @copybara-service[bot] in #5331
- add graph_updates to remat by @copybara-service[bot] in #5336
- add graph_updates to eval_shape and checkify by @copybara-service[bot] in #5338
- add compat module by @copybara-service[bot] in #5340
- add more tests that check for consistent aliasing in transforms by @copybara-service[bot] in #5341
- don't allow Variable mutation in custom_vjp on differentiable arguments by @copybara-service[bot] in #5342
- support hijax and ref Variables in simple transforms by @copybara-service[bot] in #5345
- allow in/out_axes when graph_updates=False by @copybara-service[bot] in #5323
- add transform_metadata transform by @copybara-service[bot] in #5346
- clean up custom_vjp's graph_updates=False section by @copybara-service[bot] in #5350
- improve error messages for tree mode duplicates check by @copybara-service[bot] in #5347
- Introduce
manual_type: ManualAxisTypeparameter on ShapedArray to track varying/unreduced/reduced and removevma: frozensetparameter. by @copybara-service[bot] in #5339 - remove aliases from nnx.graph in favor of nnx.compat by @copybara-service[bot] in #5348
- check aliases on all transform args and simplify apply_variable_updates by @copybara-service[bot] in #5349
- Add intermediate value captures (extends #4925) by @samanklesaria in #5257
- fix jit_partial lower and compile by @copybara-service[bot] in #5355
- Added support for data masking in Average, Accuracy and MultiMetric by @vfdev-5 in #5332
- Do a few more cleanups after pmap_shmap merge by @copybara-service[bot] in #5358
- update to version 0.12.6 by @copybara-service[bot] in #5356
- Do a few more cleanups after pmap_shmap merge by @copybara-service[bot] in #5359
- Do a few more cleanups after the pmap shmap merge deletion. by @copybara-service[bot] in #5361
- simplify SimpleScan by @copybara-service[bot] in #5357
- improve nnx.Dict error handling by @copybara-service[bot] in #5362
- add graph node in prefix checks by @copybara-service[bot] in #5365
- improve aliasing error msg by @copybara-service[bot] in #5364
- Generalize out_sharding to work with NamedSharding and Format by @samanklesaria in #5246
- add nnx.map and nnx.abstract_with_sharding by @copybara-service[bot] in #5366
Full Changelog: v0.12.5...v0.12.6