Skip to content

fix: unblock PPO multi-GPU training#18

Merged
kaiitunnz merged 13 commits into
mainfrom
kaiitunnz/fix/ppo-stuck
May 6, 2026
Merged

fix: unblock PPO multi-GPU training#18
kaiitunnz merged 13 commits into
mainfrom
kaiitunnz/fix/ppo-stuck

Conversation

@kaiitunnz
Copy link
Copy Markdown
Collaborator

@kaiitunnz kaiitunnz commented May 5, 2026

Purpose

This PR addresses the PPO multi-GPU training bug listed in #1 (comment) ("PPO multi-gpu training is broken"), which hangs partway through a run under DDP. It also fixes several related PPO/DPO template and executor-config issues that surfaced while reproducing the hang on single- and multi-GPU workers.

Changes

  • src/worker/executors/ppo_executor.py — extract _build_ppo_config helper; wire save_strategy, save_steps, save_total_limit, and save_only_model from the training config; install save_model / _save_checkpoint overrides that unwrap the policy under DDP and skip TRL's broken create_model_card; replace the per-device batch clip with a full local-batch normalizer that accounts for world_size, gradient_accumulation_steps, and num_mini_batches.
  • src/worker/executors/utils/huggingface.py — propagate training.padding_side into tokenizer kwargs so PPO generation uses left padding.
  • templates/ppo_training_llama_1b.yaml, templates/ppo_training_llama_1b_multi_gpu.yaml — retune learning rate / batch shape, add padding_side: "left" and save_strategy, and move bf16 / fp16 from model.config to the top-level training block.
  • templates/dpo_training_llama_1b.yaml, templates/dpo_training_llama_1b_multi_gpu.yaml — same template format migration for the DPO Llama workflows.
  • templates/ppo_training_mistral.yamltemplates/ppo_training_ministral.yaml, templates/dpo_training_mistral.yamltemplates/dpo_training_ministral.yaml — replace Mistral-7B with mistralai/Ministral-3-3B-Instruct-2512 so the templates fit on a single commodity GPU, and rename the files / workflow names / tracker project to match.

Design

The PPO multi-GPU hang is not a single bug; reproducing it on the Llama-1B and Ministral-3B templates surfaced four independent issues that all stall a run, each with its own fix:

  1. Stale template knobs were silently ignored. The PPO/DPO templates put bf16 / fp16 under model.config, but pick_torch_dtype reads them from the top-level training block, so the workers fell back to fp32 and OOM'd or stalled on small GPUs. The templates now set bf16 / fp16 (and padding_side) under training, where the executor actually reads them, and the tokenizer-builder propagates padding_side to HF.

  2. Checkpointing deadlocked under DDP. TRL's PPOTrainer wraps the policy in PolicyAndValueWrapper and then in DDP, so Trainer.save_model saves the wrapper instead of the underlying policy, and _save_checkpoint calls create_model_card which reads self.model.config._name_or_path — a field that does not exist on the DDP-wrapped object. Rank 0 errors out before the checkpoint barrier while the other ranks block on the collective, so the run hangs. We install two trainer overrides: save_model unwraps the policy before saving, and _save_checkpoint skips the broken create_model_card call.

  3. Save behavior was not configurable. PPO ignored save_strategy, save_steps, save_total_limit, and save_only_model from the training config, so multi-GPU runs either saved on every step (slow) or not at all (no recovery). These are now wired through _build_ppo_config (extracted from the inline setup for readability).

  4. Empty dataloader stall. TRL's PPO dataloader uses drop_last=True, so when world_size * per_device_batch * gradient_accumulation_steps exceeds the dataset size the loader yields zero batches and the trainer hangs on the first epoch barrier. The previous code only clipped per_device_batch against the raw dataset size, which missed both the world_size factor and the gradient_accumulation_steps factor. We replace it with a normalizer that scales the local batch (and num_mini_batches) down to fit, logging the adjustment so users see what changed.

Test Plan

uv run pre-commit run --all-files
uv run pytest tests/ --ignore=tests/worker/test_mp_executor_cleanup_gpu.py

