You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In order to run the code locally I have cloned the colab notebook, and finally have set up the environment.
Yet when running the code from the notebook there was this error occured, and the console output like below:
/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/flax/optim/base.py:52: DeprecationWarning: Use optax instead of flax.optim. Refer to the update guide https://flax.readthedocs.io/en/latest/howtos/optax_update_guide.html for detailed instructions.
'for detailed instructions.', DeprecationWarning)
/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py:183: UserWarning: pjit is an experimental feature and probably has bugs!
warn("pjit is an experimental feature and probably has bugs!")
/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/_src/lib/xla_bridge.py:430: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
"jax.host_count has been renamed to jax.process_count. This alias "
Traceback (most recent call last):
File "/mnt/fast/lwd/aisheet/test.py", line 252, in
inference_model = InferenceModel(checkpoint_path, MODEL)
File "/mnt/fast/lwd/aisheet/test.py", line 88, in init
self.restore_from_checkpoint(checkpoint_path)
File "/mnt/fast/lwd/aisheet/test.py", line 134, in restore_from_checkpoint
[restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 522, in from_checkpoint_or_scratch
return (self.from_checkpoint(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng)
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 508, in from_checkpoint
self.from_checkpoints(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng))
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 466, in from_checkpoints
yield _restore_path(path, restore_cfg)
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 458, in _restore_path
fallback_state=fallback_state)
File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 880, in restore
return self._restore_train_state(state_dict)
File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 891, in _restore_train_state
train_state, train_state_axes)
File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 639, in move_params_to_devices
train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 729, in call
return self._pjitted_fn(*args)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 266, in wrapped
args_flat, params, _, out_tree, _ = infer_params(*args, **kwargs)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 250, in infer_params
tuple(isinstance(a, GDA) for a in args_flat))
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, *args)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 385, in _pjit_jaxpr
allow_uneven_sharding=False)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 581, in _check_shapes_against_resources
raise ValueError(f"One of {what} was given the resource assignment " ValueError: One of pjit arguments was given the resource assignment of PartitionSpec(None, 'model'), which implies that the size of its dimension 1 should be divisible by 3, but it is equal to 1024
This occurs when executing the line around 252: inference_model = InferenceModel(checkpoint_path, MODEL)
I have totally no idea why this happened, hoping you guys could help me work this out, thanks!
The text was updated successfully, but these errors were encountered:
To anyone who come across this bug: I manage to solve this problem by limit the number of cuda device that is visible to python.
The pre-trained weight only apply to one gpu so simply add this in the head of your code will solve this:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
In order to run the code locally I have cloned the colab notebook, and finally have set up the environment.
Yet when running the code from the notebook there was this error occured, and the console output like below:
/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/flax/optim/base.py:52: DeprecationWarning: Use
optax
instead offlax.optim
. Refer to the update guide https://flax.readthedocs.io/en/latest/howtos/optax_update_guide.html for detailed instructions.'for detailed instructions.', DeprecationWarning)
/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py:183: UserWarning: pjit is an experimental feature and probably has bugs!
warn("pjit is an experimental feature and probably has bugs!")
/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/_src/lib/xla_bridge.py:430: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
"jax.host_count has been renamed to jax.process_count. This alias "
Traceback (most recent call last):
File "/mnt/fast/lwd/aisheet/test.py", line 252, in
inference_model = InferenceModel(checkpoint_path, MODEL)
File "/mnt/fast/lwd/aisheet/test.py", line 88, in init
self.restore_from_checkpoint(checkpoint_path)
File "/mnt/fast/lwd/aisheet/test.py", line 134, in restore_from_checkpoint
[restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 522, in from_checkpoint_or_scratch
return (self.from_checkpoint(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng)
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 508, in from_checkpoint
self.from_checkpoints(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng))
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 466, in from_checkpoints
yield _restore_path(path, restore_cfg)
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 458, in _restore_path
fallback_state=fallback_state)
File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 880, in restore
return self._restore_train_state(state_dict)
File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 891, in _restore_train_state
train_state, train_state_axes)
File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 639, in move_params_to_devices
train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 729, in call
return self._pjitted_fn(*args)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 266, in wrapped
args_flat, params, _, out_tree, _ = infer_params(*args, **kwargs)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 250, in infer_params
tuple(isinstance(a, GDA) for a in args_flat))
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, *args)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 385, in _pjit_jaxpr
allow_uneven_sharding=False)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 581, in _check_shapes_against_resources
raise ValueError(f"One of {what} was given the resource assignment "
ValueError: One of pjit arguments was given the resource assignment of PartitionSpec(None, 'model'), which implies that the size of its dimension 1 should be divisible by 3, but it is equal to 1024
This occurs when executing the line around 252: inference_model = InferenceModel(checkpoint_path, MODEL)
I have totally no idea why this happened, hoping you guys could help me work this out, thanks!
The text was updated successfully, but these errors were encountered: