Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when trying to load the model #27

Closed
levscaut opened this issue Mar 9, 2022 · 1 comment
Closed

Error when trying to load the model #27

levscaut opened this issue Mar 9, 2022 · 1 comment

Comments

@levscaut
Copy link

levscaut commented Mar 9, 2022

In order to run the code locally I have cloned the colab notebook, and finally have set up the environment.
Yet when running the code from the notebook there was this error occured, and the console output like below:

/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/flax/optim/base.py:52: DeprecationWarning: Use optax instead of flax.optim. Refer to the update guide https://flax.readthedocs.io/en/latest/howtos/optax_update_guide.html for detailed instructions.
'for detailed instructions.', DeprecationWarning)
/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py:183: UserWarning: pjit is an experimental feature and probably has bugs!
warn("pjit is an experimental feature and probably has bugs!")
/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/_src/lib/xla_bridge.py:430: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
"jax.host_count has been renamed to jax.process_count. This alias "
Traceback (most recent call last):
File "/mnt/fast/lwd/aisheet/test.py", line 252, in
inference_model = InferenceModel(checkpoint_path, MODEL)
File "/mnt/fast/lwd/aisheet/test.py", line 88, in init
self.restore_from_checkpoint(checkpoint_path)
File "/mnt/fast/lwd/aisheet/test.py", line 134, in restore_from_checkpoint
[restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 522, in from_checkpoint_or_scratch
return (self.from_checkpoint(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng)
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 508, in from_checkpoint
self.from_checkpoints(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng))
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 466, in from_checkpoints
yield _restore_path(path, restore_cfg)
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 458, in _restore_path
fallback_state=fallback_state)
File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 880, in restore
return self._restore_train_state(state_dict)
File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 891, in _restore_train_state
train_state, train_state_axes)
File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 639, in move_params_to_devices
train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 729, in call
return self._pjitted_fn(*args)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 266, in wrapped
args_flat, params, _, out_tree, _ = infer_params(*args, **kwargs)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 250, in infer_params
tuple(isinstance(a, GDA) for a in args_flat))
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, *args)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 385, in _pjit_jaxpr
allow_uneven_sharding=False)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 581, in _check_shapes_against_resources
raise ValueError(f"One of {what} was given the resource assignment "
ValueError: One of pjit arguments was given the resource assignment of PartitionSpec(None, 'model'), which implies that the size of its dimension 1 should be divisible by 3, but it is equal to 1024

This occurs when executing the line around 252: inference_model = InferenceModel(checkpoint_path, MODEL)
I have totally no idea why this happened, hoping you guys could help me work this out, thanks!

@levscaut
Copy link
Author

To anyone who come across this bug: I manage to solve this problem by limit the number of cuda device that is visible to python.
The pre-trained weight only apply to one gpu so simply add this in the head of your code will solve this:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant