You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Everything works well in the main training loop, but I meet errors when it goes into logging_steps:
Traceback (most recent call last):
File "/home/jnguan/code/zett/train.py", line 1605, in <module>
main()
File "/home/jnguan/code/zett/train.py", line 1516, in main
lambda x: x.flatten(), stack_forest(train_metrics)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 69, in stack_forest
return jax.tree_util.tree_map(stack_args, *forest)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 68, in <lambda>
stack_args = lambda *args: np.stack(args)
^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in stack
arrays = [asanyarray(arr) for arr in arrays]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in <listcomp>
arrays = [asanyarray(arr) for arr in arrays]
^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 390, in __array__
return np.asarray(self._value, dtype=dtype)
^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 588, in _value
if self.is_fully_replicated:
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 354, in is_fully_replicated
return self.sharding.is_fully_replicated
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'UnspecifiedValue' object has no attribute 'is_fully_replicated'
I tried to train a hypernetwork with English and Chinese dataset, and transfer a bilingual tokenizer for TinyLlama.
My devices are 2 * A100 80G, with CUDA driver version 12.2
My config is:
data/langs.txt
isEverything works well in the main training loop, but I meet errors when it goes into logging_steps:
Full log:
zett-142044.log
My environment:
The text was updated successfully, but these errors were encountered: