Skip to content

GRPO多节点训练,训练比较慢,脚本还能优化吗 #3186

@owl-10

Description

@owl-10

`export NNODES=8
export MASTER_PORT=${MASTER_PORT:-34229}
export GPUS=${GPUS:-8}
export NPROC_PER_NODE_NODE0=7
export NPROC_PER_NODE_OTHERS=${GPUS}

export NCCL_IB_DISABLE=0
export NCCL_P2P_DISABLE=0
export NCCL_SHM_DISABLE=0
export TF_CPP_MIN_LOG_LEVEL=3
export LAUNCHER=pytorch

--- Determine NODE_RANK (Priority: MLP_ROLE_INDEX > NODE_RANK) ---

if [ -n "$MLP_ROLE_INDEX" ]; then
export NODE_RANK=$MLP_ROLE_INDEX
echo "Using MLP_ROLE_INDEX for NODE_RANK: $NODE_RANK"
elif [ -n "$NODE_RANK" ]; then
echo "Using NODE_RANK: $NODE_RANK"
else
echo "Error: NODE_RANK or MLP_ROLE_INDEX environment variable must be set (e.g., export NODE_RANK=0)."
exit 1
fi

--- Configure NPROC_PER_NODE based on NODE_RANK ---

if [ "$NODE_RANK" -eq 0 ]; then
export NPROC_PER_NODE=$NPROC_PER_NODE_NODE0
echo "NPROC_PER_NODE set to: $NPROC_PER_NODE (for Node 0)"
else
export NPROC_PER_NODE=$NPROC_PER_NODE_OTHERS
echo "NPROC_PER_NODE set to: $NPROC_PER_NODE (for Node Rank > 0)"
fi

export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"

echo "NNODES: $NNODES"
echo "NODE_RANK: $NODE_RANK"
echo "MASTER_ADDR: $MASTER_ADDR"
echo "MASTER_PORT: $MASTER_PORT"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "----------------------------------------"

--- Distributed Training Command using torchrun ---

torchrun
--nnodes=$NNODES
--node_rank=$NODE_RANK
--master_addr=$MLP_WORKER_0_HOST
--nproc_per_node=$NPROC_PER_NODE
--master_port=$MLP_WORKER_0_PORT
swift/cli/rlhf.py
--rlhf_type grpo
--model /path/to/internvl2_5-8b
--reward_funcs acc_reward
--model_type internvl2_5
--use_vllm true
--vllm_device auto
--vllm_gpu_memory_utilization 0.7
--vllm_max_model_len 8192
--train_type full
--torch_dtype bfloat16
--max_completion_length 512
--num_train_epochs 1
--per_device_train_batch_size 2
--per_device_eval_batch_size 1
--learning_rate 1e-6
--gradient_accumulation_steps 16
--eval_steps 50
--save_steps 50
--save_total_limit 10
--logging_steps 1
--max_length 4096
--output_dir output_multi_node
--warmup_ratio 0.05
--dataloader_num_workers 32
--dataset_num_proc 32
--num_generations 7
--log_completions True
--temperature 0.9
--top_p 0.9
--deepspeed zero2
--system 'examples/train/grpo/prompt_baseline.txt'
`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions