-
Notifications
You must be signed in to change notification settings - Fork 1k
Description
尝试处理数据类似https://github.com/modelscope/ms-swift/issues/4118;
数据格式为:
{
"id": "example_109811_230681.jpeg",
"messages": [
{
"role": "system",
"content": "……" #较长;
},
{
"role": "user",
"content": "图片占位符没有显示出来output the starting and ending points of each road in the satellite image, as well as all nodes in the satellite image:"
}
],
"images":["/109811_230681.jpeg"],
"objects":
{
"ref": [],
"bbox": [[192,0],[0,234]……] #都是点;
},
"solution": "road1[endpoint[,];point[,,,]];node[]"
}
同批数据一开始处理成dpo格式,测试可以正常训练并且结果正常,之后使用脚本将dpo格式数据修改为如上格式(删除reject中数据、去除bbox中多的点信息),同时测试"objects"类型是dict,尝试更换[192,0]中数字为int类型和str类型,测试训练中都报错如下:
训练指令为:
NPROC_PER_NODE=6
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
swift rlhf
--rlhf_type grpo
--model /checkpoint-26000
--external_plugins /grpo_test/randomrewardtest.py
--reward_funcs random_reward
--use_vllm true
--vllm_mode server
--vllm_server_host xxx
--vllm_server_port 8011
--model_type qwen2_5_vl
--dataset /rl_1.json
--split_dataset_ratio 0.01
--num_train_epochs 5
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--gradient_accumulation_steps 8
--save_steps 50
--eval_steps 20
--save_total_limit 2
--logging_steps 10
--train_type full
--freeze_vit False
--freeze_aligner False
--learning_rate 1e-6
--warmup_ratio 0.05
--lr_scheduler_type cosine_with_min_lr
--lr_scheduler_kwargs '{"min_lr": 1e-7}'
--max_length 12288
--gradient_checkpointing_kwargs '{"use_reentrant": false}'
--report_to swanlab
--swanlab_project swift-robot
--deepspeed zero3
--bf16 true
--attn_impl flash_attn
--output_dir output/grpo-test/
--num_generations 6
--generation_batch_size 96
--temperature 1.0
--log_completions true
--async_generate true
--beta 0.001
[rank1]: Traceback (most recent call last):
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/cli/rlhf.py", line 5, in
[rank1]: rlhf_main()
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/llm/train/rlhf.py", line 200, in rlhf_main
[rank1]: return SwiftRLHF(args).main()
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/llm/base.py", line 49, in main
[rank1]: result = self.run()
[rank1]: ^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/llm/train/sft.py", line 187, in run
[rank1]: return self.train(trainer)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/llm/train/sft.py", line 235, in train
[rank1]: trainer.train(trainer.args.resume_from_checkpoint)
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/mixin.py", line 676, in train
[rank1]: res = super().train(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/transformers/trainer.py", line 2328, in train
[rank1]: return inner_training_loop(
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/transformers/trainer.py", line 2578, in _inner_training_loop
[rank1]: self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/transformers/trainer_callback.py", line 506, in on_train_begin
[rank1]: return self.call_event("on_train_begin", args, state, control)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/transformers/trainer_callback.py", line 556, in call_event
[rank1]: result = getattr(callback, event)(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 84, in on_train_begin
[rank1]: self.trainer._prefetch(train_dataloader)
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 738, in _prefetch
[rank1]: results = self._infer_single_or_multi_turn(all_inputs, self.request_config, is_global_inputs=True)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 696, in _infer_single_or_multi_turn
[rank1]: rollout_outputs: List[RolloutOutput] = self._rollout(inputs, request_config, is_global_inputs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 626, in _rollout
[rank1]: rollout_outputs = self._server_rollout(inputs, request_config, is_global_inputs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 2269, in _server_rollout
[rank1]: infer_requests = self.inputs2requests(inputs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 2460, in inputs2requests
[rank1]: return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/swift/trainers/rlhf_trainer/grpo_trainer.py", line 2460, in
[rank1]: return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.11/site-packages/dacite/core.py", line 74, in from_dict
[rank1]: raise WrongTypeError(field_path=field.name, field_type=field_type, value=value)
[rank1]: dacite.exceptions.WrongTypeError: wrong value type for field "objects" - should be "Dict" instead of value "{'ref': [], 'bbox': [[0.0, 230.0], [1023.0, 339.0], [14.0, 251.0], [126.0, 451.0], [226.0, 552.0], [367.0, 596.0], [568.0, 542.0], [407.0, 0.0], [1023.0, 3.0], [422.0, 18.0], [513.0, 98.0], [597.0, 119.0], [710.0, 115.0]],'bbox_type':None, 'image_id': None} of type "dict"
想请问可能的错误是什么?有没有相关数据文档,暂时不确定如何处理数据?