Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA error: too many resources requested for launch (V100, qwen2-vl) #1867

Open
Jiax323 opened this issue Aug 30, 2024 · 15 comments
Open

CUDA error: too many resources requested for launch (V100, qwen2-vl) #1867

Jiax323 opened this issue Aug 30, 2024 · 15 comments

Comments

@Jiax323
Copy link

Jiax323 commented Aug 30, 2024

RuntimeError: CUDA error: too many resources requested for launch
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

V100 lora 微调 qwenvl2-2b-instruct 出现上述错误

@wade30822
Copy link

+1

@Jintao-Huang
Copy link
Collaborator

#1860

@Jintao-Huang
Copy link
Collaborator

参考这里:https://swift.readthedocs.io/zh-cn/latest/Multi-Modal/qwen2-vl%E6%9C%80%E4%BD%B3%E5%AE%9E%E8%B7%B5.html#ocr

You can save memory by reducing SIZE_FACTOR=8 and MAX_PIXELS=602112.

@Jiax323
Copy link
Author

Jiax323 commented Sep 2, 2024

设置完SIZE_FACTOR=8 and MAX_PIXELS=602112 后 还是出现这个问题:
RuntimeError: CUDA error: too many resources requested for launch
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

@JackeyGuo
Copy link

是的,我也是遇到了这个问题,同样的V100+lora,看GPU占用不是爆显存,实际只占18个G

@lantudou
Copy link

lantudou commented Sep 3, 2024

我在V100上使用7B和2B模型推理统统复现了相同的错误,有意思的事情是这个错误跟显存无关,似乎是bfloat16数据类型导致的,V100我记得是不支持bf16数据格式的。使用--dtype fp32后推理正常

@lantudou
Copy link

lantudou commented Sep 3, 2024

@Jintao-Huang

@lantudou
Copy link

lantudou commented Sep 3, 2024

swift infer --model_type qwen2-vl-7b-instruct  --model_id_or_path /share_data/PRDATA/Qwen2-VL-7B-Instruct/ 
run sh: `/root/anaconda3/envs/qwenvl/bin/python /share_data/YuhaoSun/ms-swift/swift/cli/infer.py --model_type qwen2-vl-7b-instruct --model_id_or_path /share_data/PRDATA/Qwen2-VL-7B-Instruct/`
[INFO:swift] Successfully registered `/share_data/YuhaoSun/ms-swift/swift/llm/data/dataset_info.json`
[INFO:swift] No vLLM installed, if you are using vLLM, you will get `ImportError: cannot import name 'get_vllm_engine' from 'swift.llm'`
[INFO:swift] No LMDeploy installed, if you are using LMDeploy, you will get `ImportError: cannot import name 'prepare_lmdeploy_engine_template' from 'swift.llm'`
[INFO:swift] Start time of running main: 2024-09-03 06:08:16.803324
[INFO:swift] ckpt_dir: None
[INFO:swift] Due to `ckpt_dir` being `None`, `load_args_from_ckpt_dir` is set to `False`.
[INFO:swift] Setting template_type: qwen2-vl
[INFO:swift] Setting self.eval_human: True
[INFO:swift] Setting overwrite_generation_config: False
[INFO:swift] args: InferArguments(model_type='qwen2-vl-7b-instruct', model_id_or_path='/share_data/PRDATA/Qwen2-VL-7B-Instruct', model_revision='master', sft_type='full', template_type='qwen2-vl', infer_backend='pt', ckpt_dir=None, result_dir=None, load_args_from_ckpt_dir=False, load_dataset_config=False, eval_human=True, seed=42, dtype='AUTO', model_kwargs=None, dataset=[], val_dataset=[], dataset_seed=42, dataset_test_ratio=0.01, show_dataset_sample=-1, save_result=True, system=None, tools_prompt='react_en', max_length=None, truncation_strategy='delete', check_dataset_strategy='none', model_name=[None, None], model_author=[None, None], quant_method=None, quantization_bit=0, hqq_axis=0, hqq_dynamic_config_path=None, bnb_4bit_comp_dtype='AUTO', bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=True, bnb_4bit_quant_storage=None, max_new_tokens=2048, do_sample=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None, num_beams=1, stop_words=[], rope_scaling=None, use_flash_attn=None, ignore_args_error=False, stream=True, merge_lora=False, merge_device_map='cpu', save_safetensors=True, overwrite_generation_config=False, verbose=None, local_repo_path=None, custom_register_path=None, custom_dataset_info=None, device_map_config=None, device_max_memory=[], hub_token=None, gpu_memory_utilization=0.9, tensor_parallel_size=1, max_num_seqs=256, max_model_len=None, disable_custom_all_reduce=True, enforce_eager=False, vllm_enable_lora=False, vllm_max_lora_rank=16, lora_modules=[], tp=1, cache_max_entry_count=0.8, quant_policy=0, vision_batch_size=1, self_cognition_sample=0, train_dataset_sample=-1, val_dataset_sample=None, safe_serialization=None, model_cache_dir=None, merge_lora_and_save=None, custom_train_dataset_path=[], custom_val_dataset_path=[], vllm_lora_modules=None, device_map_config_path=None)
[INFO:swift] Global seed set to 42
[INFO:swift] device_count: 4
[INFO:swift] Loading the model using model_dir: /share_data/PRDATA/Qwen2-VL-7B-Instruct
[INFO:swift] Setting torch_dtype: torch.bfloat16
[INFO:swift] model_kwargs: {'low_cpu_mem_usage': True, 'device_map': 'auto'}
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:06<00:00,  1.23s/it]
[INFO:swift] model.max_model_len: 32768
[INFO:swift] model_config: Qwen2VLConfig {
  "_name_or_path": "/share_data/PRDATA/Qwen2-VL-7B-Instruct",
  "architectures": [
    "Qwen2VLForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 3584,
  "image_token_id": 151655,
  "initializer_range": 0.02,
  "intermediate_size": 18944,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen2_vl",
  "num_attention_heads": 28,
  "num_hidden_layers": 28,
  "num_key_value_heads": 4,
  "rms_norm_eps": 1e-06,
  "rope_scaling": {
    "mrope_section": [
      16,
      24,
      24
    ],
    "type": "mrope"
  },
  "rope_theta": 1000000.0,
  "sliding_window": 32768,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.45.0.dev0",
  "use_cache": true,
  "use_sliding_window": false,
  "video_token_id": 151656,
  "vision_config": {
    "in_chans": 3,
    "model_type": "qwen2_vl",
    "spatial_patch_size": 14
  },
  "vision_end_token_id": 151653,
  "vision_start_token_id": 151652,
  "vision_token_id": 151654,
  "vocab_size": 152064
}

[INFO:swift] model.generation_config: GenerationConfig {
  "bos_token_id": 151643,
  "do_sample": true,
  "eos_token_id": 151645,
  "max_new_tokens": 2048,
  "pad_token_id": 151643,
  "repetition_penalty": 1.05,
  "temperature": 0.1,
  "top_k": 1,
  "top_p": 0.001
}

[INFO:swift] [visual.patch_embed.proj.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.norm1.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.norm1.bias]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.norm2.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.norm2.bias]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.attn.qkv.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.attn.qkv.bias]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.attn.proj.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.attn.proj.bias]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.mlp.fc1.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.mlp.fc1.bias]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.mlp.fc2.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.0.mlp.fc2.bias]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.1.norm1.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.1.norm1.bias]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.1.norm2.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.1.norm2.bias]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.1.attn.qkv.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.1.attn.qkv.bias]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] [visual.blocks.1.attn.proj.weight]: requires_grad=False, dtype=torch.bfloat16, device=cuda:0
[INFO:swift] ...
[INFO:swift] Qwen2VLForConditionalGeneration(
  (visual): Qwen2VisionTransformerPretrainedModel(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
    )
    (rotary_pos_emb): VisionRotaryEmbedding()
    (blocks): ModuleList(
      (0-31): 32 x Qwen2VLVisionBlock(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): VisionSdpaAttention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (mlp): VisionMlp(
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (act): QuickGELUActivation()
          (fc2): Linear(in_features=5120, out_features=1280, bias=True)
        )
      )
    )
    (merger): PatchMerger(
      (ln_q): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=5120, out_features=5120, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=5120, out_features=3584, bias=True)
      )
    )
  )
  (model): Qwen2VLModel(
    (embed_tokens): Embedding(152064, 3584)
    (layers): ModuleList(
      (0-27): 28 x Qwen2VLDecoderLayer(
        (self_attn): Qwen2VLSdpaAttention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), eps=1e-06)
  )
  (lm_head): Linear(in_features=3584, out_features=152064, bias=False)
)
[INFO:swift] Qwen2VLForConditionalGeneration: 8291.3756M Params (0.0000M Trainable [0.0000%]), 234.8828M Buffers.
[INFO:swift] system: You are a helpful assistant.
[INFO:swift] Input `exit` or `quit` to exit the conversation.
[INFO:swift] Input `multi-line` to switch to multi-line input mode.
[INFO:swift] Input `reset-system` to reset the system and clear the history.
[INFO:swift] Input `clear` to clear the history.
[INFO:swift] Please enter the conversation content first, followed by the path to the multimedia file.
<<< <image> describe this image
I'm sorry, but as an AI text-based model, I don't have the ability to see or describe images. However, if you can provide a detailed description of the image, I'll do my best to assist you with any questions or information you need.
--------------------------------------------------
<<< <image> describe this.
Input an image path or URL <<< /share_data/PRDATA/99.jpg
[INFO:swift] Setting size_factor: 28. You can adjust this hyperparameter through the environment variable: `SIZE_FACTOR`.
[INFO:swift] Setting resized_height: None. You can adjust this hyperparameter through the environment variable: `RESIZED_HEIGHT`.
[INFO:swift] Setting resized_width: None. You can adjust this hyperparameter through the environment variable: `RESIZED_WIDTH`.
[INFO:swift] Setting min_pixels: 3136. You can adjust this hyperparameter through the environment variable: `MIN_PIXELS`.
[INFO:swift] Setting max_pixels: 12845056. You can adjust this hyperparameter through the environment variable: `MAX_PIXELS`.
Exception in thread Thread-3:
Traceback (most recent call last):
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/threading.py", line 973, in _bootstrap_inner
    self.run()
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "/share_data/YuhaoSun/ms-swift/swift/llm/utils/utils.py", line 695, in _model_generate
    res = model.generate(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/transformers/generation/utils.py", line 2013, in generate
    result = self._sample(
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/transformers/generation/utils.py", line 2959, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 1585, in forward
    image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 1017, in forward
    hidden_states = self.patch_embed(hidden_states)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 243, in forward
    hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 608, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/root/anaconda3/envs/qwenvl/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 603, in _conv_forward
    return F.conv3d(
RuntimeError: CUDA error: too many resources requested for launch
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.    
```  完整的报错trace信息粘贴在这里,我还没有试过微调但应该是类似的原因

@Jintao-Huang
Copy link
Collaborator

加个 SIZE_FACTOR=8 MAX_PIXELS=602112

@Jintao-Huang
Copy link
Collaborator

--dtype fp32

@Jintao-Huang Jintao-Huang changed the title CUDA error: too many resources requested for launch CUDA error: too many resources requested for launch (V100, qwen2-vl) Sep 3, 2024
@Jiax323
Copy link
Author

Jiax323 commented Sep 4, 2024

微调的话用--dtype fp32 会爆显存的....之前的qwen_vl是可以用bf16微调的

@wade30822
Copy link

oom occured even using qwen2-vl-2b with SIZE_FACTOR=8 MAX_PIXELS=602112 --dtype fp32

@lantudou
Copy link

lantudou commented Sep 4, 2024

oom occured even using qwen2-vl-2b with SIZE_FACTOR=8 MAX_PIXELS=602112 --dtype fp32

CUDA_VISIBLE_DEVICES=0,1,2,3 NPROC_PER_NODE=4 swift sft --sft_type 'full' --dtype 'fp16' --use_liger 'True' --model_id_or_path '/share_data/PRDATA/Qwen2-VL-2B-Instruct/' --template_type 'qwen2-vl' --system 'You are a helpful assistant.' --dataset coco-en-mini --learning_rate '1e-05' --gradient_accumulation_steps '16' --eval_steps '500' --save_steps '500' --eval_batch_size '1' --model_type 'qwen2-vl-2b-instruct' --deepspeed default-zero3 --add_output_dir_suffix False --output_dir /root/output/qwen2-vl-2b-instruct/v16 full sft fp16 is ok for my 4 V100, For lora, I got some strange errors and waiting for fix

@JackeyGuo
Copy link

v100上可以不用设置SIZE_FACTOR和MAX_PIXELS,直接设置--dtype fp16 就可以训练

@tastelikefeet
Copy link
Collaborator

oom occured even using qwen2-vl-2b with SIZE_FACTOR=8 MAX_PIXELS=602112 --dtype fp32

CUDA_VISIBLE_DEVICES=0,1,2,3 NPROC_PER_NODE=4 swift sft --sft_type 'full' --dtype 'fp16' --use_liger 'True' --model_id_or_path '/share_data/PRDATA/Qwen2-VL-2B-Instruct/' --template_type 'qwen2-vl' --system 'You are a helpful assistant.' --dataset coco-en-mini --learning_rate '1e-05' --gradient_accumulation_steps '16' --eval_steps '500' --save_steps '500' --eval_batch_size '1' --model_type 'qwen2-vl-2b-instruct' --deepspeed default-zero3 --add_output_dir_suffix False --output_dir /root/output/qwen2-vl-2b-instruct/v16 full sft fp16 is ok for my 4 V100, For lora, I got some strange errors and waiting for fix

Can you share the error stack?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants