Skip to content

RFC: Checkpointing Beyond Model Weights (Why we can’t do it now, and what we’ll do next) #433

@DNXie

Description

@DNXie

TLDR: We can’t (cleanly) checkpoint everything yet (e.g., dataloader, replay buffer, RNG, etc.) because Titan’s checkpointer lives inside trainer.engine and exclusively controls the step-<N> folders, while those other pieces live in separate actors. There’s no safe, centralized way to co-write into the same step right now, so we will need to deferring full multi-component checkpointing until after PTC.

What does work today: with #444, we can enable saving and resuming model weights + optimizer + LR scheduler via Titan’s checkpointer.

Scope: This RFC only discuss the challenges in saving the checkpoint. For loading, the discussion can be referred to #425


1) Context (today’s flow)

We spin up components in main:

(
    dataloader,     
    policy,       
    trainer,        # has ForgeEngine
    replay_buffer,  
    compute_advantages,  
    ref_model,       
    reward_actor,    
) = await asyncio.gather(...)

Model checkpointing is delegated to TorchTitan via the trainer’s engine:

  • trainer creates the engine and loads a checkpoint:
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.step)
  • Every training step (in train_step), the trainer asks Titan to save:
self.engine.checkpointer.save(
    curr_step=self.step,
    last_step=self.step == self.num_training_steps,
)

Titan’s checkpointer writes model weights, optimizer state, LR schedulers into:

<folder>/step-<N>/__0_0.distcp

For example

./checkpoint/step-100/__0_0.distcp

Per issue #362, we also want to save/load the states for:

  • data step,
  • replay buffer data,
  • RNG states,
  • etc.

2) The problems

Problem 1: Step-folder ownership

We have Titan-owned directory per saving step (e.g., step-200) created from inside trainer.engine.checkpointer. Other actors (dataloader, replay buffer) do not have access to the trainer’s engine or to Titan’s private folder-naming method. That leaves us with two awkward choices:

  1. Two folders per step

    checkpoint/
      step-200/          # Titan
        __0_0.distcp
      step-200-other/    # Ours
        dataloader.json
        replay_buffer.bin
        rng.json
    

    Downsides: clunky UX, hard to purge/retain atomically, and easy to drift.

  2. Single folder per step (preferred)
    To co-locate our files inside step-200/, we must either:

    • call Titan’s private _create_checkpoint_id to learn the folder name, but other components (e.g., dataloader) doesn't have an engine.
    • re-implement a look-alike function and hope it never diverges.
    • (preferred) Add a path parameter to the checkpointer.save.

Problem 2: No unified checkpoint scope

Currently, non-model states (like dataloader, replay_buffer, RNG, etc.) live in separate actors/services (e.g., dataloder), and the trainer, which owns Titan’s checkpointer, only manages model, optimizer, and LR scheduler.

Because Titan’s checkpointing is embedded inside trainer.engine.checkpointer, there’s no single coordinator that can write all components into the same step-<N> folder in a clean, synchronized way.


3) Proposal for Problem 2

Option 1: Make trainer own all other components

class RLTrainer:
    self.dataloader = ...
    self.replay_buffer = ...
    self.rng = ...

This is very fast to implement. However, it introduces heavy coupling, breaks actor/service boundaries, hurts scalability and reuse.

Option 2: Reimplement checkpointing ourselves

Write our own model/optim/lr checkpointing.

It gives us full control over layout and atomicity.

Downside

  • Rebuilding the hardest part (distributed model/optim/lr, DCP, async, FT).
  • High risk, high effort; guaranteed divergence from Titan.

Option 3: Coordinator layered above current actors

Introduce a light Checkpoint Coordinator that:

  • calls Titan to save the model/lr/optimizers to specified path (requires one small API addition to Titan save: path=),
  • then asks each actor to state_dict() and writes to the same folder (e.g., step-200/dataloader.json, etc.),
  • on load, after Titan resolves the step it will load, coordinator tries to load each components' states by calling their load_state().
class CheckpointCoordinator:
    def __init__(self):
        self._trainer: RLTrainer = None        
        self._components: Dict[str, ForgeActor | SeverceInterface] = {}        

    def set_trainer(self, trainer:RLTrainer):
        self._trainer = trainer

    def register(self, name, comp: ForgeActor | SeverceInterface):
        self._components[name] = comp

    async def save(self, step: int, path: str):
        path = get_path(folder, step)
        if self._trainer:
            self._trainer.engine.checkpointer.save(path = path)  
        for name, comp in self._components.items():
            states = comp.state_dict()
            dump_json(states, f"{path}/{name}.json")

    async def load(self, step: int, path: str):
        ...

The changes required in grpo/main:

coord = CheckpointCoordinator()
coord.set_trainer(trainer) 

coord.register("dataloader", dataloader)
coord.register("replay_buffer", replay_buffer)
...

await coord.load(step, path=checkpoint_folder)

async def continuous_training():
     ...
     await coord.save(step, path=checkpoint_folder)

It is a relatively easy change and keeps most of the existing wheels.

Downside

  • Nested structure: we still do coord.save which calls self._trainer.engine.checkpointer.save.
  • Slightly specific to our the existing grpo script. May have generalizability issue in the future.

Option 4: Standalone ForgeCheckpointManager

Longer-term, we could create a standalone manager ForgeCheckpointManager that inherits Titan’s CheckpointManager and orchestrates both Titan and forge's components in one save()/load() call. Actors register their export_state/import_state with this one object; main calls just it.

Open question: where does ForgeCheckpointManager live if the engine stays inside the trainer? And how does it read/write model/optim/lr state without re-nesting the trainer or breaking actor decoupling?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions