From b8628bb8e72a27a825225079d66d6fcf01b2e22a Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Wed, 1 Oct 2025 16:23:51 -0400 Subject: [PATCH 1/8] initial commit --- apps/grpo/main.py | 14 ++++++++++---- apps/grpo/qwen3_1_7b.yaml | 1 + apps/grpo/qwen3_8b.yaml | 5 +++-- apps/grpo/qwen3_multinode.yaml | 1 + src/forge/actors/reference_model.py | 14 +++++++++++--- 5 files changed, 26 insertions(+), 9 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 852989682..e08598c65 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -311,6 +311,7 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens + compute_logprobs_in_reference_model = cfg.compute_logprobs_in_reference_model # initialize before spawning services metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) @@ -402,13 +403,18 @@ async def continuous_rollouts(): t.step("reward_evaluation") # Calculate reference logprobs - ref_logits = await ref_model.forward.route(input_ids) - t.step("reference_model_forward") + if compute_logprobs_in_reference_model: + ref_logprobs = await ref_model.forward.route(input_ids, max_req_tokens, return_logprobs=True) + t.step("reference_model_forward_return_logprobs") + else: + ref_logits = await ref_model.forward.route(input_ids, max_req_tokens, return_logprobs=False) + t.step("reference_model_forward_return_logits") + ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:]) + del ref_logits - ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:]) for i, episode in enumerate(group.episodes): episode.ref_logprobs = ref_logprobs[i] - del ref_logits, ref_logprobs, input_ids + del ref_logprobs, input_ids t.step("compute_logprobs") # Calculate advantages and add to replay buffer diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 53eec5cfb..694c8123b 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -6,6 +6,7 @@ group_size: 8 batch_size: 16 max_req_tokens: 512 max_res_tokens: 512 +compute_logprobs_in_reference_model: true model: "Qwen/Qwen3-1.7B" off_by_n: 1 # Off by one by default diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index c46ee0620..4638bb33a 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -4,8 +4,9 @@ # Global configuration group_size: 8 batch_size: 16 -max_req_tokens: 512 -max_res_tokens: 512 +max_req_tokens: 468 +max_res_tokens: 468 +compute_logprobs_in_reference_model: true model: "Qwen/Qwen3-8B" off_by_n: 1 # Off by one by default diff --git a/apps/grpo/qwen3_multinode.yaml b/apps/grpo/qwen3_multinode.yaml index 679442d2a..660aa675b 100644 --- a/apps/grpo/qwen3_multinode.yaml +++ b/apps/grpo/qwen3_multinode.yaml @@ -7,6 +7,7 @@ group_size: 8 batch_size: 16 max_req_tokens: 512 max_res_tokens: 512 +compute_logprobs_in_reference_model: true model: "Qwen/Qwen3-1.7B" # Observability configuration diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index f0777665b..be1c3ec18 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -29,6 +29,7 @@ from forge.controller import ForgeActor from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer +from forge.util.ops import compute_logprobs logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -90,7 +91,7 @@ async def setup(self): self.model.eval() @endpoint - async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + async def forward(self, input_ids: torch.Tensor, max_req_tokens: int, return_logprobs: bool) -> torch.Tensor: # Record reference model metrics record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM) @@ -133,5 +134,12 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: if isinstance(logits, DTensor): logits = logits.full_tensor() t.step("forward") - t.stop() - return logits + + if not return_logprobs: + t.stop() + return logits + else: + logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:]) + t.step("compute_logprobs") + t.stop() + return logprobs From 91e5297d457e273ddbe63a09c4720948df1c39c6 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Wed, 1 Oct 2025 17:27:00 -0400 Subject: [PATCH 2/8] comment --- src/forge/actors/reference_model.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index be1c3ec18..a7ca39264 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -91,8 +91,20 @@ async def setup(self): self.model.eval() @endpoint - async def forward(self, input_ids: torch.Tensor, max_req_tokens: int, return_logprobs: bool) -> torch.Tensor: - + async def forward( + self, input_ids: torch.Tensor, max_req_tokens: int, return_logprobs: bool + ) -> torch.Tensor: + """ + Args: + return_logprobs (bool): whether to return og probabilities instead of raw logits. + + This flag significantly impacts the amount of data transferred to the caller: + - When False: Returns logits with shape [group_size, req + res_length, vocab_size]. + This includes the full vocabulary distribution for each token position. + + - When True: Returns log probabilities with shape [group_size, req_length]. + This only includes probabilities for the request tokens, significantly reducing memory usage and transfer overhead. + """ # Record reference model metrics record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM) record_metric( From 9380c16cd35aa74f0273c37a7bceb0d1ef0831b9 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Wed, 1 Oct 2025 17:29:57 -0400 Subject: [PATCH 3/8] rename --- apps/grpo/main.py | 4 ++-- apps/grpo/qwen3_1_7b.yaml | 2 +- apps/grpo/qwen3_8b.yaml | 2 +- apps/grpo/qwen3_multinode.yaml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index e08598c65..cbefe0115 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -311,7 +311,7 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens - compute_logprobs_in_reference_model = cfg.compute_logprobs_in_reference_model + ref_model_return_logprobs = cfg.ref_model_return_logprobs # initialize before spawning services metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) @@ -403,7 +403,7 @@ async def continuous_rollouts(): t.step("reward_evaluation") # Calculate reference logprobs - if compute_logprobs_in_reference_model: + if ref_model_return_logprobs: ref_logprobs = await ref_model.forward.route(input_ids, max_req_tokens, return_logprobs=True) t.step("reference_model_forward_return_logprobs") else: diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 694c8123b..0c5b1ec71 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -6,7 +6,7 @@ group_size: 8 batch_size: 16 max_req_tokens: 512 max_res_tokens: 512 -compute_logprobs_in_reference_model: true +ref_model_return_logprobs: true model: "Qwen/Qwen3-1.7B" off_by_n: 1 # Off by one by default diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 4638bb33a..4bf54a258 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -6,7 +6,7 @@ group_size: 8 batch_size: 16 max_req_tokens: 468 max_res_tokens: 468 -compute_logprobs_in_reference_model: true +ref_model_return_logprobs: true model: "Qwen/Qwen3-8B" off_by_n: 1 # Off by one by default diff --git a/apps/grpo/qwen3_multinode.yaml b/apps/grpo/qwen3_multinode.yaml index 660aa675b..65242fb4c 100644 --- a/apps/grpo/qwen3_multinode.yaml +++ b/apps/grpo/qwen3_multinode.yaml @@ -7,7 +7,7 @@ group_size: 8 batch_size: 16 max_req_tokens: 512 max_res_tokens: 512 -compute_logprobs_in_reference_model: true +ref_model_return_logprobs: true model: "Qwen/Qwen3-1.7B" # Observability configuration From d92a1fcc871311f921179612e12044877b068365 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 2 Oct 2025 08:54:16 -0400 Subject: [PATCH 4/8] remove options from the grpo app + doc string --- apps/grpo/main.py | 12 ++---------- apps/grpo/qwen3_1_7b.yaml | 1 - apps/grpo/qwen3_8b.yaml | 1 - apps/grpo/qwen3_multinode.yaml | 1 - src/forge/actors/reference_model.py | 7 +++++-- 5 files changed, 7 insertions(+), 15 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index cbefe0115..3fb7122db 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -311,7 +311,6 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens - ref_model_return_logprobs = cfg.ref_model_return_logprobs # initialize before spawning services metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) @@ -402,15 +401,8 @@ async def continuous_rollouts(): t.step("reward_evaluation") - # Calculate reference logprobs - if ref_model_return_logprobs: - ref_logprobs = await ref_model.forward.route(input_ids, max_req_tokens, return_logprobs=True) - t.step("reference_model_forward_return_logprobs") - else: - ref_logits = await ref_model.forward.route(input_ids, max_req_tokens, return_logprobs=False) - t.step("reference_model_forward_return_logits") - ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:]) - del ref_logits + ref_logprobs = await ref_model.forward.route(input_ids, max_req_tokens, return_logprobs=True) + t.step("reference_model_calculate_logprobs") for i, episode in enumerate(group.episodes): episode.ref_logprobs = ref_logprobs[i] diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 0c5b1ec71..53eec5cfb 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -6,7 +6,6 @@ group_size: 8 batch_size: 16 max_req_tokens: 512 max_res_tokens: 512 -ref_model_return_logprobs: true model: "Qwen/Qwen3-1.7B" off_by_n: 1 # Off by one by default diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 4bf54a258..5436ba19e 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -6,7 +6,6 @@ group_size: 8 batch_size: 16 max_req_tokens: 468 max_res_tokens: 468 -ref_model_return_logprobs: true model: "Qwen/Qwen3-8B" off_by_n: 1 # Off by one by default diff --git a/apps/grpo/qwen3_multinode.yaml b/apps/grpo/qwen3_multinode.yaml index 65242fb4c..679442d2a 100644 --- a/apps/grpo/qwen3_multinode.yaml +++ b/apps/grpo/qwen3_multinode.yaml @@ -7,7 +7,6 @@ group_size: 8 batch_size: 16 max_req_tokens: 512 max_res_tokens: 512 -ref_model_return_logprobs: true model: "Qwen/Qwen3-1.7B" # Observability configuration diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index a7ca39264..cc57e5246 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -96,14 +96,17 @@ async def forward( ) -> torch.Tensor: """ Args: + input_ids (torch.Tensor): input token ids with shape [group_size, req + res length]. + max_req_tokens (int): maximum request length. return_logprobs (bool): whether to return og probabilities instead of raw logits. - This flag significantly impacts the amount of data transferred to the caller: + return_logprobs flag significantly impacts the amount of data transferred to the caller: - When False: Returns logits with shape [group_size, req + res_length, vocab_size]. This includes the full vocabulary distribution for each token position. - When True: Returns log probabilities with shape [group_size, req_length]. - This only includes probabilities for the request tokens, significantly reducing memory usage and transfer overhead. + This only includes probabilities for the request tokens, significantly reducing memory + usage and transfer overhead. """ # Record reference model metrics record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM) From ed2b993b7892a77d623b56d7234d8459c24b9046 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 2 Oct 2025 13:40:20 -0400 Subject: [PATCH 5/8] merge conflict --- apps/grpo/qwen3_8b.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 5436ba19e..c46ee0620 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -4,8 +4,8 @@ # Global configuration group_size: 8 batch_size: 16 -max_req_tokens: 468 -max_res_tokens: 468 +max_req_tokens: 512 +max_res_tokens: 512 model: "Qwen/Qwen3-8B" off_by_n: 1 # Off by one by default From 3651e16b10651d4096c3dbe67b8c3228a9bf817e Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 2 Oct 2025 13:43:44 -0400 Subject: [PATCH 6/8] Re-trigger PR From 012c1c7539c72e8ca9fee4844e94476e182460e9 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 2 Oct 2025 13:48:34 -0400 Subject: [PATCH 7/8] ... From eb1427e6ca70bfd5cd86c1fd3af9f3c93cbef409 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 2 Oct 2025 13:54:41 -0400 Subject: [PATCH 8/8] format --- apps/grpo/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 3fb7122db..21e6168b7 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -401,7 +401,9 @@ async def continuous_rollouts(): t.step("reward_evaluation") - ref_logprobs = await ref_model.forward.route(input_ids, max_req_tokens, return_logprobs=True) + ref_logprobs = await ref_model.forward.route( + input_ids, max_req_tokens, return_logprobs=True + ) t.step("reference_model_calculate_logprobs") for i, episode in enumerate(group.episodes):