Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/templates/ppo_training_llama_1b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ spec:
target_kl: 0.1
seed: 42
padding_side: left
early_stopping: false
early_stopping: true
save_strategy: steps
save_model: true
save_freq: 25 # Save checkpoint every 25 steps
Expand Down
82 changes: 73 additions & 9 deletions src/worker/executors/ppo_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from shared.tasks.specs import PPOSpecStrict
from shared.utils.manifest import scratch_dir
from shared.utils.parsing import safe_float, safe_int
from shared.utils.parsing import safe_float, safe_int, to_bool
from worker.config import WorkerConfig
from worker.lifecycle import Lifecycle

Expand Down Expand Up @@ -324,6 +324,40 @@ def __getattr__(self, name):
return getattr(self.__dict__.get("_lm", object()), name)


class _EarlyStopSignal(Exception):
"""Internal signal: KL exceeded ``target_kl``; unwind ``PPOTrainer.train``."""


class _EarlyStopPPOTrainer(PPOTrainer):
"""``PPOTrainer`` subclass that halts when ``objective/kl`` exceeds a threshold.

TRL's PPO loop calls ``self.log(metrics)`` once per update step but
never checks ``control.should_training_stop``, so a stock
``TrainerCallback`` cannot end training. Overriding ``log`` and
raising ``_EarlyStopSignal`` instead lets the exception unwind the
loop cleanly; the executor catches it at the ``train()`` call site.

``target_kl`` defaults to ``None`` so the subclass is a safe
drop-in when early stopping is disabled.
"""

target_kl: float | None = None

def log(self, logs: dict[str, float], *args: Any, **kwargs: Any) -> None:
super().log(logs, *args, **kwargs)
if self.target_kl is None:
return
kl = logs.get("objective/kl")
if kl is None:
return
assert isinstance(kl, float), f"TRL logged non-float objective/kl: {kl!r}"
if kl > self.target_kl:
logger.info(
"PPO early stop: objective/kl=%.4f > target_kl=%.4f", kl, self.target_kl
)
raise _EarlyStopSignal()


def _resolve_report_to(value: Any) -> str | list[str]:
"""Translate ``training.report_to`` into the value PPOConfig expects.

Expand Down Expand Up @@ -702,7 +736,7 @@ def _simple_collate(features):
# introspection to build arguments positionally/with kwargs to fit.
import inspect

def build_trainer() -> PPOTrainer:
def build_trainer() -> _EarlyStopPPOTrainer:
sig = inspect.signature(PPOTrainer.__init__)
params = list(sig.parameters.values())[1:] # drop self

Expand Down Expand Up @@ -772,15 +806,15 @@ def build_trainer() -> PPOTrainer:
if key in mapping:
legacy_seq.append(mapping[key])
try:
return PPOTrainer(*legacy_seq, **kwargs)
return _EarlyStopPPOTrainer(*legacy_seq, **kwargs)
except TypeError:
# Raise with details for debugging
raise TypeError(
"PPOTrainer signature mismatch; missing required params: "
f"{missing_required}"
)

return PPOTrainer(*positional, **kwargs)
return _EarlyStopPPOTrainer(*positional, **kwargs)

ppo_trainer = build_trainer()
self._ppo_trainer = ppo_trainer
Expand Down Expand Up @@ -813,9 +847,14 @@ def build_trainer() -> PPOTrainer:
if reward_is_external:
logger.info("External reward model enabled for PPO training")

self._install_kl_early_stopping(ppo_trainer, training_config)

logger.info("Starting PPO training...")
with reward_ctx:
ppo_trainer.train()
try:
with reward_ctx:
ppo_trainer.train()
except _EarlyStopSignal:
pass
logger.info("PPO training completed")

training_successful = True
Expand Down Expand Up @@ -1079,9 +1118,6 @@ def _build_ppo_config(
training_config.get("num_train_epochs"), default=1.0, minimum=1.0
)
kl_coef = safe_float(training_config.get("kl_coef"), minimum=0)
# TODO: wire training.target_kl and training.early_stopping
# into a FlowMesh-owned early-stop hook. The templates still set these
# fields so the spec doesn't churn between PRs.

max_seq_length = safe_int(
training_config.get("max_seq_length"), default=64, minimum=1
Expand Down Expand Up @@ -1373,6 +1409,34 @@ def _resolve_model_for_save(model: Any) -> Any:

return model

def _install_kl_early_stopping(
self, ppo_trainer: _EarlyStopPPOTrainer, training_config: dict[str, Any]
) -> None:
"""Set ``target_kl`` on the trainer when ``training.early_stopping`` is on.