# End-to-end PPO runs
flowmesh stack up
flowmesh stack worker up gpu -t 0
flowmesh workflow submit templates/ppo_training_llama_1b.yaml
flowmesh workflow submit templates/ppo_training_ministral.yaml
flowmesh stack worker down all
flowmesh stack worker up gpu -t 0,1
flowmesh workflow submit templates/ppo_training_llama_1b_multi_gpu.yaml
flowmesh stack clean

Test Result

  • pre-commit run --all-files passes.
  • 588 unit tests pass (uv run pytest tests/).
  • All three PPO workflows (templates/ppo_training_llama_1b_multi_gpu.yaml, templates/ppo_training_llama_1b.yaml, templates/ppo_training_ministral.yaml) run to completion without hanging.

Pre-submission Checklist
  • I have read the contribution guidelines.
  • I have run pre-commit run --all-files and fixed any issues.
  • I have added or updated tests covering my changes (if applicable).
  • I have verified that uv run pytest tests/ passes locally.
  • If I changed shared schemas or proto definitions, I have checked downstream compatibility across Server and Worker.
  • If I changed the SDK or CLI, I have verified the affected packages work (uv sync --all-extras --frozen).
  • If this is a breaking change, I have prefixed the PR title with [BREAKING] and described migration steps above.
  • I have updated documentation or config examples if user-facing behavior changed.

@kaiitunnz kaiitunnz requested a review from timzsu May 5, 2026 14:02
Copy link
Copy Markdown
Collaborator

@timzsu timzsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skimmed through the changed code and left some quick questions. I will do a more comprehensive pass tomorrow.

Comment thread src/worker/executors/ppo_executor.py Outdated
Comment thread src/worker/executors/ppo_executor.py Outdated
Comment thread src/worker/executors/ppo_executor.py Outdated
Comment thread src/worker/executors/ppo_executor.py Outdated
Comment thread templates/dpo_training_llama_1b.yaml
Comment thread templates/ppo_training_ministral.yaml Outdated
kaiitunnz added 11 commits May 6, 2026 05:28
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
@kaiitunnz kaiitunnz force-pushed the kaiitunnz/fix/ppo-stuck branch from fd7b75e to f95f24d Compare May 6, 2026 05:28
Copy link
Copy Markdown
Collaborator

@timzsu timzsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments to address. In addition, can we switch the default lr to 1e-5 and make the use of double quotes in templates consistent?

Comment thread src/worker/executors/ppo_executor.py Outdated
Comment thread src/worker/executors/ppo_executor.py Outdated
Comment thread src/worker/executors/ppo_executor.py Outdated
Comment thread src/worker/executors/ppo_executor.py Outdated
Comment thread templates/dpo_training_llama_1b.yaml
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
@kaiitunnz kaiitunnz requested a review from timzsu May 6, 2026 07:07
Copy link
Copy Markdown
Collaborator

@timzsu timzsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed a block last time. Can you take a look of these two comments?

Comment thread src/worker/executors/ppo_executor.py Outdated
backup_model = ppo_trainer.model
backup_deepspeed = getattr(ppo_trainer, "deepspeed", None)
ppo_trainer.model = self._resolve_model_for_save(backup_model)
if getattr(ppo_trainer, "is_deepspeed_enabled", False):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_deepspeed_enabled is a member of PPOTrainer, and deepspeed is a member of Trainer (and thus a member of PPOTrainer). Can we make them direct attribute access?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment thread src/worker/executors/ppo_executor.py
Signed-off-by: Noppanat Wadlom <noppanat.wad@gmail.com>
@kaiitunnz kaiitunnz requested a review from timzsu May 6, 2026 07:28
Copy link
Copy Markdown
Collaborator

@timzsu timzsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the explanation.

@kaiitunnz kaiitunnz merged commit 4a9a81d into main May 6, 2026
10 checks passed
@kaiitunnz kaiitunnz deleted the kaiitunnz/fix/ppo-stuck branch May 6, 2026 07:35
@timzsu timzsu mentioned this pull request May 13, 2026
9 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants