Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/resources/grpo_openr1_multimodal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
66 changes: 66 additions & 0 deletions docs/source/BestPractices/GRPO多模态训练.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,69 @@ User:
Assistant:
"<think>\nTo find the measure of angle ADC, we first need to understand the properties of the tangent and the circle. Since AB is a tangent to the circle at point A, the line segment OA (which is perpendicular to the tangent AB) is perpendicular to AB. This means that angle OAB is 90 degrees.\n\nGiven that angle ABO is 32 degrees, we can find angle BAO by subtracting angle ABO from 90 degrees:\n∠BAO = 90° - ∠ABO = 90° - 32° = 58°\n\nSince angle BAO is an angle in the minor arc AD, the angle ADC, which is an angle in the minor arc AD, is half the angle BAO. Therefore, angle ADC is:\n∠ADC = 1/2 × ∠BAO = 1/2 × 58° = 29°\n\nSo, the measure of angle ADC is 29 degrees.\n</think>\n<answer>\nThe measure of angle ADC is 29 degrees.\n</answer>"
```
## Multimodal Open R1 数据集实验
### 任务与数据集定义
本任务为参考[open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal.git)的实验,使用数据集:[lmms-lab/multimodal-open-r1-8k-verified](https://www.modelscope.cn/datasets/lmms-lab/multimodal-open-r1-8k-verified),该数据集专注于多模态的数学推理任务,数据由GPT4o基于`Math360K`和`Geo170K`数据集生成,包含推理路径和可验证答案。数据集中已包含了image, problem和solution字段,我们也不需要针对prompt进行修改,因此无需额外定义数据集。
### 奖励函数
我们直接使用以上定义过的`MultiModalAccuracyORM`奖励函数。
### GRPO训练实验记录
#### 训练参数:
选取的模型和大部分超参数与上一个实验相似,由于训练的时候出现了OOM,我们设置`MAX_PIXELS=262144`以降低显存占用。
```shell
WANDB_API_KEY=your_wandb_api_key \
MAX_PIXELS=262144 \
MASTER_PORT=29600 \
NPROC_PER_NODE=6 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-VL-3B-Instruct \
--external_plugins examples/train/grpo/plugin/plugin.py \
--reward_funcs external_r1v_acc format \
--use_vllm true \
--vllm_device auto \
--vllm_gpu_memory_utilization 0.6 \
--train_type full \
--torch_dtype bfloat16 \
--dataset 'lmms-lab/multimodal-open-r1-8k-verified' \
--max_length 8192 \
--max_completion_length 1024 \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--learning_rate 1e-6 \
--gradient_accumulation_steps 2 \
--save_strategy 'steps' \
--eval_strategy 'steps' \
--eval_steps 400 \
--save_steps 400 \
--save_total_limit 10 \
--logging_steps 1 \
--output_dir output/GRPO_GEOQA \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--num_generations 8 \
--temperature 1.0 \
--repetition_penalty 1.1 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero3 \
--log_completions true \
--report_to wandb \
--num_iterations 2 \
--num_infer_workers 2 \
--async_generate false \
--beta 0.001 \
--max_grad_norm 0.5 \

```

#### 实验现象
![image.png](../../resources/grpo_openr1_multimodal.png)
训练曲线如上图所示。
一共8k数据,训练了1268个step,Accuracy Reward收敛到0.5左右,completion_length基本收敛到200token左右,reward_std最终在0.2左右震荡。
以下是训练后的模型completition示例:
```
User:
"Based on the map, which state falls into the lowest percentage range of lots, and what is that percentage range?"
Assistant:
"<think>\nThe image provided shows a color-coded map of the USA indicating different ranges of lot percentages across various states. According to the legend at the bottom, the lowest percentage range (6.0% - 6.5%) is represented by white. In the image, Alabama (AL) is shaded in white, which corresponds to the 6.0% - 6.5% category. Therefore, based on the map, the state that falls into the lowest percentage range of lots is Alabama, with the percentage range of 6.0% - 6.5%.\nTherefore, the answer is 6.0% - 6.5%.\n</think>\n<answer>Alabama</answer>"
```