Skip to content

使用grpo训练数据格式问题 #6874

@wangat

Description

@wangat

尝试处理数据类似https://github.com/modelscope/ms-swift/issues/4118;

数据格式为:
{
"id": "example_109811_230681.jpeg",
"messages": [
{
"role": "system",
"content": "……" #较长;
},
{
"role": "user",
"content": "图片占位符没有显示出来output the starting and ending points of each road in the satellite image, as well as all nodes in the satellite image:"
}
],
"images":["/109811_230681.jpeg"],
"objects":
{
"ref": [],
"bbox": [[192,0],[0,234]……] #都是点;
},
"solution": "road1[endpoint[,];point[,,,]];node[]"
}

同批数据一开始处理成dpo格式,测试可以正常训练并且结果正常,之后使用脚本将dpo格式数据修改为如上格式(删除reject中数据、去除bbox中多的点信息),同时测试"objects"类型是dict,尝试更换[192,0]中数字为int类型和str类型,测试训练中都报错如下:

训练指令为:
NPROC_PER_NODE=6
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
swift rlhf
--rlhf_type grpo
--model /checkpoint-26000
--external_plugins /grpo_test/randomrewardtest.py
--reward_funcs random_reward
--use_vllm true
--vllm_mode server
--vllm_server_host xxx
--vllm_server_port 8011
--model_type qwen2_5_vl
--dataset /rl_1.json
--split_dataset_ratio 0.01
--num_train_epochs 5
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--gradient_accumulation_steps 8
--save_steps 50
--eval_steps 20
--save_total_limit 2
--logging_steps 10
--train_type full
--freeze_vit False
--freeze_aligner False
--learning_rate 1e-6
--warmup_ratio 0.05
--lr_scheduler_type cosine_with_min_lr
--lr_scheduler_kwargs '{"min_lr": 1e-7}'
--max_length 12288
--gradient_checkpointing_kwargs '{"use_reentrant": false}'
--report_to swanlab
--swanlab_project swift-robot
--deepspeed zero3
--bf16 true
--attn_impl flash_attn
--output_dir output/grpo-test/
--num_generations 6
--generation_batch_size 96
--temperature 1.0
--log_completions true
--async_generate true
--beta 0.001

[rank1]: Traceback (most recent call last):
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/cli/rlhf.py", line 5, in
[rank1]: rlhf_main()
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/llm/train/rlhf.py", line 200, in rlhf_main
[rank1]: return SwiftRLHF(args).main()
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/llm/base.py", line 49, in main
[rank1]: result = self.run()
[rank1]: ^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/llm/train/sft.py", line 187, in run
[rank1]: return self.train(trainer)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/llm/train/sft.py", line 235, in train
[rank1]: trainer.train(trainer.args.resume_from_checkpoint)
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/mixin.py", line 676, in train
[rank1]: res = super().train(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/transformers/trainer.py", line 2328, in train
[rank1]: return inner_training_loop(
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/transformers/trainer.py", line 2578, in _inner_training_loop
[rank1]: self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/transformers/trainer_callback.py", line 506, in on_train_begin
[rank1]: return self.call_event("on_train_begin", args, state, control)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/transformers/trainer_callback.py", line 556, in call_event
[rank1]: result = getattr(callback, event)(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 84, in on_train_begin
[rank1]: self.trainer._prefetch(train_dataloader)
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 738, in _prefetch
[rank1]: results = self._infer_single_or_multi_turn(all_inputs, self.request_config, is_global_inputs=True)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 696, in _infer_single_or_multi_turn
[rank1]: rollout_outputs: List[RolloutOutput] = self._rollout(inputs, request_config, is_global_inputs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 626, in _rollout
[rank1]: rollout_outputs = self._server_rollout(inputs, request_config, is_global_inputs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 2269, in _server_rollout
[rank1]: infer_requests = self.inputs2requests(inputs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 2460, in inputs2requests
[rank1]: return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 2460, in
[rank1]: return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/dacite/core.py", line 74, in from_dict
[rank1]: raise WrongTypeError(field_path=field.name, field_type=field_type, value=value)
[rank1]: dacite.exceptions.WrongTypeError: wrong value type for field "objects" - should be "Dict" instead of value "{'ref': [], 'bbox': [[0.0, 230.0], [1023.0, 339.0], [14.0, 251.0], [126.0, 451.0], [226.0, 552.0], [367.0, 596.0], [568.0, 542.0], [407.0, 0.0], [1023.0, 3.0], [422.0, 18.0], [513.0, 98.0], [597.0, 119.0], [710.0, 115.0]],'bbox_type':None, 'image_id': None} of type "dict"

想请问可能的错误是什么?有没有相关数据文档,暂时不确定如何处理数据?

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions