From 69e589cacaa6486f16378536c4fe44a231b6e84f Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 14 Nov 2025 11:29:43 -0800 Subject: [PATCH 1/5] wandb hack Summary: Fixes wandb complaining about metadata being None. tested by running python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml Differential Revision: D87092141 --- src/forge/__init__.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/forge/__init__.py b/src/forge/__init__.py index b359f9c5b..93707e7a9 100644 --- a/src/forge/__init__.py +++ b/src/forge/__init__.py @@ -17,3 +17,21 @@ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" except ImportError: pass + + +# FIXME: remove this once wandb fixed this issue +# Patch importlib.metadata.distributions before wandb imports it +# to filter out packages with None metadata +import importlib.metadata + +_original_distributions = importlib.metadata.distributions + + +def _patched_distributions(): + """Filter out distributions with None metadata""" + for dist in _original_distributions(): + if dist.metadata is not None: + yield dist + + +importlib.metadata.distributions = _patched_distributions From fef0230f401d47de64713733fc81d53177c461fe Mon Sep 17 00:00:00 2001 From: casteryh <57782783+casteryh@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:28:15 -0800 Subject: [PATCH 2/5] Update src/forge/__init__.py Co-authored-by: Felipe Mello --- src/forge/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/forge/__init__.py b/src/forge/__init__.py index 93707e7a9..98002f06b 100644 --- a/src/forge/__init__.py +++ b/src/forge/__init__.py @@ -20,6 +20,7 @@ # FIXME: remove this once wandb fixed this issue +# https://github.com/wandb/wandb/issues/10890 # Patch importlib.metadata.distributions before wandb imports it # to filter out packages with None metadata import importlib.metadata From 3627ec1c6dc6d9130110d0643003bb8abb8414cd Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 15 Nov 2025 11:25:48 -0800 Subject: [PATCH 3/5] fix --- .claude/settings.local.json | 14 ++++++++++++++ apps/grpo/main.py | 18 ++++++++++++++---- apps/grpo/qwen3_1_7b.yaml | 4 ++-- src/forge/__init__.py | 23 ++++++++++++----------- 4 files changed, 42 insertions(+), 17 deletions(-) create mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 000000000..526fb128e --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,14 @@ +{ + "permissions": { + "allow": [ + "Bash(source ~/.bashrc)", + "Bash(conda activate forge-monarch-0-1-1)", + "Bash(pip install:*)", + "Bash(https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 pip install:*)", + "Bash(python -m pytest:*)", + "Bash(https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 git push:*)" + ], + "deny": [], + "ask": [] + } +} diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 693dc8d81..5e749d289 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -6,6 +6,19 @@ # Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml +# Patch importlib.metadata.distributions before wandb imports it +# to filter out packages with None metadata +import importlib.metadata +_original_distributions = importlib.metadata.distributions + +def _patched_distributions(): + """Filter out distributions with None metadata""" + for dist in _original_distributions(): + if dist.metadata is not None: + yield dist + +importlib.metadata.distributions = _patched_distributions + import asyncio import time import uuid @@ -125,12 +138,9 @@ def simple_grpo_loss( ref_logprobs: torch.Tensor, advantages: torch.Tensor, padding_mask: torch.Tensor, - beta: float = 0.1, ) -> torch.Tensor: logprobs: torch.Tensor = compute_logprobs(logits, response) - kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 - per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages - per_token_loss = -(per_token_policy_loss - beta * kl) + per_token_loss = torch.exp(logprobs - logprobs.detach()) * advantages.detach() loss = ( ((per_token_loss * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0)) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index c6fc1613b..a5d1230e2 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -7,10 +7,10 @@ local_batch_size: 16 # per-device batch size max_req_tokens: 1024 max_res_tokens: 1024 model: "Qwen/Qwen3-1.7B" -off_by_n: 1 # Off by one by default +off_by_n: 0 # Off by one by default # Main loop configuration -rollout_threads: 1 # Recommended to set equal to policy.num_replicas +rollout_threads: 1 # Recommended to set equal to policy.num_replicas # Observability configuration diff --git a/src/forge/__init__.py b/src/forge/__init__.py index 98002f06b..64cdd2f06 100644 --- a/src/forge/__init__.py +++ b/src/forge/__init__.py @@ -25,14 +25,15 @@ # to filter out packages with None metadata import importlib.metadata -_original_distributions = importlib.metadata.distributions - - -def _patched_distributions(): - """Filter out distributions with None metadata""" - for dist in _original_distributions(): - if dist.metadata is not None: - yield dist - - -importlib.metadata.distributions = _patched_distributions +# Guard to ensure this runs only once +if not hasattr(importlib.metadata, "_distributions_patched"): + _original_distributions = importlib.metadata.distributions + + def _patched_distributions(): + """Filter out distributions with None metadata""" + for dist in _original_distributions(): + if dist.metadata is not None: + yield dist + + importlib.metadata.distributions = _patched_distributions + importlib.metadata._distributions_patched = True From 1f9e3498b69d124c70402a059f49207aabd4ab0e Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 15 Nov 2025 11:31:28 -0800 Subject: [PATCH 4/5] Revert unintended changes from previous commit Only src/forge/__init__.py was intended to be changed. --- .claude/settings.local.json | 14 -------------- apps/grpo/main.py | 18 ++++-------------- apps/grpo/qwen3_1_7b.yaml | 4 ++-- 3 files changed, 6 insertions(+), 30 deletions(-) delete mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 526fb128e..000000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(source ~/.bashrc)", - "Bash(conda activate forge-monarch-0-1-1)", - "Bash(pip install:*)", - "Bash(https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 pip install:*)", - "Bash(python -m pytest:*)", - "Bash(https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 git push:*)" - ], - "deny": [], - "ask": [] - } -} diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 5e749d289..693dc8d81 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -6,19 +6,6 @@ # Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml -# Patch importlib.metadata.distributions before wandb imports it -# to filter out packages with None metadata -import importlib.metadata -_original_distributions = importlib.metadata.distributions - -def _patched_distributions(): - """Filter out distributions with None metadata""" - for dist in _original_distributions(): - if dist.metadata is not None: - yield dist - -importlib.metadata.distributions = _patched_distributions - import asyncio import time import uuid @@ -138,9 +125,12 @@ def simple_grpo_loss( ref_logprobs: torch.Tensor, advantages: torch.Tensor, padding_mask: torch.Tensor, + beta: float = 0.1, ) -> torch.Tensor: logprobs: torch.Tensor = compute_logprobs(logits, response) - per_token_loss = torch.exp(logprobs - logprobs.detach()) * advantages.detach() + kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_loss = -(per_token_policy_loss - beta * kl) loss = ( ((per_token_loss * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0)) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index a5d1230e2..c6fc1613b 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -7,10 +7,10 @@ local_batch_size: 16 # per-device batch size max_req_tokens: 1024 max_res_tokens: 1024 model: "Qwen/Qwen3-1.7B" -off_by_n: 0 # Off by one by default +off_by_n: 1 # Off by one by default # Main loop configuration -rollout_threads: 1 # Recommended to set equal to policy.num_replicas +rollout_threads: 1 # Recommended to set equal to policy.num_replicas # Observability configuration From c39362b00acd2c1499af0582c6d49308ad58cbd1 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 15 Nov 2025 12:03:41 -0800 Subject: [PATCH 5/5] move --- src/forge/__init__.py | 20 -------------------- src/forge/util/logging.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/forge/__init__.py b/src/forge/__init__.py index 64cdd2f06..b359f9c5b 100644 --- a/src/forge/__init__.py +++ b/src/forge/__init__.py @@ -17,23 +17,3 @@ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" except ImportError: pass - - -# FIXME: remove this once wandb fixed this issue -# https://github.com/wandb/wandb/issues/10890 -# Patch importlib.metadata.distributions before wandb imports it -# to filter out packages with None metadata -import importlib.metadata - -# Guard to ensure this runs only once -if not hasattr(importlib.metadata, "_distributions_patched"): - _original_distributions = importlib.metadata.distributions - - def _patched_distributions(): - """Filter out distributions with None metadata""" - for dist in _original_distributions(): - if dist.metadata is not None: - yield dist - - importlib.metadata.distributions = _patched_distributions - importlib.metadata._distributions_patched = True diff --git a/src/forge/util/logging.py b/src/forge/util/logging.py index 9eacf893d..8a7c1c99d 100644 --- a/src/forge/util/logging.py +++ b/src/forge/util/logging.py @@ -4,6 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# FIXME: remove this once wandb fixed this issue +# https://github.com/wandb/wandb/issues/10890 +# Patch importlib.metadata.distributions before wandb imports it +# to filter out packages with None metadata +import importlib.metadata + +# Guard to ensure this runs only once +if not hasattr(importlib.metadata, "_distributions_patched"): + _original_distributions = importlib.metadata.distributions + + def _patched_distributions(): + """Filter out distributions with None metadata""" + for distribution in _original_distributions(): + if distribution.metadata is not None: + yield distribution + + importlib.metadata.distributions = _patched_distributions + importlib.metadata._distributions_patched = True + import logging from functools import lru_cache