diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml
index 8100a988b..3615185d6 100644
--- a/apps/grpo/qwen3_32b.yaml
+++ b/apps/grpo/qwen3_32b.yaml
@@ -3,10 +3,10 @@
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability
# Global configuration
-group_size: 2
-local_batch_size: 8 # per-device batch size
+group_size: 8
+local_batch_size: 16 # per-device batch size
max_req_tokens: 512
-max_res_tokens: 512
+max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default
@@ -14,7 +14,7 @@ provisioner:
launcher: slurm
# Main loop configuration
-rollout_threads: 1 # Recommended to set equal to policy.num_replicas
+rollout_threads: 1
# Observability configuration
metric_logging:
@@ -37,7 +37,7 @@ dataset:
policy:
engine_config:
model: ${model}
- tensor_parallel_size: 4
+ tensor_parallel_size: 8
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
@@ -120,7 +120,7 @@ services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 1
- hosts: 1
+ # hosts: 1
with_gpus: true
ref_model:
procs: ${ref_model.parallelism.tensor_parallel_degree}
@@ -137,7 +137,7 @@ actors:
with_gpus: false
trainer:
procs: 8
- hosts: 1
+ # hosts: 1
with_gpus: true
replay_buffer:
procs: 1
diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py
index 0f3ce662c..3254dc87b 100644
--- a/tests/sandbox/vllm/main.py
+++ b/tests/sandbox/vllm/main.py
@@ -15,7 +15,14 @@
from forge.actors.policy import Policy
from forge.cli.config import parse
-
+from forge.controller.provisioner import shutdown
+from dataclasses import dataclass
+from datasets import load_dataset
+
+from forge.controller.actor import ForgeActor
+from monarch.actor import endpoint
+from vllm.transformers_utils.tokenizer import get_tokenizer
+from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data_models.completion import Completion
@@ -23,10 +30,129 @@
from forge.types import LauncherConfig, ProvisionerConfig
from omegaconf import DictConfig
+from forge.observability.perf_tracker import Tracer
+from forge.types import (
+ Launcher,
+ LauncherConfig,
+ ProcessConfig,
+ ProvisionerConfig,
+ ServiceConfig,
+)
+
+import time
+from collections import deque
+
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"
+import time
+import statistics
+from collections import deque
+
+class ThroughputTracker:
+ def __init__(self, window_size=60): # 60 second window
+ self.window_size = window_size
+ self.request_times = deque()
+ self.token_counts = deque()
+ self.latencies = deque() # Store latency for each request
+ self.last_print = time.time()
+ self.print_interval = 10 # Print every 10 seconds
+
+ def start_request(self):
+ """Call this when starting a request. Returns the start time."""
+ return time.time()
+
+ def end_request(self, start_time, num_tokens):
+ """Call this when a request completes."""
+ end_time = time.time()
+ latency = end_time - start_time
+
+ self.request_times.append(end_time)
+ self.token_counts.append(num_tokens)
+ self.latencies.append(latency)
+
+ # Remove old entries outside the window
+ cutoff_time = end_time - self.window_size
+ while self.request_times and self.request_times[0] < cutoff_time:
+ self.request_times.popleft()
+ self.token_counts.popleft()
+ self.latencies.popleft()
+
+ # Print throughput info periodically
+ if end_time - self.last_print >= self.print_interval:
+ self.print_metrics()
+ self.last_print = end_time
+
+ def print_metrics(self):
+ if not self.request_times:
+ return
+
+ time_window = time.time() - self.request_times[0] if len(self.request_times) > 1 else self.print_interval
+ requests_per_sec = len(self.request_times) / max(time_window, 1)
+ tokens_per_sec = sum(self.token_counts) / max(time_window, 1)
+
+ # Calculate latency statistics
+ if self.latencies:
+ avg_latency = statistics.mean(self.latencies)
+ sorted_latencies = sorted(self.latencies)
+ p50_latency = statistics.median(sorted_latencies)
+ p95_latency = sorted_latencies[int(0.95 * len(sorted_latencies))] if len(sorted_latencies) > 0 else 0
+ p99_latency = sorted_latencies[int(0.99 * len(sorted_latencies))] if len(sorted_latencies) > 0 else 0
+
+ print(f"📊 Throughput: {requests_per_sec:.2f} req/sec | {tokens_per_sec:.2f} tok/sec")
+ print(f"⏱️ Latency: avg={avg_latency:.1f}s | p50={p50_latency:.1f}s | p95={p95_latency:.1f}s | p99={p99_latency:.1f}s")
+
+
+@dataclass
+class DatasetActor(ForgeActor):
+ """Actor wrapper for HuggingFace dataset to provide async interface."""
+
+ path: str = "openai/gsm8k"
+ revision: str = "main"
+ data_split: str = "train"
+ streaming: bool = True
+ model: str = "Qwen/Qwen3-1.7B"
+
+ @endpoint
+ def setup(self):
+ self._tokenizer = get_tokenizer(self.model)
+
+ def gsm8k_transform(sample):
+ system_prompt = """
+ Put all your scratchpad work between and tags.
+ Your final answer should be between and tags otherwise it will not be scored.
+ """
+ request: str = sample["question"]
+ as_chat = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": request},
+ ]
+ formatted_request = self._tokenizer.apply_chat_template(
+ as_chat,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ target: str = sample["answer"]
+ formatted_target = target.split("#### ")[1]
+ return {"request": formatted_request, "target": formatted_target}
+
+ ds = load_dataset(
+ self.path, self.revision, split=self.data_split, streaming=self.streaming
+ )
+ ds = ds.map(gsm8k_transform)
+ ds = ds.shuffle()
+ self._iterator = iter(ds)
+
+ @endpoint
+ async def sample(self) -> dict[str, str] | None:
+ try:
+ sample = next(self._iterator)
+ return sample
+ except StopIteration:
+ return None
+
+
async def run(cfg: DictConfig):
if cfg.get("provisioner", None) is not None:
await init_provisioner(
@@ -36,36 +162,57 @@ async def run(cfg: DictConfig):
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
- if (prompt := cfg.get("prompt")) is None:
- gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
- prompt = "What is 3+5?" if gd else "Tell me a joke"
-
print("Spawning service...")
- policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)
-
- import time
-
- print("Requesting generation...")
- n = 100
- start = time.time()
- response_outputs: list[Completion] = await asyncio.gather(
- *[policy.generate.route(prompt=prompt) for _ in range(n)]
+ (dataloader, policy) = await asyncio.gather(
+ DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
+ Policy.options(**cfg.services.policy).as_service(**cfg.policy),
)
- end = time.time()
-
- print(f"Generation of {n} requests completed in {end - start:.2f} seconds.")
- print(
- f"Generation with procs {cfg.services.policy.procs}, replicas {cfg.services.policy.num_replicas}"
- )
-
- print(f"\nGeneration Results (last one of {n} requests):")
- print("=" * 80)
- for batch, response in enumerate(response_outputs[-1]):
- print(f"Sample {batch + 1}:")
- print(f"User: {prompt}")
- print(f"Assistant: {response.text}")
- print("-" * 80)
+ max_res_tokens = cfg.get("max_res_tokens", None)
+ assert max_res_tokens is not None, "max_res_tokens must be specified in config"
+ group_size = cfg.get("group_size", None)
+ assert group_size is not None, "group_size must be specified in config"
+ token_per_request = max_res_tokens * group_size
+ num_rollout_threads = cfg.get("rollout_threads", 1)
+
+ throughput_tracker = ThroughputTracker()
+
+ async def continuous_rollouts():
+ print("Starting continuous rollouts")
+ print(f" {max_res_tokens=}")
+ print(f" {group_size=}")
+ print(f" {num_rollout_threads=}")
+ while True:
+ t = Tracer("main_perf/continuous_rollouts")
+ t.start()
+ sample = await dataloader.sample.call_one()
+ if sample is None:
+ print("Dataloader is empty, exiting continuous rollout")
+ return
+
+ t.step("data_loading")
+ prompt, target = sample["request"], sample["target"]
+
+ request_start_time = throughput_tracker.start_request()
+ responses = await policy.generate.route(prompt)
+ throughput_tracker.end_request(request_start_time, token_per_request)
+ t.step("policy_generation")
+
+ # print(f"#------ request ------#")
+ # print(prompt)
+ # print("#------ target ------#")
+ # print(target)
+ print(f"#------ responses ------#")
+ # print(responses[0].text)
+ # print()
+ assert len(responses) == group_size
+
+
+ rollout_tasks = [
+ asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)
+ ]
+
+ await asyncio.gather(*rollout_tasks)
print("\nShutting down...")
await policy.shutdown()
await shutdown()