NOTE: this is the last version which explicitly sets jax.config.jax_pmap_shmap_merge to False. After this, users must set this is their codebase is designed to use the old pmap. kfac_jax now supports both old and new pmap (though this might change in the future so that only new is supported).
What's Changed
- has_aux fix by @copybara-service[bot] in #341
-
- Adding a basic implementation of an adaptive technique to set the initial damping value (used in the automatic damping adaptation). by @copybara-service[bot] in #346
- Adding a note that clarifies the format of the array arguments to the optimizer's step() and init() functions. by @copybara-service[bot] in #347
- Filtering out only scalar values to be logged in polyak stats by @copybara-service[bot] in #348
- Enable training on a fixed number of batches from the training dataset in a pre-emption safe way. by @copybara-service[bot] in #352
- Minor code quality improvements by @copybara-service[bot] in #353
- Deterministic resume when num_batches specified. by @copybara-service[bot] in #354
- Silence pytype error for deprecated JAX API by @copybara-service[bot] in #355
- Remove layer tags from processed jaxpr in kfac_jax transforms. This is required as the initial layer tags for the underlying function are not valid anymore once we apply one of the kfac_jax transforms. In general, this is a necessity for subsequently using the jaxpr of the transformed function, either within or outside the kfac_jax framework. by @copybara-service[bot] in #357
- Fixing broken test by @copybara-service[bot] in #360
- Internal Change by @copybara-service[bot] in #359
- Internal Change by @copybara-service[bot] in #361
- Fixing a PyType issue caused by recent JAX CL. by @copybara-service[bot] in #365
- Ignore pytype errors produced with --use-functools-partial-overlay by @copybara-service[bot] in #364
- Minor improvement to schedule code for examples by @copybara-service[bot] in #366
- Removing broken "mask" feature from sigmoid_cross_entropy loss in examples code. by @copybara-service[bot] in #372
- [kfac_jax] Prepare for
jax_pmap_shmap_merge=True. by @copybara-service[bot] in #373 - Improve logging of parameter registrations. by @copybara-service[bot] in #375
- Excluding opt_state from eval worker when possible. by @copybara-service[bot] in #374
- Updates to schedules module in examples code: by @copybara-service[bot] in #377
-
- Fixing issue that broke the tracer and scanner logic when a layer had a literal (i.e. a constant) as input. by @copybara-service[bot] in #378
- Minor fixes to docstrings. by @copybara-service[bot] in #376
- [pmap] Make kfac_jax get_first more robust under jax_pmap_shmap_merge by @copybara-service[bot] in #379
- Bumping version number in preparation for next official PyPI release. by @copybara-service[bot] in #381
- Adding function using_legacy_pmap() to detect when using the legacy pmap is being used (i.e. when jax.config.jax_pmap_shmap_merge exists and is False). This should allow the 0.0.8 release to continue working when jax.config.jax_pmap_shmap_merge is removed from JAX, while also supporting older versions of JAX. by @copybara-service[bot] in #382
Full Changelog: v0.0.7...v0.0.8