fix: unblock PPO multi-GPU training#18
Merged
Merged
Conversation
timzsu
requested changes
May 5, 2026
Collaborator
timzsu
left a comment
There was a problem hiding this comment.
Skimmed through the changed code and left some quick questions. I will do a more comprehensive pass tomorrow.
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
fd7b75e to
f95f24d
Compare
timzsu
requested changes
May 6, 2026
Collaborator
timzsu
left a comment
There was a problem hiding this comment.
Some comments to address. In addition, can we switch the default lr to 1e-5 and make the use of double quotes in templates consistent?
15 tasks
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
timzsu
requested changes
May 6, 2026
Collaborator
timzsu
left a comment
There was a problem hiding this comment.
Sorry I missed a block last time. Can you take a look of these two comments?
| backup_model = ppo_trainer.model | ||
| backup_deepspeed = getattr(ppo_trainer, "deepspeed", None) | ||
| ppo_trainer.model = self._resolve_model_for_save(backup_model) | ||
| if getattr(ppo_trainer, "is_deepspeed_enabled", False): |
Collaborator
There was a problem hiding this comment.
is_deepspeed_enabled is a member of PPOTrainer, and deepspeed is a member of Trainer (and thus a member of PPOTrainer). Can we make them direct attribute access?
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
timzsu
approved these changes
May 6, 2026
Collaborator
timzsu
left a comment
There was a problem hiding this comment.
LGTM. Thanks for the explanation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
This PR addresses the PPO multi-GPU training bug listed in #1 (comment) ("PPO multi-gpu training is broken"), which hangs partway through a run under DDP. It also fixes several related PPO/DPO template and executor-config issues that surfaced while reproducing the hang on single- and multi-GPU workers.
Changes
src/worker/executors/ppo_executor.py— extract_build_ppo_confighelper; wiresave_strategy,save_steps,save_total_limit, andsave_only_modelfrom the training config; installsave_model/_save_checkpointoverrides that unwrap the policy under DDP and skip TRL's brokencreate_model_card; replace the per-device batch clip with a full local-batch normalizer that accounts forworld_size,gradient_accumulation_steps, andnum_mini_batches.src/worker/executors/utils/huggingface.py— propagatetraining.padding_sideinto tokenizer kwargs so PPO generation uses left padding.templates/ppo_training_llama_1b.yaml,templates/ppo_training_llama_1b_multi_gpu.yaml— retune learning rate / batch shape, addpadding_side: "left"andsave_strategy, and movebf16/fp16frommodel.configto the top-leveltrainingblock.templates/dpo_training_llama_1b.yaml,templates/dpo_training_llama_1b_multi_gpu.yaml— same template format migration for the DPO Llama workflows.templates/ppo_training_mistral.yaml→templates/ppo_training_ministral.yaml,templates/dpo_training_mistral.yaml→templates/dpo_training_ministral.yaml— replace Mistral-7B withmistralai/Ministral-3-3B-Instruct-2512so the templates fit on a single commodity GPU, and rename the files / workflow names / tracker project to match.Design
The PPO multi-GPU hang is not a single bug; reproducing it on the Llama-1B and Ministral-3B templates surfaced four independent issues that all stall a run, each with its own fix:
Stale template knobs were silently ignored. The PPO/DPO templates put
bf16/fp16undermodel.config, butpick_torch_dtypereads them from the top-leveltrainingblock, so the workers fell back to fp32 and OOM'd or stalled on small GPUs. The templates now setbf16/fp16(andpadding_side) undertraining, where the executor actually reads them, and the tokenizer-builder propagatespadding_sideto HF.Checkpointing deadlocked under DDP. TRL's
PPOTrainerwraps the policy inPolicyAndValueWrapperand then in DDP, soTrainer.save_modelsaves the wrapper instead of the underlying policy, and_save_checkpointcallscreate_model_cardwhich readsself.model.config._name_or_path— a field that does not exist on the DDP-wrapped object. Rank 0 errors out before the checkpoint barrier while the other ranks block on the collective, so the run hangs. We install two trainer overrides:save_modelunwraps the policy before saving, and_save_checkpointskips the brokencreate_model_cardcall.Save behavior was not configurable. PPO ignored
save_strategy,save_steps,save_total_limit, andsave_only_modelfrom the training config, so multi-GPU runs either saved on every step (slow) or not at all (no recovery). These are now wired through_build_ppo_config(extracted from the inline setup for readability).Empty dataloader stall. TRL's PPO dataloader uses
drop_last=True, so whenworld_size * per_device_batch * gradient_accumulation_stepsexceeds the dataset size the loader yields zero batches and the trainer hangs on the first epoch barrier. The previous code only clippedper_device_batchagainst the raw dataset size, which missed both theworld_sizefactor and thegradient_accumulation_stepsfactor. We replace it with a normalizer that scales the local batch (andnum_mini_batches) down to fit, logging the adjustment so users see what changed.Test Plan
Test Result
pre-commit run --all-filespasses.uv run pytest tests/).templates/ppo_training_llama_1b_multi_gpu.yaml,templates/ppo_training_llama_1b.yaml,templates/ppo_training_ministral.yaml) run to completion without hanging.Pre-submission Checklist
pre-commit run --all-filesand fixed any issues.uv run pytest tests/passes locally.uv sync --all-extras --frozen).[BREAKING]and described migration steps above.