-
Notifications
You must be signed in to change notification settings - Fork 16
Description
TLDR: We can’t (cleanly) checkpoint everything yet (e.g., dataloader, replay buffer, RNG, etc.) because Titan’s checkpointer lives inside trainer.engine
and exclusively controls the step-<N>
folders, while those other pieces live in separate actors. There’s no safe, centralized way to co-write into the same step right now, so we will need to deferring full multi-component checkpointing until after PTC.
What does work today: with #444, we can enable saving and resuming model weights + optimizer + LR scheduler via Titan’s checkpointer.
Scope: This RFC only discuss the challenges in saving the checkpoint. For loading, the discussion can be referred to #425
1) Context (today’s flow)
We spin up components in main
:
(
dataloader,
policy,
trainer, # has ForgeEngine
replay_buffer,
compute_advantages,
ref_model,
reward_actor,
) = await asyncio.gather(...)
Model checkpointing is delegated to TorchTitan via the trainer’s engine:
trainer
creates the engine and loads a checkpoint:
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.step)
- Every training step (in
train_step
), the trainer asks Titan to save:
self.engine.checkpointer.save(
curr_step=self.step,
last_step=self.step == self.num_training_steps,
)
Titan’s checkpointer writes model weights, optimizer state, LR schedulers into:
<folder>/step-<N>/__0_0.distcp
For example
./checkpoint/step-100/__0_0.distcp
Per issue #362, we also want to save/load the states for:
- data step,
- replay buffer data,
- RNG states,
- etc.
2) The problems
Problem 1: Step-folder ownership
We have Titan-owned directory per saving step (e.g., step-200
) created from inside trainer.engine.checkpointer
. Other actors (dataloader, replay buffer) do not have access to the trainer’s engine or to Titan’s private folder-naming method. That leaves us with two awkward choices:
-
Two folders per step
checkpoint/ step-200/ # Titan __0_0.distcp step-200-other/ # Ours dataloader.json replay_buffer.bin rng.json
Downsides: clunky UX, hard to purge/retain atomically, and easy to drift.
-
Single folder per step (preferred)
To co-locate our files insidestep-200/
, we must either:- call Titan’s private
_create_checkpoint_id
to learn the folder name, but other components (e.g.,dataloader
) doesn't have anengine
. - re-implement a look-alike function and hope it never diverges.
- (preferred) Add a
path
parameter to thecheckpointer.save
.
- call Titan’s private
Problem 2: No unified checkpoint scope
Currently, non-model states (like dataloader
, replay_buffer
, RNG, etc.) live in separate actors/services (e.g., dataloder
), and the trainer
, which owns Titan’s checkpointer
, only manages model, optimizer, and LR scheduler.
Because Titan’s checkpointing is embedded inside trainer.engine.checkpointer
, there’s no single coordinator that can write all components into the same step-<N>
folder in a clean, synchronized way.
3) Proposal for Problem 2
Option 1: Make trainer
own all other components
class RLTrainer:
self.dataloader = ...
self.replay_buffer = ...
self.rng = ...
This is very fast to implement. However, it introduces heavy coupling, breaks actor/service boundaries, hurts scalability and reuse.
Option 2: Reimplement checkpointing ourselves
Write our own model/optim/lr checkpointing.
It gives us full control over layout and atomicity.
Downside
- Rebuilding the hardest part (distributed model/optim/lr, DCP, async, FT).
- High risk, high effort; guaranteed divergence from Titan.
Option 3: Coordinator layered above current actors
Introduce a light Checkpoint Coordinator that:
- calls Titan to save the model/lr/optimizers to specified path (requires one small API addition to Titan save:
path=
), - then asks each actor to
state_dict()
and writes to the same folder (e.g.,step-200/dataloader.json
, etc.), - on load, after Titan resolves the step it will load, coordinator tries to load each components' states by calling their
load_state()
.
class CheckpointCoordinator:
def __init__(self):
self._trainer: RLTrainer = None
self._components: Dict[str, ForgeActor | SeverceInterface] = {}
def set_trainer(self, trainer:RLTrainer):
self._trainer = trainer
def register(self, name, comp: ForgeActor | SeverceInterface):
self._components[name] = comp
async def save(self, step: int, path: str):
path = get_path(folder, step)
if self._trainer:
self._trainer.engine.checkpointer.save(path = path)
for name, comp in self._components.items():
states = comp.state_dict()
dump_json(states, f"{path}/{name}.json")
async def load(self, step: int, path: str):
...
The changes required in grpo/main
:
coord = CheckpointCoordinator()
coord.set_trainer(trainer)
coord.register("dataloader", dataloader)
coord.register("replay_buffer", replay_buffer)
...
await coord.load(step, path=checkpoint_folder)
async def continuous_training():
...
await coord.save(step, path=checkpoint_folder)
It is a relatively easy change and keeps most of the existing wheels.
Downside
- Nested structure: we still do
coord.save
which callsself._trainer.engine.checkpointer.save
. - Slightly specific to our the existing grpo script. May have generalizability issue in the future.
Option 4: Standalone ForgeCheckpointManager
Longer-term, we could create a standalone manager ForgeCheckpointManager
that inherits Titan’s CheckpointManager
and orchestrates both Titan and forge's components in one save()/load()
call. Actors register their export_state/import_state
with this one object; main
calls just it.
Open question: where does ForgeCheckpointManager
live if the engine stays inside the trainer? And how does it read/write model/optim/lr state without re-nesting the trainer or breaking actor decoupling?