Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ cython_debug/
slogs/
slurm-*

# DCP checkpoints
model_state_dict/

# Celery stuff
celerybeat-schedule
celerybeat.pid
Expand Down
32 changes: 31 additions & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from forge.actors.policy import Policy
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.torchstore_utils import get_param_key
from forge.actors.trainer import RLTrainer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
Expand Down Expand Up @@ -155,6 +156,23 @@ def simple_grpo_loss(
/ (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()
return loss
loss = self.loss(logprobs, ref_logprobs, advantages, mask)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)

return loss.item()

@endpoint
async def push_weights(self, version: int):
"""Update policy model weights with trainer's current weights."""
start_time = time.perf_counter()
hf_state_dict = self.model.state_dict()
for name, param in hf_state_dict.items():
key = get_param_key(version, name)
await ts.put(key, param)
end_time = time.perf_counter()
self.logger.debug(f"Pushed weights in {end_time - start_time:.2f} seconds")


@dataclass
Expand Down Expand Up @@ -245,7 +263,7 @@ async def main(cfg: DictConfig):
mlogger = get_metric_logger(
"wandb",
freq=1,
project="grpo-training",
project="yuxuanh-grpo-training-debug",
)

# ---- Setup services ---- #
Expand Down Expand Up @@ -351,8 +369,20 @@ async def continuous_training():
loss = await trainer.train_step.choose(inputs, targets)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
start_time = time.perf_counter()
await trainer.push_weights.call(training_step)
mlogger.log(
"push_weights_time/training_step",
time.perf_counter() - start_time,
training_step,
)
start_time = time.perf_counter()
await policy.update_weights.call(training_step)
mlogger.log(
"update_weights_time/training_step",
time.perf_counter() - start_time,
training_step,
)

print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
Expand Down
97 changes: 64 additions & 33 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from __future__ import annotations

import asyncio

import logging
import os
import sys
Expand All @@ -19,8 +18,21 @@
import torch
import torch.distributed.checkpoint as dcp
import torchstore as ts

from forge.actors.torchstore_utils import (
extract_param_name,
get_param_key,
get_param_prefix,
)

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.data.sharding import VLLMSharding
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt

from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore.state_dict_utils import DELIM
from vllm.config import VllmConfig

from vllm.engine.arg_utils import EngineArgs
Expand All @@ -43,15 +55,7 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt

from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig

logger: logging.Logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -388,15 +392,6 @@ async def update_weights(self, policy_version: int):
self.policy_version = policy_version
logger.info(f"Weight update completed (now v{self.policy_version})")

@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
"""Get the current model parameters. Only for testing purposes."""
val_mesh = await self.policy_worker._get_model_params.call()
sharded_state_dicts = {}
for idx, val in val_mesh.items():
sharded_state_dicts[idx["gpus"]] = val
return sharded_state_dicts

@endpoint
async def get_version(self) -> int:
"""Get the current policy version."""
Expand All @@ -406,6 +401,18 @@ async def get_version(self) -> int:
async def stop(self):
self.running = False

@endpoint
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[Policy] start saving model parameters before update for testing")
await self.policy_worker._test_save_model_params.call()

@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[Policy] start validating model parameters post update")
return await self.policy_worker._test_validate_model_params.call(validate_fn)

def _to_completions(self, request_output: RequestOutput) -> list[Completion]:
"""Convert a RequestOutput to a list of Completion objects."""
completions = []
Expand Down Expand Up @@ -449,6 +456,9 @@ class PolicyWorker(ForgeActor):
state_dict_key: str = "model_state_dict"
use_dcp: bool = True

# used for tesing purposes only
_test_prev_params = {}

@endpoint
async def setup(self):
# TODO: remove ["gpus"] when monarch implements a flat rank
Expand Down Expand Up @@ -498,12 +508,23 @@ async def _load_tensor_parallel_state_dict(
@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
key = f"{self.state_dict_key}{DELIM}{version}"
model = self.worker.model_runner.model
current_state_dict = model.state_dict()
start = time.time()
await self._load_tensor_parallel_state_dict(current_state_dict, version)
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds")
prefix = get_param_prefix(version)
self.logger.debug(f"{prefix=}")
matching_keys = await ts.keys(prefix)
self.logger.debug(f"{matching_keys=}")
# TODO: find a way to save the original huggingface parameter names.
hf_names = [extract_param_name(key) for key in matching_keys]
self.logger.debug(f"{hf_names=}")
loaded_weights = set()
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
for name in hf_names:
param = await ts.get(get_param_key(version, name))
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
self.logger.info(f"Updated {len(loaded_weights)} parameters")

@endpoint
async def setup_kv_cache(self):
Expand Down Expand Up @@ -536,15 +557,25 @@ async def setup_kv_cache(self):
return kv_cache_config

@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
model = self.worker.model_runner.model
state_dict = {}
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info(
"[PolicyWorker] start saving model parameters before update for testing"
)
for name, param in self.worker.model_runner.model.named_parameters():
self._test_prev_params[name] = param.detach().cpu()
logger.info(
"[PolicyWorker] finished saving model parameters, len = %d",
len(self._test_prev_params),
)

for name, param in model.named_parameters():
if "layers.0" not in name:
continue
state_dict[name] = param.cpu().detach()
return state_dict
@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[PolicyWorker] start validating model parameters post update")
return validate_fn(
self._test_prev_params, self.worker.model_runner.model, logger
)

def setup_worker(self):
"""Build and Instantiate vLLM worker"""
Expand Down
19 changes: 19 additions & 0 deletions src/forge/actors/torchstore_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

KEY_DELIM = "."


def get_param_prefix(policy_version: int) -> str:
return f"policy_ver_{policy_version}"


def get_param_key(policy_version: int, name: str) -> str:
return f"policy_ver_{policy_version}{KEY_DELIM}{name}"


def extract_param_name(key: str) -> str:
return KEY_DELIM.join(key.split(KEY_DELIM)[1:])
29 changes: 26 additions & 3 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
import torch.distributed.checkpoint as dcp
import torchstore as ts

from forge.actors.torchstore_utils import (
extract_param_name,
get_param_key,
get_param_prefix,
)

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device

from monarch.actor import current_rank, current_size, endpoint
from torch import Tensor
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
Expand All @@ -36,9 +45,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -290,6 +296,23 @@ async def push_weights(self, policy_version: int) -> None:

logger.debug(f"Pushed weights to {key} in {end_time - start_time:.2f} seconds")

@endpoint
async def push_weights_hf_nonsharded(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format, non-sharded."""
if "model" not in self.engine.checkpointer.states:
raise RuntimeError("Model state not found in checkpointer state")

sd = self.engine.checkpointer.states["model"].state_dict()
flattened_state_dict, _ = flatten_state_dict(sd)
if self.engine.checkpointer.sd_adapter is None:
raise RuntimeError(
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
)
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)

@endpoint
async def cleanup(self) -> None:
if self.engine.checkpointer:
Expand Down
5 changes: 5 additions & 0 deletions tests/integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Loading
Loading