Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 75 additions & 57 deletions apps/grpo/main.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it out of scope to move the vllm processing loop within the policy instead of having the start that in main.py?

Copy link
Contributor Author

@allenwang28 allenwang28 Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? yeah it should have been moved in this PR

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import asyncio
import logging
import time
from dataclasses import dataclass
from typing import Callable
Expand All @@ -14,12 +15,15 @@
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, spawn_service
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def compute_sequence_logprobs(
model: torch.nn.Module,
Expand Down Expand Up @@ -314,18 +318,18 @@ async def forward(self, token_ids: list[int]) -> torch.Tensor:
class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

def __init__(self, *args, **kwargs):
def __init__(
self, path: str, config_name: str, split: str, streaming: bool, **kwargs
):
super().__init__()
self._setup_dataset(*args, **kwargs)

def _setup_dataset(self, *args, **kwargs):
def gsm8k_to_messages(sample):
question = sample["question"]
full_answer: str = sample["answer"]
answer = full_answer.split("#### ")[1]
return {"question": question, "answer": answer}

ds = load_dataset(*args, **kwargs)
ds = load_dataset(path, config_name, split=split, streaming=streaming)
ds = ds.map(gsm8k_to_messages)
ds = ds.shuffle()
self._iterator = iter(ds)
Expand All @@ -351,66 +355,69 @@ async def main():
)

# ---- Setup services ---- #
policy = await spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Policy,
PolicyConfig(
num_workers=1,
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
(
dataloader,
policy,
trainer,
replay_buffer,
compute_advantages,
ref_model,
reward_actor,
) = await asyncio.gather(
Comment on lines +359 to +366
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice optimization

Will note that this is harder to read and more error prone if the order gets changed or services list mutated. Not sure if there's a way to get our cake and eat it too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm we could do something like:

dataloader_task = spawn_service(...) # don't await yet
policy_task = spawn_service(...)

then do the bulk await at the end:

dataloader, policy = await asyncio.gather(...)

but I'm not sure if it fully solves the problem

spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
DatasetActor,
path="openai/gsm8k",
config_name="main",
split="train",
streaming=True,
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Policy,
config=PolicyConfig(
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(
num_samples=group_size, max_tokens=16
),
),
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Trainer,
learning_rate=1e-5,
beta=0.1,
model_name=model,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ReplayBuffer,
batch_size=4,
max_policy_age=1,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ComputeAdvantages,
gamma=0.99,
lambda_=0.95,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
RefModel,
model_name=model,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
RewardActor,
reward_functions=[MathReward(), ThinkingReward()],
),
)

trainer = await spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Trainer,
learning_rate=1e-5,
beta=0.1,
model_name=model,
)

replay_buffer = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ReplayBuffer,
batch_size=4,
max_policy_age=1,
)

dataloader = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
DatasetActor,
"openai/gsm8k",
"main",
split="train",
streaming=True,
)

compute_advantages = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ComputeAdvantages,
gamma=0.99,
lambda_=0.95,
)

ref_model = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
RefModel,
model_name=model,
)

reward_actor = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
RewardActor,
reward_functions=[MathReward(), ThinkingReward()],
)

print("All services initialized successfully!")

# ---- Core RL loops ---- #
async def continuous_rollouts():
rollout_count = 0
# TODO: Move this into setup
asyncio.create_task(policy.run_processing.call())
while True:
sample = await dataloader.__next__.choose()
if sample is None:
Expand Down Expand Up @@ -481,6 +488,17 @@ async def continuous_training():
print("Training interrupted by user")
rollout_task.cancel()
training_task.cancel()
finally:
print("Shutting down...")
await asyncio.gather(
shutdown_service(policy),
shutdown_service(trainer),
shutdown_service(replay_buffer),
shutdown_service(dataloader),
shutdown_service(compute_advantages),
shutdown_service(ref_model),
shutdown_service(reward_actor),
)


if __name__ == "__main__":
Expand Down
39 changes: 20 additions & 19 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import List

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.controller.service import ServiceConfig, spawn_service
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from vllm.outputs import CompletionOutput


Expand Down Expand Up @@ -58,9 +58,11 @@ def parse_args() -> Namespace:


def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):

worker_size = 2
worker_params = WorkerConfig(
model=args.model,
tensor_parallel_size=2,
tensor_parallel_size=worker_size,
pipeline_parallel_size=1,
enforce_eager=True,
vllm_args=None,
Expand All @@ -72,35 +74,34 @@ def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):
)

policy_config = PolicyConfig(
num_workers=2, worker_params=worker_params, sampling_params=sampling_params
worker_params=worker_params, sampling_params=sampling_params
)
service_config = ServiceConfig(
procs_per_replica=worker_size, num_replicas=1, with_gpus=True
)
service_config = ServiceConfig(procs_per_replica=1, num_replicas=1)

return policy_config, service_config


async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't be delete this now that we have the vllm app?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I'm not sure I'm following, we still need a ServiceConfig for spawning a service here regardless?

print("Spawning service...")
policy = await spawn_service(service_config, Policy, config=config)
session_id = await policy.start_session()

print("Starting background processing...")
processing_task = asyncio.create_task(policy.run_processing.call())
async with policy.session():
print("Requesting generation...")
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)

print("Requesting generation...")
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)
print("\nGeneration Results:")
print("=" * 80)
for batch, response in enumerate(responses):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
print("-" * 80)

print("\nGeneration Results:")
print("=" * 80)
for batch, response in enumerate(responses):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
print("-" * 80)
print("\nShutting down...")

print("\nShutting down...")
await policy.shutdown.call()
await policy.terminate_session(session_id)
await shutdown_service(policy)


if __name__ == "__main__":
Expand Down
Loading
Loading