The trainer is already an ``_EarlyStopPPOTrainer`` from ``build_trainer``;
we just stamp the threshold so its ``log`` override starts watching KL.
Enforces that ``early_stopping=True`` is paired with a positive
``target_kl``; if ``early_stopping`` is off, ``target_kl`` is ignored
but logged so users notice the gap.
"""
enabled = to_bool(training_config.get("early_stopping"), default=False)
target_kl = safe_float(training_config.get("target_kl"))
if not enabled:
if target_kl is not None and target_kl > 0:
logger.info(
"PPO training.target_kl=%.4f set without early_stopping=true; "
"no early-stop hook attached",
target_kl,
)
return
if target_kl is None or target_kl <= 0:
raise ExecutionError(
"training.early_stopping requires a positive training.target_kl"
)
ppo_trainer.target_kl = target_kl
logger.info("PPO KL early-stop enabled at target_kl=%.4f", target_kl)

def _install_trainer_save_overrides(self, ppo_trainer: PPOTrainer) -> None:
"""Patch PPO trainer saves to avoid TRL's DDP-unsafe checkpoint wrapper.

Expand Down
97 changes: 97 additions & 0 deletions tests/worker/test_ppo_early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Unit tests for PPO KL-based early stopping."""

from typing import Any
from unittest.mock import patch

import pytest

pytest.importorskip("trl", reason="trl not installed (needs --extra training)")

from worker.executors.base_executor import ExecutionError
from worker.executors.ppo_executor import (
PPOExecutor,
_EarlyStopPPOTrainer,
_EarlyStopSignal,
)


def _make_trainer(target_kl: float | None) -> _EarlyStopPPOTrainer:
"""Build a bare ``_EarlyStopPPOTrainer`` without running ``__init__``.

``_EarlyStopPPOTrainer.log`` defers to ``super().log`` (TRL/HF Trainer); the
tests patch that with a no-op so they only exercise our override.
"""
trainer = _EarlyStopPPOTrainer.__new__(_EarlyStopPPOTrainer)
trainer.target_kl = target_kl
return trainer


@patch("trl.trainer.ppo_trainer.PPOTrainer.log", autospec=True)
def test_kl_above_target_raises(_super_log) -> None:
trainer = _make_trainer(target_kl=0.1)
with pytest.raises(_EarlyStopSignal):
trainer.log({"objective/kl": 0.2})


@patch("trl.trainer.ppo_trainer.PPOTrainer.log", autospec=True)
def test_kl_below_target_does_not_raise(_super_log) -> None:
trainer = _make_trainer(target_kl=0.1)
trainer.log({"objective/kl": 0.05})


@patch("trl.trainer.ppo_trainer.PPOTrainer.log", autospec=True)
def test_kl_equal_to_target_does_not_raise(_super_log) -> None:
trainer = _make_trainer(target_kl=0.1)
trainer.log({"objective/kl": 0.1})


@patch("trl.trainer.ppo_trainer.PPOTrainer.log", autospec=True)
def test_missing_kl_key_is_ignored(_super_log) -> None:
trainer = _make_trainer(target_kl=0.1)
trainer.log({"loss": 1.0})


@patch("trl.trainer.ppo_trainer.PPOTrainer.log", autospec=True)
def test_threshold_unset_is_no_op(_super_log) -> None:
"""When ``target_kl is None`` the override is a pure pass-through."""
trainer = _make_trainer(target_kl=None)
trainer.log({"objective/kl": 9.99})


# ---------------------------------------------------------------------------
# _install_kl_early_stopping activation rules
# ---------------------------------------------------------------------------


def _install(training_config: dict[str, Any]) -> _EarlyStopPPOTrainer:
"""Run the installer against a bare trainer and return it."""
executor = PPOExecutor.__new__(PPOExecutor)
trainer = _make_trainer(target_kl=None)
executor._install_kl_early_stopping(trainer, training_config)
return trainer


def test_install_no_op_when_flag_missing() -> None:
trainer = _install({"target_kl": 0.1})
assert trainer.target_kl is None


@pytest.mark.parametrize("flag", [False, "false", "False", 0, "no", "off"])
def test_install_no_op_when_flag_disabled(flag: Any) -> None:
trainer = _install({"early_stopping": flag, "target_kl": 0.1})
assert trainer.target_kl is None


@pytest.mark.parametrize("flag", [True, "true", "True", 1, "yes", "on"])
def test_install_arms_when_flag_enabled(flag: Any) -> None:
trainer = _install({"early_stopping": flag, "target_kl": 0.1})
assert trainer.target_kl == pytest.approx(0.1)


@pytest.mark.parametrize("bad_target", [None, 0, -0.5])
def test_install_rejects_enabled_without_positive_target(bad_target: Any) -> None:
cfg: dict[str, Any] = {"early_stopping": True}
if bad_target is not None:
cfg["target_kl"] = bad_target
with pytest.raises(ExecutionError):
_install(cfg)