v0.1.9
This version includes the changes listed below.
Furthermore, note that this release is planned to be one of the last of type v0.1.x, with a v0.2.0 release planned for release within the next two months. The v0.2.0 release will contain many new features, but also some breaking API changes, however, we'll provide a detailed migration guide along with the release.
- Migrating training pipeline from
pmapto SPMDjax.jitwithNamedSharding, enabling multi-host training on TPU and multi-GPU setups. - Adding support for multi-host data-parallel training.
- Checkpointing now works across multiple hosts (only process 0 saves).
- Lifting restrictions on the compatible versions of the
orbax-checkpointdependency. - Removing
keyfield fromTrainingStateandrandom_keyparameter frominit_training_state. Old checkpoints containing a keyfield can still be restored; thekeyis skipped via Orbaxpartial_restore`. - Adding multi-host utilities:
create_device_mesh,create_replicated_sharding,create_dp_sharding, andsync_stringfor cross-host communication. - Adding
SYSTEM_METRICSlog category for per-process runtime and throughput metrics. Important note: Some of these metrics were logged underTRAIN_METRICSbefore, so please update your custom loggers to keep logging these metrics. - Disabling async checkpointing for multi-host compatibility.
- Deprecating
should_parallelizeparameter inTrainingLoopin favour ofmesh. - Deprecating
devicesparameter inGraphDatasetBuilder.get_splits()in favour ofmesh. - Updating pre-commit configuration to use ruff in place of isort, black and flake8.
- Adding early stopping when a simulation has exploded, meaning that its temperature is NaN or greater than 1e6.
- Fixing bug in ASE simulation engine: not using fixed random seed made runs irreproducible.
- Enabling the initialization of a simulation using 2D
(num_atoms, 3)arrays for positions and velocities. This allows for initializing a simulation using a single frame, rather than restoring from 3D arrays containing a multi-step trajectory.