-
Notifications
You must be signed in to change notification settings - Fork 998
Initialize chord dataset after accelerator setup in GRPOTrainer #6638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The get_chord_sft_dataloader() method relies on GRPOTrainer.accelerator, but the function was previously called before the parent class (super().__init__) finished initializing the accelerator. As a result, the get_chord_sft_dataloader will raise exception regarding non-existent attribute GRPOTrainer.accelerator.
Summary of ChangesHello @tongchen126, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a critical bug in GRPOTrainer where get_chord_sft_dataloader() was invoked before the accelerator attribute was initialized, leading to an AttributeError. The fix correctly repositions the CHORD dataset initialization to occur after the parent class's __init__ method has set up the accelerator. Additionally, the logic for preparing the CHORD dataset has been cleanly refactored into a new _prepare_chord_dataset method. This not only resolves the bug but also improves the code's structure and readability. The changes are correct and well-implemented.
|
thanks! |
…ner (modelscope#6638) The get_chord_sft_dataloader() method relies on GRPOTrainer.accelerator, but the function was previously called before the parent class (super().__init__) finished initializing the accelerator. As a result, the get_chord_sft_dataloader will raise exception regarding non-existent attribute GRPOTrainer.accelerator.
…ner (#6638) The get_chord_sft_dataloader() method relies on GRPOTrainer.accelerator, but the function was previously called before the parent class (super().__init__) finished initializing the accelerator. As a result, the get_chord_sft_dataloader will raise exception regarding non-existent attribute GRPOTrainer.accelerator.
PR type
PR information
The get_chord_sft_dataloader() method relies on the existence of attribute "GRPOTrainer.accelerator", but it was called before the parent class (super().__init__()) finished initializing the attribute "accelerator" (This initialization occurs in swift/trainers/rlhf_trainer/grpo_trainer.py, line 82). As a result, calling get_chord_sft_dataloader() raises an exception because the "accelerator" attribute does not yet exist.
Hope it helps. Thank you.
Experiment results
Command:
swift rlhf
--model Qwen/Qwen3-VL-8B-Instruct --dataset /data/pokemon/data.json --split_dataset_ratio 0.01 --num_train_epochs 8 --eval_steps 100 --save_steps 100 --save_total_limit 5 --logging_steps 50 --warmup_r
atio 0.1 --dataloader_num_workers 8 --dataset_num_proc 2 --lr_scheduler_type cosine --load_from_cache_file true --gradient_checkpointing true --report_to all --use_hf true --torch_dtype bfloat16 --log_completions true --deepspeed zero3 --max_le
ngth 4096 --max_completion_length 2048 --freeze_vit true --target_modules all-linear --per_device_train_batch_size 1 --per_device_eval_batch_size 2 --gradient_accumulation_steps 4 --rlhf_type grpo --num_generations 8 --temperature 1.0 --bet
a 0.001 --external_plugins /pokemon.py --reward_funcs pokemon_grpo_format pokemon_grpo_acc repetition --use_vllm true --vllm_mode colocate --vllm_gpu_memory_utilization 0.4 --vllm_tensor
_parallel_size 1 --vllm_max_model_len 4096 --vllm_data_parallel_size 8
--train_type lora --output_dir saved/grpo-e3-r8-b1-beta3-chord --learning_rate 1e-3 --lora_rank 8 --lora_alpha 16
--chord_sft_per_device_train_batch_size 1 --chord_sft_dataset /data/pokemon/data.json --chord_enable_phi_function false --chord_mu_warmup_steps 25 --chord_mu_decay_steps 300 --chord_mu_pea
k 0.75 --chord_mu_valley 0.15
Exception (in rank1):
[rank1]: Traceback (most recent call last):
[rank1]: File "/ms-swift/swift/cli/rlhf.py", line 7, in
[rank1]: rlhf_main()
[rank1]: File "/ms-swift/swift/llm/train/rlhf.py", line 227, in rlhf_main
[rank1]: return SwiftRLHF(args).main()
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/ms-swift/swift/llm/base.py", line 49, in main
[rank1]: result = self.run()
[rank1]: ^^^^^^^^^^
[rank1]: File "/ms-swift/swift/ray/base.py", line 170, in wrapper
[rank1]: return func(self, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/ms-swift/swift/llm/train/sft.py", line 196, in run
[rank1]: trainer = trainer_cls(
[rank1]: ^^^^^^^^^^^^
[rank1]: File "/ms-swift/swift/trainers/rlhf_trainer/grpo_trainer.py", line 81, in init
[rank1]: self._prepare_algorithm_params()
[rank1]: File "/ms-swift/swift/trainers/rlhf_trainer/grpo_trainer.py", line 1867, in _prepare_algorithm_params
[rank1]: self.chord_sft_iterator = make_chord_sft_dataset(self, self.chord_sft_dataset) [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/ms-swift/swift/trainers/rlhf_trainer/utils.py", line 934, in make_chord_sft_dataset [rank1]: chord_sft_dataloader = get_chord_sft_dataloader( [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/ms-swift/swift/trainers/rlhf_trainer/utils.py", line 920, in get_chord_sft_dataloader
[rank1]: dataloader = trainer.accelerator.prepare(DataLoader(dataset, **dataloader_params))
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: AttributeError: 'GRPOTrainer' object has no attribute 'accelerator'