Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ A conversation between User and Assistant. The user asks a question, and the Ass
- num_generations: 每个prompt采样的数量,论文中的G值,需要被 per_device_eval_batch_size * nproc_per_node 整除
- max_completion_length: 采样生成的最大长度,默认为512
- reward_funcs: 奖励函数,根据模型生成结果进行打分,内置accuracy、format、cosine和repetition四个rule-based函数,详细见 swift/plugin/orm.py 文件
- log_completions: 是否记录训练中的模型生成内容,搭配 report_to wandb 使用,默认为False
- reward_weights: 每个奖励函数的权重。必须与奖励函数的数量匹配。如果为 None,则所有奖励的权重都相等,为`1.0`
- 提示:如果GRPO训练中包含`--reward_model`,则其加在奖励函数的最后位置
- log_completions: 是否记录训练中的模型生成内容,搭配 `--report_to wandb` 使用,默认为False
- use_vllm: 是否使用vLLM作为采样的生成后端,默认为False,建议使用加快训练速度
- vllm_device: 设置vLLM部署的设备,默认为`auto`, 即未被使用的第一张显卡,使用`cuda:x`来设置特定的卡。
- vllm_gpu_memory_utilization: vLLM透传参数
Expand Down
4 changes: 3 additions & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ reward模型参数将在PPO、GRPO中使用。
- num_generations: GRPO算法中的G值,默认为8
- max_completion_length: GRPO算法中的最大生成长度,默认为512
- reward_funcs: GRPO算法奖励函数,可选项为`accuracy`、`format`、`cosine` 和 `repetition`,见swift/plugin/orm.py。你也可以在plugin中自定义自己的奖励函数。默认为`[]`
- log_completions: 是否记录训练中的模型生成内容,搭配 report_to wandb 使用,默认为False
- reward_weights: 每个奖励函数的权重。必须与奖励函数的数量匹配。如果为 None,则所有奖励的权重都相等,为`1.0`
- 提示:如果GRPO训练中包含`--reward_model`,则其加在奖励函数的最后位置
- log_completions: 是否记录训练中的模型生成内容,搭配 `--report_to wandb` 使用,默认为False
- use_vllm: 是否使用vLLM作为GRPO生成的infer_backend,默认为False
- vllm_device: 设置vLLM部署的设备,比如部署在卡0上,则`cuda:1`, 默认为`auto`, 即使用最后一张卡
- vllm_gpu_memory_utilization: vllm透传参数,默认为0.9
Expand Down
2 changes: 1 addition & 1 deletion docs/source/Instruction/预训练与微调.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ swift infer \
```
- 如果使用全参数训练,请使用`--model`替代`--adapters`指定训练的checkpoint目录。
- 你可以选择对LoRA进行merge(额外指定`--merge_lora true`),然后指定`--infer_backend vllm/lmdeploy`进行推理加速。
- 你可以使用`swift app`替代`--swift infer`进行界面推理。
- 你可以使用`swift app`替代`swift infer`进行界面推理。

对数据集中的验证集进行批量推理:
```shell
Expand Down
4 changes: 3 additions & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ The meanings of the following parameters can be referenced [here](https://huggin
- num_generations: The G value in the GRPO algorithm, default is 8.
- max_completion_length: The maximum generation length in the GRPO algorithm, default is 512.
- reward_funcs: Reward functions in the GRPO algorithm; options include `accuracy`,`format`,`cosine` and `repetition`, as seen in `swift/plugin/orm.py`. You can also customize your own reward functions in the plugin. Default is `[]`.
- log_completions: Whether to log the model-generated content during training, to be used in conjunction with report_to wandb, default is False.
- reward_weights: Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`.
- Note: If `--reward_model` is included in GRPO training, it is added to the end of the reward functions.
- log_completions: Whether to log the model-generated content during training, to be used in conjunction with `--report_to wandb`, default is False.
- use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False.
- vllm_device: Set the device for vLLM deployment. For example, if deployed on card 0, use `cuda:0`; default is `auto`, which means using the last available GPU.
- vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9.
Expand Down
4 changes: 3 additions & 1 deletion docs/source_en/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ Hyperparameters
- num_generations: The number of samples for each prompt, referred to as the G value in the paper, needs to be divisible by per_device_eval_batch_size * - nproc_per_node.
- max_completion_length: The maximum length for sampling generation, default is 512.
- reward_funcs: Reward functions to score the results generated by the model. Includes built-in accuracy, format , cosine and repetition rule-based functions, detailed in the swift/plugin/orm.py file.
- log_completions: Whether to log the model-generated content during training, to be used in conjunction with report_to wandb, default is False.
- reward_weights: Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`.
- Note: If `--reward_model` is included in GRPO training, it is added to the end of the reward functions.
- log_completions: Whether to log the model-generated content during training, to be used in conjunction with `--report_to wandb`, default is False.
- use_vllm: Whether to use vLLM as the back-end for sampling generation; default is False, using it is recommended to speed up training.
- vllm_device: Device for deploying vLLM, default is auto, meaning the first unused GPU. Use cuda:x to specify a particular card.
- vllm_gpu_memory_utilization: vLLM pass-through parameter.
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Pre-training-and-Fine-tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ swift infer \

- If you are using full parameter training, please replace `--adapters` with `--model` to specify the directory of the trained checkpoint.
- You can choose to merge LoRA (by additionally specifying `--merge_lora true`) and then specify `--infer_backend vllm/lmdeploy` for inference acceleration.
- You can use `swift app` instead of `--swift infer` for interface-based inference.
- You can use `swift app` instead of `swift infer` for interface-based inference.

For batch inference on the validation set of the dataset:

Expand Down
1 change: 1 addition & 0 deletions examples/export/quantize/mllm/awq.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Test environment: transformers==4.47.1, autoawq==0.2.8
CUDA_VISIBLE_DEVICES=0 \
swift export \
--model Qwen/Qwen2-VL-2B-Instruct \
Expand Down
3 changes: 3 additions & 0 deletions swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class GRPOArguments(GRPOArgumentsMixin):
num_generations: int = 8 # G in the GRPO paper
max_completion_length: int = 512
reward_funcs: List[str] = field(default_factory=list)
reward_weights: List[float] = None
log_completions: bool = False

# vLLM in GRPO
use_vllm: bool = False
vllm_device: Optional[str] = 'auto' # 'cuda:0'
Expand Down
1 change: 0 additions & 1 deletion swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def place_model_on_device(self):

@dataclass
class GRPOArgumentsMixin:
log_completions: bool = False
# vllm_device, vllm_gpu_memory_utilization, and vllm_max_model_len are defined in HfGRPOConfig.
vllm_max_num_seqs: int = 256
vllm_enforce_eager: bool = False
Expand Down
15 changes: 11 additions & 4 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import torch
import torch.nn as nn
from accelerate.utils import broadcast_object_list, gather, gather_object
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel
from trl import GRPOTrainer as HFGRPOTrainer
from trl.models import unwrap_model_for_generation

from swift.llm import InferRequest, RequestConfig, to_device
from swift.plugin.orm import orms
Expand Down Expand Up @@ -65,6 +63,15 @@ def __init__(self,
if not self.reward_funcs:
raise ValueError('You must specify reward_funcs or reward_model')

# Reward weights
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward '
f'functions ({len(reward_funcs)})')
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)

self.num_generations = args.num_generations
model.warnings_issued['estimate_tokens'] = True
kwargs['data_collator'] = lambda x: x
Expand Down Expand Up @@ -244,8 +251,8 @@ def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

rewards_per_func = gather(rewards_per_func)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)

# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
Expand Down
4 changes: 2 additions & 2 deletions swift/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def safe_ddp_context(hash_id: str):
yield


def get_device(rank: Optional[Union[str, int]] = None) -> 'torch.device':
def get_device(rank: Optional[Union[str, int]] = None) -> str:
if rank is None:
rank = get_dist_setting()[1]
if rank < 0 or rank is None:
Expand All @@ -256,7 +256,7 @@ def get_device(rank: Optional[Union[str, int]] = None) -> 'torch.device':
else:
device = 'cpu'

return torch.device(device)
return device


def get_device_count() -> int:
Expand Down
Loading