Skip to content

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) #9130

@wolfvoid

Description

@wolfvoid

Checklist / 检查清单

  • I have searched existing issues, and this is a new bug report. / 我已经搜索过现有的 issues,确认这是一个新的 bug report.

Bug Description / Bug 描述

When training Qwen3.5-4B with GRPO using DeepSpeed multi-GPU setup, the training crashes during vLLM engine initialization with the following error:

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

The error occurs in vllm/model_executor/layers/mamba/gdn_linear_attn.py when the GDN attention kernel tries to access tensors on non-zero CUDA devices.

Environment / 环境信息

  • ms-swift: 4.2.0dev0
  • Python: 3.10
  • PyTorch: 2.10.0
  • Triton: 3.6.0
  • vLLM: 0.19.0 (V1 engine)
  • GPU: NVIDIA H800 × 8
  • DeepSpeed: zero2
  • Training Mode: GRPO (RLHF)

Root Cause Analysis / 根本原因分析

This is a Triton multi-GPU context mismatch bug (similar to triton#2441):

  1. DeepSpeed launches 8 processes (rank0-7), each bound to a different GPU (cuda:0 to cuda:7)
  2. vLLM V1 initializes the Qwen3.5 model which uses GDN (Global Deep Network) Attention
  3. GDN attention internally calls Triton kernels (e.g., fused_gdn_gating_kernel)
  4. Triton kernels default to cuda:0 device context regardless of which GPU the process is using
  5. When rank2 (on cuda:2) tries to launch the kernel with tensors on cuda:2, the cuda:0 context cannot access them
  6. cuPointerGetAttribute returns CUDA_ERROR_INVALID_VALUE, causing the error

Why this happens specifically with Qwen3.5:

  • Qwen3.5 uses a custom GDN attention mechanism in vLLM
  • The VLLM_ATTENTION_BACKEND environment variable does NOT affect Qwen3.5 (it forces GDN usage)
  • This differs from standard attention mechanisms that properly handle multi-GPU contexts

How to Reproduce / 如何复现

# Environment setup
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export NPROC_PER_NODE=8

# DeepSpeed launch
deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 28956 \
    swift/cli/rlhf.py \
    --rlhf_type grpo \
    --model /path/to/Qwen3.5-4B-Base \
    --model_type qwen3_5 \
    --template qwen3_5 \
    --use_vllm true \
    --vllm_mode colocate \
    --vllm_tensor_parallel_size 2 \
    --deepspeed zero2 \
    --num_generations 8 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --torch_dtype bfloat16 \
    ...

Key trigger conditions:

  • Multi-GPU training (DeepSpeed)
  • --use_vllm true (using vLLM for rollout generation)
  • Qwen3.5 model (uses GDN attention)

Workarounds / 临时解决方案

Option 1: Disable vLLM (Confirmed Working) ✅

Use transformers engine instead of vLLM:
--use_vllm false
Trade-off: 5-10x slower rollout generation, but training works correctly.

Additional fix needed when using transformers engine:
The transformers engine passes all data fields to model.generate(), causing:
ValueError: The following model_kwargs are not used by the model: ['solution', 'task_type', ...]

Fix: Filter non-model parameters in transformers_engine.py:
swift/infer_engine/transformers_engine.py, line ~388
def _infer_full(self, inputs: Dict[str, Any], *, ...):
# Filter out non-model kwargs
model_forward_params = set(inspect.signature(self.model.forward).parameters.keys())
filtered_inputs = {k: v for k, v in inputs.items() if k in model_forward_params}
generate_kwargs = {'generation_config': generation_config, **filtered_inputs}

Option 2: Use vLLM Server Mode (Not Verified)

--vllm_mode server
May avoid context pollution by running vLLM in a separate process.

Trade-off: Requires weight synchronization between processes, doubles GPU memory usage.

Suggested Fix / 建议修复

For ms-swift:

  1. Short-term: Add automatic filtering of non-model parameters in transformers_engine.py when --use_vllm false is used for GRPO training
  2. Long-term: Document the incompatibility between Qwen3.5 + vLLM V1 + DeepSpeed multi-GPU, recommend --use_vllm false for GRPO training with Qwen3.5

For vLLM (upstream):

Fix GDN attention to properly set CUDA device context before Triton kernel launches:
vllm/model_executor/layers/mamba/gdn_linear_attn.py
with torch.cuda.device(tensor.device):
fused_gdn_gating_kernelgrid

Related Issues / 相关 Issues

Additional Information / 补充信息

The error only occurs with:

  • vLLM enabled (--use_vllm true)
  • Multi-GPU training (single GPU works fine)
  • Qwen3.5 model (other models using standard attention work fine)

SFT training works because it doesn't use vLLM for generation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions