Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability

# Global configuration
group_size: 2
local_batch_size: 8 # per-device batch size
group_size: 8
local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm

# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
rollout_threads: 1

# Observability configuration
metric_logging:
Expand All @@ -37,7 +37,7 @@ dataset:
policy:
engine_config:
model: ${model}
tensor_parallel_size: 4
tensor_parallel_size: 8
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
Expand Down Expand Up @@ -120,7 +120,7 @@ services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 1
hosts: 1
# hosts: 1
with_gpus: true
ref_model:
procs: ${ref_model.parallelism.tensor_parallel_degree}
Expand All @@ -137,7 +137,7 @@ actors:
with_gpus: false
trainer:
procs: 8
hosts: 1
# hosts: 1
with_gpus: true
replay_buffer:
procs: 1
Expand Down
203 changes: 175 additions & 28 deletions tests/sandbox/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,144 @@

from forge.actors.policy import Policy
from forge.cli.config import parse

from forge.controller.provisioner import shutdown
from dataclasses import dataclass
from datasets import load_dataset

from forge.controller.actor import ForgeActor
from monarch.actor import endpoint
from vllm.transformers_utils.tokenizer import get_tokenizer
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
from forge.controller.provisioner import init_provisioner, shutdown

from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.types import LauncherConfig, ProvisionerConfig
from omegaconf import DictConfig

from forge.observability.perf_tracker import Tracer
from forge.types import (
Launcher,
LauncherConfig,
ProcessConfig,
ProvisionerConfig,
ServiceConfig,
)

import time
from collections import deque

os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"


import time
import statistics
from collections import deque

class ThroughputTracker:
def __init__(self, window_size=60): # 60 second window
self.window_size = window_size
self.request_times = deque()
self.token_counts = deque()
self.latencies = deque() # Store latency for each request
self.last_print = time.time()
self.print_interval = 10 # Print every 10 seconds

def start_request(self):
"""Call this when starting a request. Returns the start time."""
return time.time()

def end_request(self, start_time, num_tokens):
"""Call this when a request completes."""
end_time = time.time()
latency = end_time - start_time

self.request_times.append(end_time)
self.token_counts.append(num_tokens)
self.latencies.append(latency)

# Remove old entries outside the window
cutoff_time = end_time - self.window_size
while self.request_times and self.request_times[0] < cutoff_time:
self.request_times.popleft()
self.token_counts.popleft()
self.latencies.popleft()

# Print throughput info periodically
if end_time - self.last_print >= self.print_interval:
self.print_metrics()
self.last_print = end_time

def print_metrics(self):
if not self.request_times:
return

time_window = time.time() - self.request_times[0] if len(self.request_times) > 1 else self.print_interval
requests_per_sec = len(self.request_times) / max(time_window, 1)
tokens_per_sec = sum(self.token_counts) / max(time_window, 1)

# Calculate latency statistics
if self.latencies:
avg_latency = statistics.mean(self.latencies)
sorted_latencies = sorted(self.latencies)
p50_latency = statistics.median(sorted_latencies)
p95_latency = sorted_latencies[int(0.95 * len(sorted_latencies))] if len(sorted_latencies) > 0 else 0
p99_latency = sorted_latencies[int(0.99 * len(sorted_latencies))] if len(sorted_latencies) > 0 else 0

print(f"📊 Throughput: {requests_per_sec:.2f} req/sec | {tokens_per_sec:.2f} tok/sec")
print(f"⏱️ Latency: avg={avg_latency:.1f}s | p50={p50_latency:.1f}s | p95={p95_latency:.1f}s | p99={p99_latency:.1f}s")


@dataclass
class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

path: str = "openai/gsm8k"
revision: str = "main"
data_split: str = "train"
streaming: bool = True
model: str = "Qwen/Qwen3-1.7B"

@endpoint
def setup(self):
self._tokenizer = get_tokenizer(self.model)

def gsm8k_transform(sample):
system_prompt = """
Put all your scratchpad work between <think> and </think> tags.
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.
"""
request: str = sample["question"]
as_chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request},
]
formatted_request = self._tokenizer.apply_chat_template(
as_chat,
tokenize=False,
add_generation_prompt=True,
)
target: str = sample["answer"]
formatted_target = target.split("#### ")[1]
return {"request": formatted_request, "target": formatted_target}

ds = load_dataset(
self.path, self.revision, split=self.data_split, streaming=self.streaming
)
ds = ds.map(gsm8k_transform)
ds = ds.shuffle()
self._iterator = iter(ds)

@endpoint
async def sample(self) -> dict[str, str] | None:
try:
sample = next(self._iterator)
return sample
except StopIteration:
return None


async def run(cfg: DictConfig):
if cfg.get("provisioner", None) is not None:
await init_provisioner(
Expand All @@ -36,36 +162,57 @@ async def run(cfg: DictConfig):
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)

if (prompt := cfg.get("prompt")) is None:
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
prompt = "What is 3+5?" if gd else "Tell me a joke"

print("Spawning service...")
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)

import time

print("Requesting generation...")
n = 100
start = time.time()
response_outputs: list[Completion] = await asyncio.gather(
*[policy.generate.route(prompt=prompt) for _ in range(n)]
(dataloader, policy) = await asyncio.gather(
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
)
end = time.time()

print(f"Generation of {n} requests completed in {end - start:.2f} seconds.")
print(
f"Generation with procs {cfg.services.policy.procs}, replicas {cfg.services.policy.num_replicas}"
)

print(f"\nGeneration Results (last one of {n} requests):")
print("=" * 80)
for batch, response in enumerate(response_outputs[-1]):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
print("-" * 80)

max_res_tokens = cfg.get("max_res_tokens", None)
assert max_res_tokens is not None, "max_res_tokens must be specified in config"
group_size = cfg.get("group_size", None)
assert group_size is not None, "group_size must be specified in config"
token_per_request = max_res_tokens * group_size
num_rollout_threads = cfg.get("rollout_threads", 1)

throughput_tracker = ThroughputTracker()

async def continuous_rollouts():
print("Starting continuous rollouts")
print(f" {max_res_tokens=}")
print(f" {group_size=}")
print(f" {num_rollout_threads=}")
while True:
t = Tracer("main_perf/continuous_rollouts")
t.start()
sample = await dataloader.sample.call_one()
if sample is None:
print("Dataloader is empty, exiting continuous rollout")
return

t.step("data_loading")
prompt, target = sample["request"], sample["target"]

request_start_time = throughput_tracker.start_request()
responses = await policy.generate.route(prompt)
throughput_tracker.end_request(request_start_time, token_per_request)
t.step("policy_generation")

# print(f"#------ request ------#")
# print(prompt)
# print("#------ target ------#")
# print(target)
print(f"#------ responses ------#")
# print(responses[0].text)
# print()
assert len(responses) == group_size


rollout_tasks = [
asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)
]

await asyncio.gather(*rollout_tasks)
print("\nShutting down...")
await policy.shutdown()
await shutdown()
Expand Down
Loading