From 1e88cedf3586399de8a414e171106491d5dbcf84 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 2 Sep 2025 16:31:40 -0700 Subject: [PATCH] Updates for RefModel post land --- src/forge/actors/__init__.py | 6 +- ...{reference_actor.py => reference_model.py} | 285 ++++++++---------- src/forge/actors/trainer.py | 54 +++- 3 files changed, 178 insertions(+), 167 deletions(-) rename src/forge/actors/{reference_actor.py => reference_model.py} (53%) diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py index 70198120b..54e450cd7 100644 --- a/src/forge/actors/__init__.py +++ b/src/forge/actors/__init__.py @@ -24,9 +24,9 @@ def __getattr__(name): from .replay_buffer import ReplayBuffer return ReplayBuffer - elif name == "TitanRefModel": - from .reference_actor import TitanRefModel + elif name == "ReferenceModel": + from .reference_model import ReferenceModel - return TitanRefModel + return ReferenceModel else: raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/forge/actors/reference_actor.py b/src/forge/actors/reference_model.py similarity index 53% rename from src/forge/actors/reference_actor.py rename to src/forge/actors/reference_model.py index c0b6aad24..d840e0c1c 100644 --- a/src/forge/actors/reference_actor.py +++ b/src/forge/actors/reference_model.py @@ -17,6 +17,7 @@ from typing import Any import torch +from forge.actors.trainer import compute_logprobs, compute_sequence_logprobs from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint @@ -24,7 +25,7 @@ from torch import nn from torchtitan.components.lr_scheduler import LRSchedulersContainer -from torchtitan.config.job_config import Comm, Model, Parallelism +from torchtitan.config.job_config import Checkpoint, Comm, Compile, Model, Parallelism from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig @@ -36,7 +37,7 @@ @dataclass -class TitanRefModel(ForgeActor): +class ReferenceModel(ForgeActor): """ Represents a reference actor leveraging a torchtitan model for execution @@ -46,6 +47,8 @@ class TitanRefModel(ForgeActor): # Refer to titan JobConfig for enabling more ForgeEngine configuration model: Model = field(default_factory=Model) parallelism: Parallelism = field(default_factory=Parallelism) + checkpoint: Checkpoint = field(default_factory=Checkpoint) + compile: Compile = field(default_factory=Compile) # Populated in setup # TODO: Commented out since engine_config parsing extracts from class members @@ -90,7 +93,7 @@ async def setup(self): self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) @endpoint - async def forward(self, request: list[int], response: list[int]) -> torch.Tensor: + async def generate(self, episode: "Episode") -> torch.Tensor: """ Given a request and response tokens, return the log_probability of the token_ids @@ -122,10 +125,8 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor else: # (jackkhuu) Not sure if either context are needed for inference here with self.engine.train_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 with self.engine.maybe_enable_amp: - # Titan Tranformer - logits = model_parts[0](input_ids) + logits = self.forward(input_ids) # Compute logprobs input_ids = input_ids[:, len(response) :] @@ -135,21 +136,12 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor return pred - -# Based on torchtune's grpo -def compute_logprobs( - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 -) -> torch.Tensor: - context_length = logits.shape[1] - input_ids.shape[1] - - # Truncate request logits and drop last - logits = logits[:, context_length - 1 : -1] - - # Compute logprobs - logprobs = torch.log_softmax(logits / temperature, dim=-1) - logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) - - return logprobs + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + """ + Pure forward pass through the model itself + """ + assert len(self.engine.model_parts) == 1, "PP not implemented yet" + return self.engine.model_parts[0](input_ids) # Maintained to keep Old GRPO app prior to full migration off of HF @@ -199,37 +191,6 @@ async def forward(self, token_ids: list[int]) -> torch.Tensor: ) # Remove batch dimension for single response -def compute_sequence_logprobs( - model: torch.nn.Module, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - requires_grad: bool = True, -) -> torch.Tensor: - context_manager = torch.enable_grad() if requires_grad else torch.no_grad() - - with context_manager: - outputs = model(input_ids=input_ids, attention_mask=attention_mask) - logits = outputs.logits - - # Apply log softmax to get log probabilities - log_probs = torch.log_softmax(logits, dim=-1) - - # Extract log probabilities for the actual tokens (excluding the first token for next-token prediction) - shifted_input_ids = input_ids[:, 1:] # Remove first token - shifted_log_probs = log_probs[:, :-1, :] # Remove last logit - - # Gather log probabilities for actual tokens - token_log_probs = torch.gather( - shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) - ).squeeze(-1) - - # Sum log probabilities across sequence (masked by attention) - shifted_attention_mask = attention_mask[:, 1:] - sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1) - - return sequence_log_probs - - """ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Experimental: DO NOT USE (YET) @@ -239,113 +200,113 @@ def compute_sequence_logprobs( """ -@dataclass -class ReferenceActor(ForgeActor): - """ - DO NOT USE (YET) - - Not updated/used; Original plan was to use this for coordination, but - it might be overkil if we can rely on the Service Replicas to handle - the queue. - We MAY need to still do this for DP and batching support - - For now if you think you need this: directly spin up services of the - reference models - """ - - model: Model = field(default_factory=Model) - # parallelism: Parallelism = field(default_factory=Parallelism) - # comm: Comm = field(default_factory=Comm) - - # For RefModel - ref_model: ForgeActor | None = None - device: torch.device | None = None - - # For processing - running: bool = False - queue: deque | None = None - - def __post_init__(self): - """Initializes config types and env variables. - - torchrun normally hands env variables, but we need to do it ourselves - in monarch for now. - - """ - # Instantiate dict fields - for f in fields(self): - attr = getattr(self, f.name) - if isinstance(attr, Mapping): - setattr(self, f.name, f.type(**attr)) - elif not isinstance(attr, f.type): - raise TypeError( - f"{f.name} should be a {f.type} type or a dict like object" - ) - - # This might need to be changed to a distributed friendly container - # We also don't have a traditional scheduler? - self.queue = deque() - - self.rank = current_rank().rank - self.size = math.prod(current_size().values()) - - env = { - "RANK": str(self.rank), - "LOCAL_RANK": str(self.rank), - "LOCAL_WORLD_SIZE": str(self.size), - "GROUP_RANK": str(self.size), - "GROUP_WORLD_SIZE": str(self.size), - "ROLE_RANK": str(self.rank), - "ROLE_WORLD_SIZE": str(self.size), - "ROLE_NAME": "rank", - "WORLD_SIZE": str(self.size), - "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", - } - os.environ.update(env) - - @endpoint - async def setup(self): - engine_config = {f.name: getattr(self, f.name) for f in fields(self)} - self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) - - # Spawn the RefModel - self.ref_model = await spawn_service( - default_service_cfg, - HuggingFaceRefModel, - model_name=self.model.name, - device=self.device, - ) - - # Kick off background processing - self.start_processing() - - def start_processing(self): - """Start the replica's processing loop if not already running.""" - if self._run_task is None or self._run_task.done(): - self._run_task = asyncio.create_task(self.run()) - - @endpoint - async def forward(self, token_ids: list[int]) -> torch.Tensor: - """ - Enque the tokens and await response - """ - fut = asyncio.Future() - self.queue.append((token_ids, fut)) - return await fut - - async def run(self): - """ - Simple loop to pass things along to the ref model - """ - - # TODO: Consider creating a unified base class for this pattern - self.running = True - - while self.running: - request, fut = self.queue.popleft() - model_output = await self.ref_model.forward(request) - fut.set_result(model_output) - - @endpoint - async def stop(self) -> None: - self.running = False +# @dataclass +# class ReferenceActor(ForgeActor): +# """ +# DO NOT USE (YET) + +# Not updated/used; Original plan was to use this for coordination, but +# it might be overkil if we can rely on the Service Replicas to handle +# the queue. +# We MAY need to still do this for DP and batching support + +# For now if you think you need this: directly spin up services of the +# reference models +# """ + +# model: Model = field(default_factory=Model) +# # parallelism: Parallelism = field(default_factory=Parallelism) +# # comm: Comm = field(default_factory=Comm) + +# # For RefModel +# ref_model: ForgeActor | None = None +# device: torch.device | None = None + +# # For processing +# running: bool = False +# queue: deque | None = None + +# def __post_init__(self): +# """Initializes config types and env variables. + +# torchrun normally hands env variables, but we need to do it ourselves +# in monarch for now. + +# """ +# # Instantiate dict fields +# for f in fields(self): +# attr = getattr(self, f.name) +# if isinstance(attr, Mapping): +# setattr(self, f.name, f.type(**attr)) +# elif not isinstance(attr, f.type): +# raise TypeError( +# f"{f.name} should be a {f.type} type or a dict like object" +# ) + +# # This might need to be changed to a distributed friendly container +# # We also don't have a traditional scheduler? +# self.queue = deque() + +# self.rank = current_rank().rank +# self.size = math.prod(current_size().values()) + +# env = { +# "RANK": str(self.rank), +# "LOCAL_RANK": str(self.rank), +# "LOCAL_WORLD_SIZE": str(self.size), +# "GROUP_RANK": str(self.size), +# "GROUP_WORLD_SIZE": str(self.size), +# "ROLE_RANK": str(self.rank), +# "ROLE_WORLD_SIZE": str(self.size), +# "ROLE_NAME": "rank", +# "WORLD_SIZE": str(self.size), +# "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", +# } +# os.environ.update(env) + +# @endpoint +# async def setup(self): +# engine_config = {f.name: getattr(self, f.name) for f in fields(self)} +# self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) + +# # Spawn the RefModel +# self.ref_model = await spawn_service( +# default_service_cfg, +# HuggingFaceRefModel, +# model_name=self.model.name, +# device=self.device, +# ) + +# # Kick off background processing +# self.start_processing() + +# def start_processing(self): +# """Start the replica's processing loop if not already running.""" +# if self._run_task is None or self._run_task.done(): +# self._run_task = asyncio.create_task(self.run()) + +# @endpoint +# async def forward(self, token_ids: list[int]) -> torch.Tensor: +# """ +# Enque the tokens and await response +# """ +# fut = asyncio.Future() +# self.queue.append((token_ids, fut)) +# return await fut + +# async def run(self): +# """ +# Simple loop to pass things along to the ref model +# """ + +# # TODO: Consider creating a unified base class for this pattern +# self.running = True + +# while self.running: +# request, fut = self.queue.popleft() +# model_output = await self.ref_model.forward(request) +# fut.set_result(model_output) + +# @endpoint +# async def stop(self) -> None: +# self.running = False diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4232ca5ca..793cae183 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -12,6 +12,8 @@ from dataclasses import dataclass, field, fields import torch + +from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint from torchtitan.config.job_config import ( ActivationCheckpoint, @@ -30,12 +32,60 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +# Based on torchtune's grpo +# Will updated to match the definition in grpo/main.py after +# https://github.com/meta-pytorch/forge/pull/97 lands +def compute_logprobs( + logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 +) -> torch.Tensor: + context_length = logits.shape[1] - input_ids.shape[1] + + # Truncate request logits and drop last + logits = logits[:, context_length - 1 : -1] + + # Compute logprobs + logprobs = torch.log_softmax(logits / temperature, dim=-1) + logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) + + return logprobs + + +# Maintained to keep Old GRPO app prior to full migration off of HF +def compute_sequence_logprobs( + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + requires_grad: bool = True, +) -> torch.Tensor: + context_manager = torch.enable_grad() if requires_grad else torch.no_grad() + + with context_manager: + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + logits = outputs.logits + + # Apply log softmax to get log probabilities + log_probs = torch.log_softmax(logits, dim=-1) + + # Extract log probabilities for the actual tokens (excluding the first token for next-token prediction) + shifted_input_ids = input_ids[:, 1:] # Remove first token + shifted_log_probs = log_probs[:, :-1, :] # Remove last logit + + # Gather log probabilities for actual tokens + token_log_probs = torch.gather( + shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) + ).squeeze(-1) + + # Sum log probabilities across sequence (masked by attention) + shifted_attention_mask = attention_mask[:, 1:] + sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1) + + return sequence_log_probs + + @dataclass class RLTrainer(ForgeActor): model: Model = field(default_factory=Model)