Updates jax instructions for Spark#5884
Conversation
There was a problem hiding this comment.
Documentation Review
Thanks for adding CUDA 13 / aarch64 instructions and removing the outdated Spark limitation! The tab-set structure is a clean improvement for platform-specific guidance.
I found one critical issue with the version constraint in the aarch64 tab that would result in a broken (CPU-only) installation.
Update (commit ae1059c): Previous concerns addressed. ✅
- ✅ Bug (jax[cuda13]<0.6.0 empty range): Fixed — version constraint removed, so
pipwill now correctly resolve a JAX version that includes thecuda13extra. - ✅ Nit (blanket version guidance line): Fixed — the contradictory guidance line was removed entirely.
Note on new change: The <0.6.0 and flax<0.10.7 constraints were also removed from the x86_64/CUDA 12 tab. If the CuDNN v9.7 vs v9.8 conflict with PyTorch 2.7 has been resolved upstream, this is fine. Otherwise, users on x86_64 may hit CuDNN version mismatches with unconstrained JAX. Flagging for awareness — no action needed if the team has validated this.
Greptile SummaryThis PR updates the SKRL+JAX installation documentation to support DGX Spark (aarch64) by splitting the install instructions into two tabs — CUDA 12 for x86_64 and CUDA 13 for aarch64 — and removes the previously listed Spark limitation around JAX GPU support.
Confidence Score: 3/5The core intent is correct but the aarch64 CUDA 13 install command uses a version constraint that likely resolves to no valid package, breaking the primary use case this PR enables. The aarch64 install command docs/source/overview/reinforcement-learning/rl_existing_scripts.rst — specifically the aarch64 CUDA 13 tab-item around line 282 Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[User installs SKRL with JAX backend] --> B{Architecture?}
B -- x86_64 --> C["jax[cuda12]<0.6.0 + flax<0.10.7"]
B -- aarch64 / DGX Spark --> D["jax[cuda13]<0.6.0 + flax<0.10.7"]
C --> E[skrl jax dependencies installed]
D --> F["Likely empty version range\n(cuda13 extra introduced in 0.6.0+)"]
F --> G[pip error or CPU-only JAX]
E --> H[Run train.py --ml_framework jax]
Reviews (1): Last reviewed commit: "Updates jax instructions for Spark" | Re-trigger Greptile |
…nvs, shutdown error, mjcf fix (#5888) # Description Cherry-pick PRs from develop: - #5866 - #5879 - #5882 - #5884 - #5889 --------- Signed-off-by: Yize Wang <yizew@nvidia.com> Signed-off-by: Kelly Guo <kellyg@nvidia.com> Co-authored-by: YizeWang <37894497+YizeWang@users.noreply.github.com> Co-authored-by: Yize Wang <yizew@nvidia.com>
Description
Updates documentation for Spark to install cuda 13 jax. Remove limitation of SKRL Jax training on Spark.
Type of change
Checklist
pre-commitchecks with./isaaclab.sh --formatconfig/extension.tomlfileCONTRIBUTORS.mdor my name already exists there