diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7b90fed56..85b9a7d7c 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -17,6 +17,7 @@ import torchstore as ts from datasets import load_dataset from forge.actors.policy import Policy +from forge.actors.reference_model import ReferenceModel # noqa: F401 from forge.actors.replay_buffer import ReplayBuffer from forge.actors.trainer import _qwen3_hf_to_vllm from forge.cli.config import parse @@ -30,6 +31,7 @@ from omegaconf import DictConfig from torch import nn from torchstore.state_dict_utils import DELIM +from torchtitan.config.job_config import Model as TitanJobModelConfig from transformers import AutoModelForCausalLM from vllm.transformers_utils.tokenizer import get_tokenizer @@ -330,6 +332,7 @@ async def pad_token(self): async def main(cfg: DictConfig): """Main GRPO training loop with rollout and training processes.""" + titan_model = TitanJobModelConfig(name="qwen3", flavor="1.7B") # Get parameters from config with fallbacks group_size = cfg.group_size model = cfg.model @@ -381,6 +384,11 @@ async def main(cfg: DictConfig): RefModel, model_name=model, ), + # spawn_service( + # ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), + # ReferenceModel, + # model=titan_model, + # ), spawn_service( ServiceConfig(**cfg.reward_actor.service), RewardActor, diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py index 70198120b..54e450cd7 100644 --- a/src/forge/actors/__init__.py +++ b/src/forge/actors/__init__.py @@ -24,9 +24,9 @@ def __getattr__(name): from .replay_buffer import ReplayBuffer return ReplayBuffer - elif name == "TitanRefModel": - from .reference_actor import TitanRefModel + elif name == "ReferenceModel": + from .reference_model import ReferenceModel - return TitanRefModel + return ReferenceModel else: raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/forge/actors/reference_actor.py b/src/forge/actors/reference_actor.py deleted file mode 100644 index 28e0f9814..000000000 --- a/src/forge/actors/reference_actor.py +++ /dev/null @@ -1,377 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import asyncio -import logging -import math -import os - -from collections import deque -from collections.abc import Mapping -from dataclasses import dataclass, field, fields - -from typing import Any - -import torch -from monarch.actor import current_rank, current_size, endpoint -from omegaconf import DictConfig, OmegaConf -from torch import nn - -from torchtitan.components.lr_scheduler import LRSchedulersContainer -from torchtitan.config.job_config import Comm, Model, Parallelism -from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.experiments.forge.engine import ForgeEngine -from torchtitan.experiments.forge.job_config import ForgeJobConfig -from transformers import AutoModelForCausalLM - -from forge.controller import ForgeActor - - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -@dataclass -class TitanRefModel(ForgeActor): - """ - Represents a reference actor leveraging a torchtitan model for execution - - Intended for generating reference_logprobs - for example in KL Divergence - """ - - # Refer to titan JobConfig for enabling more ForgeEngine configuration - model: Model = field(default_factory=Model) - parallelism: Parallelism = field(default_factory=Parallelism) - - # Populated in setup - # TODO: Commented out since engine_config parsing extracts from class members - # engine: ForgeEngine | None = None - - def __post_init__(self): - """Initializes config types and env variables.""" - # Instantiate dict fields - for f in fields(self): - attr = getattr(self, f.name) - if isinstance(attr, Mapping): - setattr(self, f.name, f.type(**attr)) - elif not isinstance(attr, f.type): - raise TypeError( - f"{f.name} should be a {f.type} type or a dict like object" - ) - - """ - torchrun normally hands env variables, but we need to do it ourselves - in monarch for now. - """ - self.rank = current_rank().rank - self.size = math.prod(current_size().values()) - - env = { - "RANK": str(self.rank), - "LOCAL_RANK": str(self.rank), - "LOCAL_WORLD_SIZE": str(self.size), - "GROUP_RANK": str(self.size), - "GROUP_WORLD_SIZE": str(self.size), - "ROLE_RANK": str(self.rank), - "ROLE_WORLD_SIZE": str(self.size), - "ROLE_NAME": "rank", - "WORLD_SIZE": str(self.size), - "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", - } - os.environ.update(env) - - @endpoint - async def setup(self): - engine_config = {f.name: getattr(self, f.name) for f in fields(self)} - self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) - - @endpoint - async def forward(self, request: list[int], response: list[int]) -> torch.Tensor: - """ - Given a request and response tokens, return the log_probability of the - token_ids, shape (completion_len, ) - - """ - model_parts = self.engine.model_parts - parallel_dims = self.engine.parallel_dims - - # Use provided token_ids directly - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.tensor( - request + response, dtype=torch.long, device=device - ).unsqueeze(0) - - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) - if parallel_dims.cp_enabled - else None - ) - - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet") - else: - # (jackkhuu) Not sure if either context are needed for inference here - with self.engine.train_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 - with self.engine.maybe_enable_amp: - # Titan Tranformer - logits = model_parts[0](input_ids) - - # Compute logprobs - input_ids = input_ids[:, len(request) :] - # (bsz=1, completion_len) - logprobs = compute_logprobs(logits, input_ids) - # (completion_len, ) - return logprobs.squeeze(0) - - return pred - - -# Based on torchtune's grpo -def compute_logprobs( - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 -) -> torch.Tensor: - """ - Compute log probs of the completion input_ids given the logits of the whole sequence. - Warning: only works if all prompts in the batch have the same length. TODO: support variable length prompts. - - Args: - logits (torch.Tensor): (batch_size, seq_len, vocab_size), the logits output from the model. - input_ids (torch.Tensor): (batch_size, completion_len), the token ids for the completion. - - Returns: - torch.Tensor: (batch_size, completion_len), the log probabilities of the completion tokens. - - Raises: - ValueError: If the inferred context length is less than or equal to 0. - """ - context_len = logits.shape[1] - input_ids.shape[1] - completion_len = input_ids.shape[1] - if context_len <= 0: - raise ValueError( - "Context length must be greater than 0. Otherwise the probability of the first token is undefined." - ) - - # (bsz, completion_len, vocab_size) - logits = logits[:, context_len - 1 : -1, :] - assert logits.shape == ( - input_ids.shape[0], - completion_len, - logits.shape[-1], - ), f"logits shape incorrect, {logits.shape=}, {input_ids.shape=}, {logits.shape[-1]=}" - token_logprobs = torch.log_softmax(logits / temperature, dim=-1) - # (bsz, completion_len, 1) - logprobs = torch.gather(token_logprobs, 2, input_ids.unsqueeze(-1)) - # (bsz, completion_len) - logprobs = logprobs.squeeze(-1) - - return logprobs - - -# Maintained to keep Old GRPO app prior to full migration off of HF -class HuggingFaceRefModel(ForgeActor): - """ - Represents a reference actor leveraging HuggingFace for execution - """ - - def __init__(self, model_name, device: torch.device | None = None): - super().__init__() - self.model_name = model_name - - # Set device - if device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self.device = device - - # Initialize model and tokenizer - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, - trust_remote_code=True, - ).to(self.device) - - # Set model to eval mode for reference computations - self.model.eval() - - self.logger.info(f"Model initialized on {self.device}") - - @endpoint - async def forward(self, token_ids: list[int]) -> torch.Tensor: - # Use provided token_ids directly - input_ids = ( - torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device) - ) - # Create attention mask of all 1s since we have actual tokens (no padding) - attention_mask = torch.ones_like(input_ids).to(self.device) - - # Compute log probabilities using shared utility function - sequence_log_probs = compute_sequence_logprobs( - self.model, input_ids, attention_mask, requires_grad=False - ) - - return ( - sequence_log_probs.squeeze() - ) # Remove batch dimension for single response - - -def compute_sequence_logprobs( - model: torch.nn.Module, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - requires_grad: bool = True, -) -> torch.Tensor: - context_manager = torch.enable_grad() if requires_grad else torch.no_grad() - - with context_manager: - outputs = model(input_ids=input_ids, attention_mask=attention_mask) - logits = outputs.logits - - # Apply log softmax to get log probabilities - log_probs = torch.log_softmax(logits, dim=-1) - - # Extract log probabilities for the actual tokens (excluding the first token for next-token prediction) - shifted_input_ids = input_ids[:, 1:] # Remove first token - shifted_log_probs = log_probs[:, :-1, :] # Remove last logit - - # Gather log probabilities for actual tokens - token_log_probs = torch.gather( - shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) - ).squeeze(-1) - - # Sum log probabilities across sequence (masked by attention) - shifted_attention_mask = attention_mask[:, 1:] - sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1) - - return sequence_log_probs - - -""" -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Experimental: DO NOT USE (YET) - -ReferenceActor: Coordinate requests to reference models -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -""" - - -@dataclass -class ReferenceActor(ForgeActor): - """ - DO NOT USE (YET) - - Not updated/used; Original plan was to use this for coordination, but - it might be overkil if we can rely on the Service Replicas to handle - the queue. - We MAY need to still do this for DP and batching support - - For now if you think you need this: directly spin up services of the - reference models - """ - - model: Model = field(default_factory=Model) - # parallelism: Parallelism = field(default_factory=Parallelism) - # comm: Comm = field(default_factory=Comm) - - # For RefModel - ref_model: ForgeActor | None = None - device: torch.device | None = None - - # For processing - running: bool = False - queue: deque | None = None - - def __post_init__(self): - """Initializes config types and env variables. - - torchrun normally hands env variables, but we need to do it ourselves - in monarch for now. - - """ - # Instantiate dict fields - for f in fields(self): - attr = getattr(self, f.name) - if isinstance(attr, Mapping): - setattr(self, f.name, f.type(**attr)) - elif not isinstance(attr, f.type): - raise TypeError( - f"{f.name} should be a {f.type} type or a dict like object" - ) - - # This might need to be changed to a distributed friendly container - # We also don't have a traditional scheduler? - self.queue = deque() - - self.rank = current_rank().rank - self.size = math.prod(current_size().values()) - - env = { - "RANK": str(self.rank), - "LOCAL_RANK": str(self.rank), - "LOCAL_WORLD_SIZE": str(self.size), - "GROUP_RANK": str(self.size), - "GROUP_WORLD_SIZE": str(self.size), - "ROLE_RANK": str(self.rank), - "ROLE_WORLD_SIZE": str(self.size), - "ROLE_NAME": "rank", - "WORLD_SIZE": str(self.size), - "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", - } - os.environ.update(env) - - @endpoint - async def setup(self): - engine_config = {f.name: getattr(self, f.name) for f in fields(self)} - self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) - - # Spawn the RefModel - self.ref_model = await spawn_service( - default_service_cfg, - HuggingFaceRefModel, - model_name=self.model.name, - device=self.device, - ) - - # Kick off background processing - self.start_processing() - - def start_processing(self): - """Start the replica's processing loop if not already running.""" - if self._run_task is None or self._run_task.done(): - self._run_task = asyncio.create_task(self.run()) - - @endpoint - async def forward(self, token_ids: list[int]) -> torch.Tensor: - """ - Enque the tokens and await response - """ - fut = asyncio.Future() - self.queue.append((token_ids, fut)) - return await fut - - async def run(self): - """ - Simple loop to pass things along to the ref model - """ - - # TODO: Consider creating a unified base class for this pattern - self.running = True - - while self.running: - request, fut = self.queue.popleft() - model_output = await self.ref_model.forward(request) - fut.set_result(model_output) - - @endpoint - async def stop(self) -> None: - self.running = False diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py new file mode 100644 index 000000000..5c128fa31 --- /dev/null +++ b/src/forge/actors/reference_model.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import math +import os + +from collections.abc import Mapping +from dataclasses import dataclass, field, fields + +import torch +from monarch.actor import current_rank, current_size, endpoint + +from torchtitan.config.job_config import Checkpoint, Compile, Model, Parallelism +from torchtitan.distributed import utils as dist_utils +from torchtitan.experiments.forge.engine import ForgeEngine +from torchtitan.experiments.forge.job_config import ForgeJobConfig + +from forge.controller import ForgeActor + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@dataclass +class ReferenceModel(ForgeActor): + """ + Represents a reference actor leveraging a torchtitan model for execution + + Intended for generating reference_logprobs - for example in KL Divergence + """ + + # Refer to titan JobConfig for enabling more ForgeEngine configuration + model: Model = field(default_factory=Model) + parallelism: Parallelism = field(default_factory=Parallelism) + checkpoint: Checkpoint = field(default_factory=Checkpoint) + compile: Compile = field(default_factory=Compile) + + # Populated in setup + # TODO: Commented out since engine_config parsing extracts from class members + # engine: ForgeEngine | None = None + + def __post_init__(self): + """Initializes config types and env variables.""" + # Instantiate dict fields + for f in fields(self): + attr = getattr(self, f.name) + if isinstance(attr, Mapping): + setattr(self, f.name, f.type(**attr)) + elif not isinstance(attr, f.type): + raise TypeError( + f"{f.name} should be a {f.type} type or a dict like object" + ) + + """ + torchrun normally hands env variables, but we need to do it ourselves + in monarch for now. + """ + self.rank = current_rank().rank + self.size = math.prod(current_size().values()) + + env = { + "RANK": str(self.rank), + "LOCAL_RANK": str(self.rank), + "LOCAL_WORLD_SIZE": str(self.size), + "GROUP_RANK": str(self.size), + "GROUP_WORLD_SIZE": str(self.size), + "ROLE_RANK": str(self.rank), + "ROLE_WORLD_SIZE": str(self.size), + "ROLE_NAME": "rank", + "WORLD_SIZE": str(self.size), + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + } + os.environ.update(env) + + @endpoint + async def setup(self): + engine_config = {f.name: getattr(self, f.name) for f in fields(self)} + self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) + + @endpoint + async def forward(self, episode: "Episode") -> torch.Tensor: + """ + Given an episode, return the log_probability of the + token_ids, shape (completion_len, ) + + """ + req, res = episode.request_tensor, episode.response_tensor + model_parts = self.engine.model_parts + parallel_dims = self.engine.parallel_dims + + # Use provided token_ids directly + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # input_ids = torch.tensor( + # request + response, dtype=torch.long, device=device + # ).unsqueeze(0) + input_ids = torch.cat([req, res]).to(device).unsqueeze(0) + + optional_context_parallel_ctx = ( + dist_utils.create_context_parallel_ctx( + cp_mesh=parallel_dims.world_mesh["cp"], + cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], + cp_seq_dims=[1, 1] + [0 for _ in model_parts], + cp_no_restore_buffers={inputs, labels}, + cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + ) + if parallel_dims.cp_enabled + else None + ) + + if parallel_dims.pp_enabled: + raise NotImplementedError("PP not implemented yet") + else: + # (jackkhuu) Not sure if either context are needed for inference here + with self.engine.train_context(optional_context_parallel_ctx): + assert len(model_parts) == 1 + with self.engine.maybe_enable_amp: + # Titan Tranformer + logits = model_parts[0](input_ids) + + # Compute logprobs + input_ids = input_ids[:, len(req) :] + # (bsz=1, completion_len) + logprobs = compute_logprobs(logits, input_ids) + # (completion_len, ) + return logprobs.squeeze(0) + + return pred + + +# Based on torchtune's grpo +def compute_logprobs( + logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 +) -> torch.Tensor: + """ + Compute log probs of the completion input_ids given the logits of the whole sequence. + Warning: only works if all prompts in the batch have the same length. TODO: support variable length prompts. + + Args: + logits (torch.Tensor): (batch_size, seq_len, vocab_size), the logits output from the model. + input_ids (torch.Tensor): (batch_size, completion_len), the token ids for the completion. + + Returns: + torch.Tensor: (batch_size, completion_len), the log probabilities of the completion tokens. + + Raises: + ValueError: If the inferred context length is less than or equal to 0. + """ + context_len = logits.shape[1] - input_ids.shape[1] + completion_len = input_ids.shape[1] + if context_len <= 0: + raise ValueError( + "Context length must be greater than 0. Otherwise the probability of the first token is undefined." + ) + + # (bsz, completion_len, vocab_size) + logits = logits[:, context_len - 1 : -1, :] + assert logits.shape == ( + input_ids.shape[0], + completion_len, + logits.shape[-1], + ), f"logits shape incorrect, {logits.shape=}, {input_ids.shape=}, {logits.shape[-1]=}" + token_logprobs = torch.log_softmax(logits / temperature, dim=-1) + # (bsz, completion_len, 1) + logprobs = torch.gather(token_logprobs, 2, input_ids.unsqueeze(-1)) + # (bsz, completion_len) + logprobs = logprobs.squeeze(-1) + + return logprobs diff --git a/tests/unit_tests/test_reference_actor.py b/tests/unit_tests/test_reference_actor.py index 403da7169..7052b5782 100644 --- a/tests/unit_tests/test_reference_actor.py +++ b/tests/unit_tests/test_reference_actor.py @@ -16,7 +16,7 @@ def _import_error(): try: - import forge.actors.reference_actor # noqa: F401 + import forge.actors.reference_model # noqa: F401 return False except Exception: @@ -32,7 +32,7 @@ class TestComputeLogprobs(unittest.TestCase): ) def test_compute_logprobs_basic(self): """Test basic logprobs computation.""" - from forge.actors.reference_actor import compute_logprobs + from forge.actors.reference_model import compute_logprobs batch_size = 1 seq_len = 5 @@ -57,7 +57,7 @@ def test_compute_logprobs_basic(self): ) def test_compute_logprobs_with_temperature(self): """Test logprobs computation with temperature scaling.""" - from forge.actors.reference_actor import compute_logprobs + from forge.actors.reference_model import compute_logprobs batch_size = 1 seq_len = 5 @@ -82,7 +82,7 @@ def test_compute_logprobs_with_temperature(self): ) def test_compute_logprobs_single_token(self): """Test logprobs computation with single token response.""" - from forge.actors.reference_actor import compute_logprobs + from forge.actors.reference_model import compute_logprobs batch_size = 1 seq_len = 5 @@ -103,7 +103,7 @@ def test_compute_logprobs_single_token(self): ) def test_compute_logprobs_empty_response(self): """Test logprobs computation with empty response.""" - from forge.actors.reference_actor import compute_logprobs + from forge.actors.reference_model import compute_logprobs batch_size = 1 seq_len = 5 @@ -123,7 +123,7 @@ def test_compute_logprobs_empty_response(self): ) def test_compute_logprobs_empty_prompt(self): """Test logprobs computation with empty prompt.""" - from forge.actors.reference_actor import compute_logprobs + from forge.actors.reference_model import compute_logprobs batch_size = 1 vocab_size = 1000