Conversation
There was a problem hiding this comment.
Pull request overview
This pull request bumps the version from 0.0.7 to 0.0.8 and introduces several significant changes to the julax library and its examples. The PR focuses on improving the mesh management API, adding support for Splash Attention on TPU, refactoring the dataset pipeline, and making several performance and usability improvements.
Changes:
- Refactored mesh management from computed property to explicit context manager pattern, making mesh lifecycle more explicit and controllable
- Added Splash Attention support for TPU with TPU-optimized kernels for faster attention computation
- Replaced
opencv-pythonwithopencv-python-headlessacross examples to avoid GUI dependencies - Added
remat(rematerialization) option toRepeatlayer for gradient checkpointing to reduce memory usage during training - Changed loss accumulation from bfloat16 to float32 for improved numerical stability
- Refactored dataset pipeline with better resource management and state handling
- Added context manager support to Experiment class for proper resource cleanup
Reviewed changes
Copilot reviewed 19 out of 22 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| pyproject.toml | Version bump to 0.0.8 |
| uv.lock | Updated dependencies including julax version, opencv-python-headless, new etils and gcsfs packages |
| src/julax/init.py | Added exports for utility functions (identity, create_mesh, get_mesh) |
| src/julax/utils.py | Modified get_mesh to support abstract mesh fallback, added allow_split_physical_axes parameter |
| src/julax/layers/base.py | Removed donate_argnames from jit_forward decorator |
| src/julax/layers/core.py | Changed loss dtype from bfloat16 to float32 in Learner and Trainer |
| src/julax/layers/connectors.py | Added remat parameter to Repeat layer for gradient checkpointing |
| src/julax/experiment/run.py | Removed mesh context manager wrapper, now uses get_mesh() directly |
| src/julax/experiment/experiment.py | Removed mesh as computed property, added context manager methods, added start_prefetch calls |
| examples/03_Llama_3/training.py | Major refactor: removed create_experiment function, added main with mesh context and multiple model configurations |
| examples/03_Llama_3/model.py | Changed attention tensor layout from [B,S,N,H] to [B,N,S,H] (heads-first), added Splash Attention implementation, added debug model configuration |
| examples/03_Llama_3/dataset.py | Complete refactor with new iterator-based architecture using etils.epath for better resource management |
| examples/03_Llama_3/pyproject.toml | Added etils and gcsfs dependencies |
| examples/03_Llama_3/convert.py | Fixed hardcoded model path to use model_path parameter |
| examples/01_mnist/pyproject.toml | Changed opencv-python to opencv-python-headless |
| examples/03_Llama_3/debug/tokenizer/* | Added tokenizer configuration files for development |
| examples/03_Llama_3/README.md | Added performance optimization tracking documentation |
| Dockerfile | Added vim to installed packages |
| .gitignore | Added exception for debug data directory |
| .github/workflows/docker.yml | Restricted SHA tagging to main branch only |
| mesh = get_mesh() | ||
| splash_spec = P("data", None) | ||
|
|
||
| # The kernel (mask) should be replicated across data parallelism devices | ||
| replicated_sharding = jax.sharding.NamedSharding(mesh, P(None, None)) | ||
| kernel_spec = kernel.manual_sharding_spec(replicated_sharding) | ||
|
|
||
| @partial( | ||
| jax.shard_map, | ||
| mesh=mesh, | ||
| in_specs=( | ||
| kernel_spec, | ||
| splash_spec, | ||
| splash_spec, | ||
| splash_spec, | ||
| ), | ||
| out_specs=splash_spec, | ||
| check_vma=False, | ||
| ) | ||
| def sharded_kernel(kernel, q, k, v): | ||
| def _apply_kernel(q, k, v): | ||
| res = kernel(q, k, v) | ||
| if isinstance(res, tuple): | ||
| return res[0] | ||
| return res | ||
|
|
||
| return jax.vmap(_apply_kernel)(q, k, v) | ||
|
|
||
| def splash_attention(inputs): |
There was a problem hiding this comment.
The make_splash_attention_fn calls get_mesh() at function definition time (when the model is created), which captures the mesh for use in the jax.shard_map. This means the splash attention function will be bound to whatever mesh was active when create_model was called. This should work correctly with the new pattern where create_mesh is used as a context manager around model creation, but ensure that create_model is always called within a mesh context when attention_backend="splash" is used.
| mesh = get_mesh() | |
| splash_spec = P("data", None) | |
| # The kernel (mask) should be replicated across data parallelism devices | |
| replicated_sharding = jax.sharding.NamedSharding(mesh, P(None, None)) | |
| kernel_spec = kernel.manual_sharding_spec(replicated_sharding) | |
| @partial( | |
| jax.shard_map, | |
| mesh=mesh, | |
| in_specs=( | |
| kernel_spec, | |
| splash_spec, | |
| splash_spec, | |
| splash_spec, | |
| ), | |
| out_specs=splash_spec, | |
| check_vma=False, | |
| ) | |
| def sharded_kernel(kernel, q, k, v): | |
| def _apply_kernel(q, k, v): | |
| res = kernel(q, k, v) | |
| if isinstance(res, tuple): | |
| return res[0] | |
| return res | |
| return jax.vmap(_apply_kernel)(q, k, v) | |
| def splash_attention(inputs): | |
| def splash_attention(inputs): | |
| mesh = get_mesh() | |
| splash_spec = P("data", None) | |
| # The kernel (mask) should be replicated across data parallelism devices | |
| replicated_sharding = jax.sharding.NamedSharding(mesh, P(None, None)) | |
| kernel_spec = kernel.manual_sharding_spec(replicated_sharding) | |
| @partial( | |
| jax.shard_map, | |
| mesh=mesh, | |
| in_specs=( | |
| kernel_spec, | |
| splash_spec, | |
| splash_spec, | |
| splash_spec, | |
| ), | |
| out_specs=splash_spec, | |
| check_vma=False, | |
| ) | |
| def sharded_kernel(kernel, q, k, v): | |
| def _apply_kernel(q, k, v): | |
| res = kernel(q, k, v) | |
| if isinstance(res, tuple): | |
| return res[0] | |
| return res | |
| return jax.vmap(_apply_kernel)(q, k, v) |
There was a problem hiding this comment.
If it is not initialized within a mesh context, an error will be thrown based on my tests.
There was a problem hiding this comment.
@copilot create an issue to revisit the mesh related configs. Maybe create abstract mesh in the initialization & precompiling stage?
| def run(exp: Experiment, observer: ObserverBase) -> tuple[int, Param, State]: | ||
| with exp.mesh as mesh: | ||
| step, param, state, input_iter = exp.restore() | ||
| step, param, state, input_iter = exp.restore() | ||
| observer(step, exp, param, state) | ||
|
|
||
| compiled_step = None | ||
|
|
||
| for x_local in input_iter: | ||
| if exp.max_steps is not None and step >= exp.max_steps: | ||
| logger.info(f"Reached max steps {exp.max_steps}, stopping training.") | ||
| break | ||
| x = jax.make_array_from_process_local_data( | ||
| sharding=jax.sharding.NamedSharding( | ||
| get_mesh(), PartitionSpec(exp.batch_axis_names) | ||
| ), | ||
| local_data=x_local, | ||
| ) | ||
|
|
||
| if compiled_step is None: | ||
| compiled_step = exp.precompile(x, param, state) | ||
|
|
||
| with jax.profiler.StepTraceAnnotation("train", step_num=step): | ||
| param, state = compiled_step(x, param, state) | ||
| step += 1 | ||
|
|
||
| exp.save(step, param, state, input_iter) | ||
| observer(step, exp, param, state) | ||
|
|
||
| compiled_step = None | ||
|
|
||
| for x_local in input_iter: | ||
| if exp.max_steps is not None and step >= exp.max_steps: | ||
| logger.info(f"Reached max steps {exp.max_steps}, stopping training.") | ||
| break | ||
| x = jax.make_array_from_process_local_data( | ||
| sharding=jax.sharding.NamedSharding( | ||
| mesh, PartitionSpec(exp.batch_axis_names) | ||
| ), | ||
| local_data=x_local, | ||
| ) | ||
|
|
||
| if compiled_step is None: | ||
| compiled_step = exp.precompile(x, param, state) | ||
|
|
||
| with jax.profiler.StepTraceAnnotation("train", step_num=step): | ||
| param, state = compiled_step(x, param, state) | ||
| step += 1 | ||
|
|
||
| exp.save(step, param, state, input_iter) | ||
| observer(step, exp, param, state) | ||
| exp.close() | ||
| return step, param, state |
There was a problem hiding this comment.
The run function no longer calls exp.close() at the end. Previously, the mesh context manager ensured cleanup, but now that the mesh context is removed from run(), the experiment's resources (especially checkpoint_manager) may not be properly closed if an exception occurs during training. Consider wrapping the training loop in a try-finally block to ensure exp.close() is called, or document that callers should use the Experiment as a context manager.
| if self._file: | ||
| self._file.close() | ||
| self._file = None | ||
| if self._raw_ctx: | ||
| self._raw_ctx.__exit__(None, None, None) | ||
| self._raw_ctx = None |
There was a problem hiding this comment.
The close() method of JsonlDatasetIterator closes both the text wrapper and the raw context, but it doesn't handle potential exceptions during cleanup. If closing _file raises an exception, _raw_ctx.__exit__() won't be called, potentially leaving resources open. Consider using a try-finally pattern or catching exceptions to ensure both resources are cleaned up.
| if self._file: | |
| self._file.close() | |
| self._file = None | |
| if self._raw_ctx: | |
| self._raw_ctx.__exit__(None, None, None) | |
| self._raw_ctx = None | |
| file = self._file | |
| raw_ctx = self._raw_ctx | |
| # Clear references first to avoid leaving the iterator in a | |
| # partially-closed state if cleanup raises. | |
| self._file = None | |
| self._raw_ctx = None | |
| try: | |
| if file: | |
| file.close() | |
| finally: | |
| if raw_ctx: | |
| raw_ctx.__exit__(None, None, None) |
| def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: ... | ||
|
|
||
| @partial(jit, static_argnums=0, donate_argnames=("p", "s")) | ||
| @partial(jit, static_argnums=0) |
There was a problem hiding this comment.
Removing donate_argnames=("p", "s") from the jit_forward decorator may lead to increased memory usage during forward passes. The donation optimization allows JAX to reuse input buffers for outputs when they're no longer needed, which is particularly important for large model parameters and states. Consider whether this removal was intentional or if it might cause memory issues in training loops. Note that forward_and_backward in Trainer still uses donation, so the main training loop should be fine, but direct calls to jit_forward won't benefit from the optimization.
| @partial(jit, static_argnums=0) | |
| @partial(jit, static_argnums=0, donate_argnames=("p", "s")) |
There was a problem hiding this comment.
Donate is only needed in the trainer.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@findmyway I've opened a new pull request, #11, to work on those changes. Once the pull request is ready, I'll request review from you. |
No description provided.