-
Notifications
You must be signed in to change notification settings - Fork 232
Open
Labels
type:bugSomething isn't workingSomething isn't working
Description
Expected Behavior
Pass.
Actual Behavior
Failed with the following error:
File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 361, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/tree_util.py", line 361, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^
File "/usr/local/lib/python3.12/site-packages/qwix/_src/flax_util.py", line 293, in fn
return unbox(x.get_raw_value())
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/flax/nnx/variablelib.py", line 299, in __getattr__
return getattr(self.raw_value, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'get_raw_value'
One or more scripts failed!See more here.
Steps to Reproduce the Problem
unset JAX_PLATFORMS
pip list | egrep 'jax|flax|libtpu'
cd ..
git clone https://github.com/sgl-project/sglang-jax.git && cd sglang-jax/python && pip install -e . && cd ../..
pip install jax==0.8.1 flax==0.12.0 libtpu==0.0.24
pip list | egrep 'jax|flax|libtpu'
cd tunixRun python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine sglang_jax --enable-lora --lora-target-modules all.
Environment
- OS: [e.g., Ubuntu, etc.]
- Project Version: [e.g., 0.0.1]
Checklist
- I have searched the existing issues for a similar bug report.
- I have provided all the required information in the "Environment" section.
- I have provided a minimal, reproducible example.
Would you like to help us fix it?
Metadata
Metadata
Assignees
Labels
type:bugSomething isn't workingSomething isn't working