Skip to content

Updates jax instructions for Spark#5884

Merged
kellyguo11 merged 2 commits into
isaac-sim:developfrom
kellyguo11:spark-jax
May 31, 2026
Merged

Updates jax instructions for Spark#5884
kellyguo11 merged 2 commits into
isaac-sim:developfrom
kellyguo11:spark-jax

Conversation

@kellyguo11
Copy link
Copy Markdown
Contributor

Description

Updates documentation for Spark to install cuda 13 jax. Remove limitation of SKRL Jax training on Spark.

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • Documentation update

Checklist

  • I have read and understood the contribution guidelines
  • I have run the pre-commit checks with ./isaaclab.sh --format
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the changelog and the corresponding version in the extension's config/extension.toml file
  • I have added my name to the CONTRIBUTORS.md or my name already exists there

@github-actions github-actions Bot added the documentation Improvements or additions to documentation label May 30, 2026
Copy link
Copy Markdown

@isaaclab-review-bot isaaclab-review-bot Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation Review

Thanks for adding CUDA 13 / aarch64 instructions and removing the outdated Spark limitation! The tab-set structure is a clean improvement for platform-specific guidance.

I found one critical issue with the version constraint in the aarch64 tab that would result in a broken (CPU-only) installation.


Update (commit ae1059c): Previous concerns addressed. ✅

  • Bug (jax[cuda13]<0.6.0 empty range): Fixed — version constraint removed, so pip will now correctly resolve a JAX version that includes the cuda13 extra.
  • Nit (blanket version guidance line): Fixed — the contradictory guidance line was removed entirely.

Note on new change: The <0.6.0 and flax<0.10.7 constraints were also removed from the x86_64/CUDA 12 tab. If the CuDNN v9.7 vs v9.8 conflict with PyTorch 2.7 has been resolved upstream, this is fine. Otherwise, users on x86_64 may hit CuDNN version mismatches with unconstrained JAX. Flagging for awareness — no action needed if the team has validated this.

Comment thread docs/source/overview/reinforcement-learning/rl_existing_scripts.rst Outdated
Comment thread docs/source/overview/reinforcement-learning/rl_existing_scripts.rst Outdated
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 30, 2026

Greptile Summary

This PR updates the SKRL+JAX installation documentation to support DGX Spark (aarch64) by splitting the install instructions into two tabs — CUDA 12 for x86_64 and CUDA 13 for aarch64 — and removes the previously listed Spark limitation around JAX GPU support.

  • rl_existing_scripts.rst: Adds a nested tab-set under the JAX tab-item with separate pip install commands for x86_64 (CUDA 12) and aarch64 (CUDA 13); updates the JAX docs link from jax.readthedocs.io to docs.jax.dev; removes the no-longer-applicable CuDNN incompatibility note.
  • index.rst: Drops the Spark-specific limitation bullet that stated JAX SKRL training was unvalidated and CPU-only on DGX Spark.

Confidence Score: 3/5

The core intent is correct but the aarch64 CUDA 13 install command uses a version constraint that likely resolves to no valid package, breaking the primary use case this PR enables.

The aarch64 install command jax[cuda13]<0.6.0 combines two constraints that are almost certainly mutually exclusive: the CUDA 13 pip extra was introduced in JAX 0.6.0+, while the upper bound <0.6.0 excludes every version that supports it. DGX Spark users following these docs would hit a broken install path.

docs/source/overview/reinforcement-learning/rl_existing_scripts.rst — specifically the aarch64 CUDA 13 tab-item around line 282

Important Files Changed

Filename Overview
docs/source/overview/reinforcement-learning/rl_existing_scripts.rst Splits JAX SKRL install instructions into x86_64 (CUDA 12) and aarch64 (CUDA 13) tabs; the jax[cuda13]<0.6.0 version spec in the aarch64 tab is likely unsatisfiable because CUDA 13 support wasn't introduced until JAX 0.6.0+.
docs/source/setup/installation/index.rst Removes the Spark-specific JAX/SKRL limitation bullet point now that CUDA 13 JAX is supported; the surrounding list remains consistent.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[User installs SKRL with JAX backend] --> B{Architecture?}
    B -- x86_64 --> C["jax[cuda12]<0.6.0 + flax<0.10.7"]
    B -- aarch64 / DGX Spark --> D["jax[cuda13]<0.6.0 + flax<0.10.7"]
    C --> E[skrl jax dependencies installed]
    D --> F["Likely empty version range\n(cuda13 extra introduced in 0.6.0+)"]
    F --> G[pip error or CPU-only JAX]
    E --> H[Run train.py --ml_framework jax]
Loading

Reviews (1): Last reviewed commit: "Updates jax instructions for Spark" | Re-trigger Greptile

Comment thread docs/source/overview/reinforcement-learning/rl_existing_scripts.rst Outdated
Comment thread docs/source/overview/reinforcement-learning/rl_existing_scripts.rst Outdated
@kellyguo11 kellyguo11 merged commit e7f3b62 into isaac-sim:develop May 31, 2026
37 checks passed
kellyguo11 added a commit that referenced this pull request May 31, 2026
…nvs, shutdown error, mjcf fix (#5888)

# Description

Cherry-pick PRs from develop:

- #5866 
- #5879 
- #5882 
- #5884
- #5889

---------

Signed-off-by: Yize Wang <yizew@nvidia.com>
Signed-off-by: Kelly Guo <kellyg@nvidia.com>
Co-authored-by: YizeWang <37894497+YizeWang@users.noreply.github.com>
Co-authored-by: Yize Wang <yizew@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant