Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]RuntimeError: cuda rng state model-parallel-rng is not added #3583

Closed
iamsile opened this issue May 21, 2023 · 5 comments
Closed

[BUG]RuntimeError: cuda rng state model-parallel-rng is not added #3583

iamsile opened this issue May 21, 2023 · 5 comments
Labels
bug Something isn't working deepspeed-chat Related to DeepSpeed-Chat

Comments

@iamsile
Copy link

iamsile commented May 21, 2023

Describe the bug
Running an error

Log output
Traceback (most recent call last):
File "/xxxxx/deepspeed_chat/training/step2_reward_model_finetuning/main.py", line 472, in
main()
File "/xxxxx/deepspeed_chat/training/step2_reward_model_finetuning/main.py", line 393, in main
outputs = rm_model(**batch, use_cache=False)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1731, in forward
loss = self.module(*inputs, **kwargs)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/xxxxx/deepspeed_chat/training/utils/model/reward_model.py", line 53, in forward
transformer_outputs = self.rwtranrsformer(
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/THUDM/visualglm-6b/533d4ae86f232b1d7d04417398f572f71751c77d/modeling_chatglm.py", line 1462, in forward
image_embeds = self.image_encoder(images)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/THUDM/visualglm-6b/533d4ae86f232b1d7d04417398f572f71751c77d/visual.py", line 69, in forward
enc = self.vit(image)[0]
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/THUDM/visualglm-6b/533d4ae86f232b1d7d04417398f572f71751c77d/visual.py", line 28, in forward
return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/sat/model/base_model.py", line 144, in forward
return self.transformer(*args, **kwargs)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/sat/model/transformer.py", line 569, in forward
layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer,
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/sat/model/transformer.py", line 330, in forward
return HOOKS_DEFAULT['layer_forward'](self, hidden_states, mask, *args, **kw_args)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/sat/transformer_defaults.py", line 127, in layer_forward_default
attention_output = self.attention(attention_input, mask, **kw_args)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/sat/model/transformer.py", line 103, in forward
return HOOKS_DEFAULT['attention_forward'](self, hidden_states, mask, **kw_args)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/sat/transformer_defaults.py", line 63, in attention_forward_default
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/sat/transformer_defaults.py", line 38, in standard_attention
with mpu.get_cuda_rng_tracker().fork():
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/contextlib.py", line 113, in enter
return next(self.gen)
File "/opt/conda/envs/rlhf_tw_test/lib/python3.8/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 174, in fork
raise Exception('cuda rng state {} is not added'.format(name))
Exception: cuda rng state model-parallel-rng is not added

Execute script:
deepspeed --include localhost:7 /xxxxx/deepspeed_chat/training/step2_reward_model_finetuning/main.py --data_output_path /xxxxx/deepspeed_chat/training/step2_reward_model_finetuning/data_output_zh_multi_model --model_name_or_path THUDM/visualglm-6b --num_padding_at_beginning 0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 5e-4 --num_train_epochs 1 --gradient_accumulation_steps 1 --num_warmup_steps 0 --zero_stage 2 --deepspeed --output_dir /xxxxx/deepspeed_chat/training/step2_reward_model_finetuning/output_zh_multi_modal --lora_dim 2 --lora_module_name q_proj,k_proj --only_optimize_lora

pytorch:2.0.1
CUDA: 11

System info:
No LSB modules are available.
Distributor ID: Ubuntu
Description: Ubuntu 18.04.5 LTS
Release: 18.04
Codename: bionic

@iamsile iamsile added bug Something isn't working training labels May 21, 2023
@jomayeri jomayeri added deepspeed-chat Related to DeepSpeed-Chat and removed training labels May 23, 2023
@zhangyuanscall
Copy link

me too

@WangYuxiang8
Copy link

any update please?

@jimmieliu
Copy link

same here

@XaviLv
Copy link

XaviLv commented Aug 19, 2023

I kind of forgetting registering _CUDA_RNG_STATE_TRACKER. The " cuda rng state model-parallel-rng is not added" never happens after code below. Hopes helping u guys.

from deepspeed.runtime.activation_checkpointing.checkpointing import _CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME
 _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1)

See: SwissArmyTransformer

@iamsile
Copy link
Author

iamsile commented Aug 21, 2023

I kind of forgetting registering _CUDA_RNG_STATE_TRACKER. The " cuda rng state model-parallel-rng is not added" never happens after code below. Hopes helping u guys.

from deepspeed.runtime.activation_checkpointing.checkpointing import _CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME
 _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1)

See: SwissArmyTransformer

Thank you very much!It works!

@iamsile iamsile closed this as completed Aug 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working deepspeed-chat Related to DeepSpeed-Chat
Projects
None yet
Development

No branches or pull requests

6 participants