Skip to content

v0.1.9

Choose a tag to compare

@chrbrunk chrbrunk released this 26 Mar 10:52
· 3 commits to main since this release

This version includes the changes listed below.

Furthermore, note that this release is planned to be one of the last of type v0.1.x, with a v0.2.0 release planned for release within the next two months. The v0.2.0 release will contain many new features, but also some breaking API changes, however, we'll provide a detailed migration guide along with the release.

  • Migrating training pipeline from pmap to SPMD jax.jit with NamedSharding, enabling multi-host training on TPU and multi-GPU setups.
  • Adding support for multi-host data-parallel training.
  • Checkpointing now works across multiple hosts (only process 0 saves).
  • Lifting restrictions on the compatible versions of the orbax-checkpoint dependency.
  • Removing key field from TrainingState and random_key parameter from init_training_state. Old checkpoints containing a keyfield can still be restored; thekeyis skipped via Orbaxpartial_restore`.
  • Adding multi-host utilities: create_device_mesh, create_replicated_sharding, create_dp_sharding, and sync_string for cross-host communication.
  • Adding SYSTEM_METRICS log category for per-process runtime and throughput metrics. Important note: Some of these metrics were logged under TRAIN_METRICS before, so please update your custom loggers to keep logging these metrics.
  • Disabling async checkpointing for multi-host compatibility.
  • Deprecating should_parallelize parameter in TrainingLoop in favour of mesh.
  • Deprecating devices parameter in GraphDatasetBuilder.get_splits() in favour of mesh.
  • Updating pre-commit configuration to use ruff in place of isort, black and flake8.
  • Adding early stopping when a simulation has exploded, meaning that its temperature is NaN or greater than 1e6.
  • Fixing bug in ASE simulation engine: not using fixed random seed made runs irreproducible.
  • Enabling the initialization of a simulation using 2D (num_atoms, 3) arrays for positions and velocities. This allows for initializing a simulation using a single frame, rather than restoring from 3D arrays containing a multi-step trajectory.