Skip to content

kfac_jax 0.0.8

Latest

Choose a tag to compare

@james-martens james-martens released this 25 Feb 00:40
· 18 commits to main since this release

NOTE: this is the last version which explicitly sets jax.config.jax_pmap_shmap_merge to False. After this, users must set this is their codebase is designed to use the old pmap. kfac_jax now supports both old and new pmap (though this might change in the future so that only new is supported).

What's Changed

  • has_aux fix by @copybara-service[bot] in #341
    • Adding a basic implementation of an adaptive technique to set the initial damping value (used in the automatic damping adaptation). by @copybara-service[bot] in #346
  • Adding a note that clarifies the format of the array arguments to the optimizer's step() and init() functions. by @copybara-service[bot] in #347
  • Filtering out only scalar values to be logged in polyak stats by @copybara-service[bot] in #348
  • Enable training on a fixed number of batches from the training dataset in a pre-emption safe way. by @copybara-service[bot] in #352
  • Minor code quality improvements by @copybara-service[bot] in #353
  • Deterministic resume when num_batches specified. by @copybara-service[bot] in #354
  • Silence pytype error for deprecated JAX API by @copybara-service[bot] in #355
  • Remove layer tags from processed jaxpr in kfac_jax transforms. This is required as the initial layer tags for the underlying function are not valid anymore once we apply one of the kfac_jax transforms. In general, this is a necessity for subsequently using the jaxpr of the transformed function, either within or outside the kfac_jax framework. by @copybara-service[bot] in #357
  • Fixing broken test by @copybara-service[bot] in #360
  • Internal Change by @copybara-service[bot] in #359
  • Internal Change by @copybara-service[bot] in #361
  • Fixing a PyType issue caused by recent JAX CL. by @copybara-service[bot] in #365
  • Ignore pytype errors produced with --use-functools-partial-overlay by @copybara-service[bot] in #364
  • Minor improvement to schedule code for examples by @copybara-service[bot] in #366
  • Removing broken "mask" feature from sigmoid_cross_entropy loss in examples code. by @copybara-service[bot] in #372
  • [kfac_jax] Prepare for jax_pmap_shmap_merge=True. by @copybara-service[bot] in #373
  • Improve logging of parameter registrations. by @copybara-service[bot] in #375
  • Excluding opt_state from eval worker when possible. by @copybara-service[bot] in #374
  • Updates to schedules module in examples code: by @copybara-service[bot] in #377
    • Fixing issue that broke the tracer and scanner logic when a layer had a literal (i.e. a constant) as input. by @copybara-service[bot] in #378
  • Minor fixes to docstrings. by @copybara-service[bot] in #376
  • [pmap] Make kfac_jax get_first more robust under jax_pmap_shmap_merge by @copybara-service[bot] in #379
  • Bumping version number in preparation for next official PyPI release. by @copybara-service[bot] in #381
  • Adding function using_legacy_pmap() to detect when using the legacy pmap is being used (i.e. when jax.config.jax_pmap_shmap_merge exists and is False). This should allow the 0.0.8 release to continue working when jax.config.jax_pmap_shmap_merge is removed from JAX, while also supporting older versions of JAX. by @copybara-service[bot] in #382

Full Changelog: v0.0.7...v0.0.8