diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 8100a988b..3615185d6 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -3,10 +3,10 @@ # NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability # Global configuration -group_size: 2 -local_batch_size: 8 # per-device batch size +group_size: 8 +local_batch_size: 16 # per-device batch size max_req_tokens: 512 -max_res_tokens: 512 +max_res_tokens: 1024 model: "Qwen/Qwen3-32B" off_by_n: 1 # Off by one by default @@ -14,7 +14,7 @@ provisioner: launcher: slurm # Main loop configuration -rollout_threads: 1 # Recommended to set equal to policy.num_replicas +rollout_threads: 1 # Observability configuration metric_logging: @@ -37,7 +37,7 @@ dataset: policy: engine_config: model: ${model} - tensor_parallel_size: 4 + tensor_parallel_size: 8 pipeline_parallel_size: 1 enforce_eager: false sampling_config: @@ -120,7 +120,7 @@ services: policy: procs: ${policy.engine_config.tensor_parallel_size} num_replicas: 1 - hosts: 1 + # hosts: 1 with_gpus: true ref_model: procs: ${ref_model.parallelism.tensor_parallel_degree} @@ -137,7 +137,7 @@ actors: with_gpus: false trainer: procs: 8 - hosts: 1 + # hosts: 1 with_gpus: true replay_buffer: procs: 1 diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 0f3ce662c..3254dc87b 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -15,7 +15,14 @@ from forge.actors.policy import Policy from forge.cli.config import parse - +from forge.controller.provisioner import shutdown +from dataclasses import dataclass +from datasets import load_dataset + +from forge.controller.actor import ForgeActor +from monarch.actor import endpoint +from vllm.transformers_utils.tokenizer import get_tokenizer +from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY from forge.controller.provisioner import init_provisioner, shutdown from forge.data_models.completion import Completion @@ -23,10 +30,129 @@ from forge.types import LauncherConfig, ProvisionerConfig from omegaconf import DictConfig +from forge.observability.perf_tracker import Tracer +from forge.types import ( + Launcher, + LauncherConfig, + ProcessConfig, + ProvisionerConfig, + ServiceConfig, +) + +import time +from collections import deque + os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824" +import time +import statistics +from collections import deque + +class ThroughputTracker: + def __init__(self, window_size=60): # 60 second window + self.window_size = window_size + self.request_times = deque() + self.token_counts = deque() + self.latencies = deque() # Store latency for each request + self.last_print = time.time() + self.print_interval = 10 # Print every 10 seconds + + def start_request(self): + """Call this when starting a request. Returns the start time.""" + return time.time() + + def end_request(self, start_time, num_tokens): + """Call this when a request completes.""" + end_time = time.time() + latency = end_time - start_time + + self.request_times.append(end_time) + self.token_counts.append(num_tokens) + self.latencies.append(latency) + + # Remove old entries outside the window + cutoff_time = end_time - self.window_size + while self.request_times and self.request_times[0] < cutoff_time: + self.request_times.popleft() + self.token_counts.popleft() + self.latencies.popleft() + + # Print throughput info periodically + if end_time - self.last_print >= self.print_interval: + self.print_metrics() + self.last_print = end_time + + def print_metrics(self): + if not self.request_times: + return + + time_window = time.time() - self.request_times[0] if len(self.request_times) > 1 else self.print_interval + requests_per_sec = len(self.request_times) / max(time_window, 1) + tokens_per_sec = sum(self.token_counts) / max(time_window, 1) + + # Calculate latency statistics + if self.latencies: + avg_latency = statistics.mean(self.latencies) + sorted_latencies = sorted(self.latencies) + p50_latency = statistics.median(sorted_latencies) + p95_latency = sorted_latencies[int(0.95 * len(sorted_latencies))] if len(sorted_latencies) > 0 else 0 + p99_latency = sorted_latencies[int(0.99 * len(sorted_latencies))] if len(sorted_latencies) > 0 else 0 + + print(f"📊 Throughput: {requests_per_sec:.2f} req/sec | {tokens_per_sec:.2f} tok/sec") + print(f"⏱️ Latency: avg={avg_latency:.1f}s | p50={p50_latency:.1f}s | p95={p95_latency:.1f}s | p99={p99_latency:.1f}s") + + +@dataclass +class DatasetActor(ForgeActor): + """Actor wrapper for HuggingFace dataset to provide async interface.""" + + path: str = "openai/gsm8k" + revision: str = "main" + data_split: str = "train" + streaming: bool = True + model: str = "Qwen/Qwen3-1.7B" + + @endpoint + def setup(self): + self._tokenizer = get_tokenizer(self.model) + + def gsm8k_transform(sample): + system_prompt = """ + Put all your scratchpad work between and tags. + Your final answer should be between and tags otherwise it will not be scored. + """ + request: str = sample["question"] + as_chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + formatted_request = self._tokenizer.apply_chat_template( + as_chat, + tokenize=False, + add_generation_prompt=True, + ) + target: str = sample["answer"] + formatted_target = target.split("#### ")[1] + return {"request": formatted_request, "target": formatted_target} + + ds = load_dataset( + self.path, self.revision, split=self.data_split, streaming=self.streaming + ) + ds = ds.map(gsm8k_transform) + ds = ds.shuffle() + self._iterator = iter(ds) + + @endpoint + async def sample(self) -> dict[str, str] | None: + try: + sample = next(self._iterator) + return sample + except StopIteration: + return None + + async def run(cfg: DictConfig): if cfg.get("provisioner", None) is not None: await init_provisioner( @@ -36,36 +162,57 @@ async def run(cfg: DictConfig): mlogger = await get_or_create_metric_logger() await mlogger.init_backends.call_one(metric_logging_cfg) - if (prompt := cfg.get("prompt")) is None: - gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False) - prompt = "What is 3+5?" if gd else "Tell me a joke" - print("Spawning service...") - policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy) - - import time - - print("Requesting generation...") - n = 100 - start = time.time() - response_outputs: list[Completion] = await asyncio.gather( - *[policy.generate.route(prompt=prompt) for _ in range(n)] + (dataloader, policy) = await asyncio.gather( + DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), + Policy.options(**cfg.services.policy).as_service(**cfg.policy), ) - end = time.time() - - print(f"Generation of {n} requests completed in {end - start:.2f} seconds.") - print( - f"Generation with procs {cfg.services.policy.procs}, replicas {cfg.services.policy.num_replicas}" - ) - - print(f"\nGeneration Results (last one of {n} requests):") - print("=" * 80) - for batch, response in enumerate(response_outputs[-1]): - print(f"Sample {batch + 1}:") - print(f"User: {prompt}") - print(f"Assistant: {response.text}") - print("-" * 80) + max_res_tokens = cfg.get("max_res_tokens", None) + assert max_res_tokens is not None, "max_res_tokens must be specified in config" + group_size = cfg.get("group_size", None) + assert group_size is not None, "group_size must be specified in config" + token_per_request = max_res_tokens * group_size + num_rollout_threads = cfg.get("rollout_threads", 1) + + throughput_tracker = ThroughputTracker() + + async def continuous_rollouts(): + print("Starting continuous rollouts") + print(f" {max_res_tokens=}") + print(f" {group_size=}") + print(f" {num_rollout_threads=}") + while True: + t = Tracer("main_perf/continuous_rollouts") + t.start() + sample = await dataloader.sample.call_one() + if sample is None: + print("Dataloader is empty, exiting continuous rollout") + return + + t.step("data_loading") + prompt, target = sample["request"], sample["target"] + + request_start_time = throughput_tracker.start_request() + responses = await policy.generate.route(prompt) + throughput_tracker.end_request(request_start_time, token_per_request) + t.step("policy_generation") + + # print(f"#------ request ------#") + # print(prompt) + # print("#------ target ------#") + # print(target) + print(f"#------ responses ------#") + # print(responses[0].text) + # print() + assert len(responses) == group_size + + + rollout_tasks = [ + asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads) + ] + + await asyncio.gather(*rollout_tasks) print("\nShutting down...") await policy.shutdown() await shutdown()