From 87ea8213d843b5d7e3ffb0acf526add6bd941a58 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 21 Oct 2025 15:18:03 -0700 Subject: [PATCH 1/7] add accumulator and test --- src/forge/observability/__init__.py | 5 + src/forge/observability/metrics.py | 99 ++++++++++++++++++- .../data/test_metrics_aggregator.py | 16 +++ 3 files changed, 119 insertions(+), 1 deletion(-) diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 555aa761e..474483426 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -24,8 +24,10 @@ record_metric, Reduce, reduce_metrics_states, + SampleAccumulator, StdAccumulator, SumAccumulator, + TopBottomKFilter, WandbBackend, ) from .perf_tracker import trace, Tracer @@ -64,4 +66,7 @@ "MaxAccumulator", "MinAccumulator", "StdAccumulator", + "SampleAccumulator", + # Filter classes + "TopBottomKFilter", ] diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 55a3c31a2..065d233be 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -4,13 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import heapq +import itertools import logging import os from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, Dict, List import pytz from monarch.actor import current_rank @@ -68,6 +70,7 @@ class Reduce(Enum): MAX = "max" MIN = "min" STD = "std" + SAMPLE = "sample" @property def accumulator_class(self): @@ -77,6 +80,7 @@ def accumulator_class(self): Reduce.MAX: MaxAccumulator, Reduce.MIN: MinAccumulator, Reduce.STD: StdAccumulator, + Reduce.SAMPLE: SampleAccumulator, } return mapping[self] @@ -182,6 +186,55 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri return reduced_metrics +################# +# SampleFilters # +################# + + +class TopBottomKFilter: + """Keep the top-k and bottom-k samples by a given key (e.g., reward).""" + + def __init__(self, top_k=1, bottom_k=1, key="reward"): + self.top_k = top_k + self.bottom_k = bottom_k + self.key = key + self._top_heap = [] # min-heap for top-k + self._bottom_heap = [] # max-heap for bottom-k (store -value) + self._counter = itertools.count() # tie-breaker id generator + + def filter_append(self, sample: Dict) -> bool: + val = sample.get(self.key, 0.0) + idx = next(self._counter) # unique tiebreaker + + # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none). + # maintain top-k + if self.top_k > 0: + if len(self._top_heap) < self.top_k: + heapq.heappush(self._top_heap, (val, idx, sample)) + else: + heapq.heappushpop(self._top_heap, (val, idx, sample)) + + # maintain bottom-k + if self.bottom_k > 0: + if len(self._bottom_heap) < self.bottom_k: + heapq.heappush(self._bottom_heap, (-val, idx, sample)) + else: + heapq.heappushpop(self._bottom_heap, (-val, idx, sample)) + + # always return False here because we don't store in samples list + return False + + def filter_flush(self, samples: List[Dict]) -> List[Dict]: + tops = [s for _, _, s in self._top_heap] + bottoms = [s for _, _, s in self._bottom_heap] + return bottoms + tops + + def reset(self): + self._top_heap = [] + self._bottom_heap = [] + self._counter = itertools.count() + + ################ # Accumulators # ################ @@ -392,6 +445,50 @@ def reset(self) -> None: self.count = 0 +class SampleAccumulator(MetricAccumulator): + """Accumulator for sample-level metrics (e.g., prompt/response/reward dicts). + Optionally uses a sample filter to decide what to keep at append/flush time. + """ + + def __init__(self, reduction: Reduce): + super().__init__(reduction) + self.samples: List[Dict[str, Any]] = [] + self.filter = TopBottomKFilter() + + def append(self, value: dict) -> None: + if not isinstance(value, dict): + raise ValueError(f"Expected dict, got {type(value)}") + + # Only keep the sample if filter_append returns True + if self.filter.filter_append(value): + self.samples.append(value) + + def get_value(self) -> list[dict]: + """Return locally collected (and optionally filtered) samples.""" + # Apply flush-time filter (e.g. heap selection, threshold trimming) + return self.filter.filter_flush(self.samples) + + def get_state(self) -> Dict[str, Any]: + """Serialize accumulator state for cross-rank reduction.""" + return { + "reduction_type": self.reduction_type.value, + "samples": self.get_value(), + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dict]: + """Merge sample states across ranks.""" + merged = [] + for s in states: + merged.extend(s.get("samples", [])) + return merged + + def reset(self) -> None: + """Clear local samples and reset filter state.""" + self.samples.clear() + self.filter.reset() + + ############# # Collector # ############# diff --git a/tests/unit_tests/data/test_metrics_aggregator.py b/tests/unit_tests/data/test_metrics_aggregator.py index 5b847c92f..16c9f5f1a 100644 --- a/tests/unit_tests/data/test_metrics_aggregator.py +++ b/tests/unit_tests/data/test_metrics_aggregator.py @@ -246,6 +246,22 @@ def test_handler_replacement_warning(self, caplog): assert len(caplog.records) == 1 assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message + def test_sample_accumulator_with_topbottom_filter(self): + """Ensure SampleAccumulator integrates with TopBottomKFilter correctly.""" + from forge.observability.metrics import Reduce, SampleAccumulator + + acc = SampleAccumulator(Reduce.SAMPLE) + + rewards = [0.1, 0.9, 0.5, 0.7, 0.3] + for r in rewards: + acc.append({"reward": r, "prompt": "Q", "response": "A"}) + + result = acc.get_value() + result_rewards = sorted(s["reward"] for s in result) + + # Expect bottom-1 (0.1) and top-2 (0.7, 0.9) + assert result_rewards == [0.1, 0.9] + class TestDistributedMetricsAggregator(FSDPTest): """Distributed tests for MetricsAggregator using FSDPTest infrastructure.""" From 23b88c1e07560487cf43d93cf66251b41b8e5f6e Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 21 Oct 2025 16:02:28 -0700 Subject: [PATCH 2/7] functions, tests --- apps/grpo/main.py | 16 ++- src/forge/observability/__init__.py | 2 + src/forge/observability/metric_actors.py | 13 +- src/forge/observability/metrics.py | 135 ++++++++++++++++-- .../unit_tests/observability/test_metrics.py | 69 ++++++--- 5 files changed, 203 insertions(+), 32 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index ef522e57b..d051a4f0f 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -29,7 +29,7 @@ from forge.data.rewards import MathReward, ThinkingReward from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger -from forge.observability.metrics import record_metric, Reduce +from forge.observability.metrics import record_episode_sample, record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig @@ -51,6 +51,7 @@ class Episode: completion: Completion | None = None ref_logprobs: torch.Tensor | None = None reward: float | None = None + reward_breakdown: dict[str, float] | None = None advantage: float | None = None @property @@ -151,8 +152,11 @@ class RewardActor(ForgeActor): reward_functions: list[Callable] @endpoint - async def evaluate_response(self, prompt: str, response: str, target: str) -> float: + async def evaluate_response( + self, prompt: str, response: str, target: str + ) -> dict[str, float]: total_rewards = 0.0 + reward_breakdown = {} # reward breakdown by function for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) total_rewards += reward @@ -161,6 +165,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl reward_fn_name = getattr( reward_fn, "__name__", reward_fn.__class__.__name__ ) + reward_breakdown[reward_fn_name] = reward # per function reward record_metric( f"reward/evaluate_response/sum_{reward_fn_name}_reward", @@ -193,7 +198,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl ) avg_reward = total_rewards / len(self.reward_functions) - return avg_reward + reward_breakdown["reward"] = avg_reward + return reward_breakdown @dataclass @@ -386,9 +392,10 @@ async def continuous_rollouts(): target=target, completion=response, ) - episode.reward = await reward_actor.evaluate_response.route( + episode.reward_breakdown = await reward_actor.evaluate_response.route( prompt=prompt, response=response.text, target=target ) + episode.reward = episode.reward_breakdown["reward"] episodes.append(episode) # Build input_ids for reference logprobs @@ -411,6 +418,7 @@ async def continuous_rollouts(): for episode, advantage in zip(episodes, advantages): episode.advantage = advantage await replay_buffer.add.call_one(episode) + record_episode_sample("rollout/sample", episode) # Log metrics rollout_count += 1 diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 474483426..e80bd1860 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -21,6 +21,7 @@ MetricAccumulator, MetricCollector, MinAccumulator, + record_episode_sample, record_metric, Reduce, reduce_metrics_states, @@ -37,6 +38,7 @@ # Main API functions "record_metric", "reduce_metrics_states", + "record_episode_sample", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index e1f0a65df..175b10096 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -423,9 +423,20 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) + # Split into scalar metrics and sample metrics + scalar_metrics = [ + m for m in reduced_metrics if m.reduction != Reduce.SAMPLE + ] + sample_metrics = { + m.key: m.value for m in reduced_metrics if m.reduction == Reduce.SAMPLE + } + # Log to global backends for backend_name, backend in self.global_logger_backends.items(): - await backend.log_batch(reduced_metrics, global_step) + if scalar_metrics: + await backend.log_batch(scalar_metrics, global_step) + if sample_metrics: + await backend.log_samples(sample_metrics, global_step) @endpoint def has_fetcher(self, proc_id: str) -> bool: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 065d233be..fd9e56ead 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -139,12 +139,32 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri list[Metric]: List of reduced metrics Example: - states = [ - {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}}, - {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, - ] - reduce_metrics_states(states) - >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] + >>> states = [ + ... { + ... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"}, + ... "reward/sample": { + ... "reduction_type": "sample", + ... "samples": [{"episode_id": 1, "reward": 0.5}], + ... }, + ... }, + ... { + ... "loss": {"count": 10, "sum": 16, "reduction_type": "mean"}, + ... "reward/sample": { + ... "reduction_type": "sample", + ... "samples": [{"episode_id": 2, "reward": 1.0}], + ... }, + ... }, + ... ] + >>> metrics = reduce_metrics_states(states) + >>> for m in metrics: + ... print(m) + Metric(key='loss', value=2.0, reduction=Reduce.MEAN) + Metric( + key='reward/sample', + value=[{'episode_id': 1, 'reward': 0.5}, + {'episode_id': 2, 'reward': 1.0}], + reduction=Reduce.SAMPLE, + ) Raises: ValueError: on mismatched reduction types for the same metric key. @@ -186,6 +206,31 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri return reduced_metrics +def record_episode_sample(table_name: str, episode): + """ + Record a structured sample-level log for a single episode. + Args: + table_name (str): logging prefix (e.g. "rollout/sample"). + episode (Episode): episode object with filled attributes. + """ + sample = { + "episode_id": episode.episode_id, + "policy_version": episode.policy_version, + "prompt": episode.request, + "response": episode.response, + "target": str(episode.target), + **( + episode.reward_breakdown or {} + ), # per-fn breakdown including the average reward + "advantage": episode.advantage, + "request_len": episode.request_len, + "response_len": episode.response_len, + "pad_id": episode.pad_id, + } + + record_metric(table_name, sample, Reduce.SAMPLE) + + ################# # SampleFilters # ################# @@ -656,7 +701,12 @@ def push(self, metric: Metric) -> None: # For PER_RANK_NO_REDUCE backends: stream without reduce for backend in self.per_rank_no_reduce_backends: - backend.log_stream(metric=metric, global_step=self.global_step) + if metric.reduction == Reduce.SAMPLE: + # Wrap singleton Metric into expected {key: [list_of_dicts]} format + sample = {metric.key: [metric.value]} + asyncio.create_task(backend.log_samples(sample, self.global_step)) + else: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key @@ -711,8 +761,21 @@ async def flush( if self.per_rank_reduce_backends: metrics_for_backends = reduce_metrics_states([states]) + # Split into scalar metrics and sample metrics + scalar_metrics = [ + m for m in metrics_for_backends if m.reduction != Reduce.SAMPLE + ] + sample_metrics = { + m.key: m.value + for m in metrics_for_backends + if m.reduction == Reduce.SAMPLE + } + for backend in self.per_rank_reduce_backends: - await backend.log_batch(metrics_for_backends, global_step) + if scalar_metrics: + await backend.log_batch(scalar_metrics, global_step) + if sample_metrics: + await backend.log_samples(sample_metrics, global_step) # Update step counter for streaming backends # Note: This is incremented AFTER flush completes, so metrics recorded between @@ -831,6 +894,16 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: async def finish(self) -> None: pass + async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: + """Pretty-print sample-level logs to console.""" + import json + + logger.info(f"========== SAMPLE LOGS STEP {step} ==========") + for table_name, table_rows in samples.items(): + logger.info(f"[{table_name}] ({len(table_rows)} samples)") + logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False)) + logger.info("==============================================\n") + class WandbBackend(LoggerBackend): """ @@ -857,6 +930,7 @@ def __init__(self, logger_backend_config: dict[str, Any]) -> None: self.run = None self.logging_mode = LoggingMode(logger_backend_config["logging_mode"]) self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False) + self._tables: dict[str, "wandb.Table"] = {} async def init( self, @@ -968,13 +1042,58 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: # note: here we dont use step since wandb keeps only the latest value for each step self.run.log(log_data) + async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: + """Log sample-level data incrementally to persistent WandB Tables.""" + import wandb + + if not self.run: + return + + for table_name, table_rows in samples.items(): + if not table_rows: + continue + + # If table doesn't exist yet, create it in INCREMENTAL mode + if table_name not in self._tables: + columns = list(table_rows[0].keys()) + table = wandb.Table(columns=columns, log_mode="INCREMENTAL") + self._tables[table_name] = table + logger.info( + f"WandbBackend: Created new incremental table: {table_name}" + ) + else: + table = self._tables[table_name] + + # Add rows (fill missing columns with None) + for s in table_rows: + values = [s.get(c) for c in table.columns] + table.add_data(*values) + + # Log the same table object (INCREMENTAL update) + self.run.log({f"{table_name}_table": table}) + logger.info( + f"WandbBackend: Appended {len(table_rows)} rows to incremental table '{table_name}' at step {step}" + ) + def get_metadata_for_secondary_ranks(self) -> dict[str, Any]: if self.run and self.per_rank_share_run: return {"shared_run_id": self.run.id} return {} async def finish(self) -> None: + import wandb + if self.run: + # Convert each incremental table to immutable before finishing + for table_name, incr_table in self._tables.items(): + final_table = wandb.Table( + columns=incr_table.columns, + data=incr_table.data, + log_mode="IMMUTABLE", + ) + self.run.log({table_name: final_table}) + logger.info(f"WandbBackend: Finalized table {table_name}") + self.run.finish() logger.info(f"WandbBackend {self.process_name}: Finished run") diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index cda3679a5..412640b31 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -115,33 +115,64 @@ def test_empty_states(self): def test_single_state(self): """Test reduce_metrics_states with single state.""" - states = [{"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}] - result = reduce_metrics_states(states) - assert len(result) == 1 - assert result[0].key == "loss" - assert result[0].value == 5.0 - assert result[0].reduction == Reduce.MEAN + states = [ + { + "loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}, + "rollout/sample": { + "reduction_type": "sample", + "samples": [{"id": 1, "reward": 0.5}], + }, + } + ] + metrics = reduce_metrics_states(states) + assert len(metrics) == 2 + # Convert to dict for easier testing + result_dict = {m.key: (m.value, m.reduction) for m in metrics} + + assert result_dict["loss"][0] == 5.0 + assert result_dict["loss"][1] == Reduce.MEAN + + assert result_dict["rollout/sample"][0] == [{"id": 1, "reward": 0.5}] + assert result_dict["rollout/sample"][1] == Reduce.SAMPLE def test_multiple_states(self): """Test reduce_metrics_states with multiple states.""" states = [ - {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, - {"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}}, + { + "loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}, + "rollout/sample": { + "reduction_type": "sample", + "samples": [{"id": 1, "reward": 0.5}], + }, + }, + { + "loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}, + "rollout/sample": { + "reduction_type": "sample", + "samples": [{"id": 2, "reward": 0.8}], + }, + }, {"accuracy": {"reduction_type": "sum", "total": 15.0}}, ] - result = reduce_metrics_states(states) + metrics = reduce_metrics_states(states) + + assert len(metrics) == 3 # Convert to dict for easier testing - result_dict = {metric.key: metric.value for metric in result} - assert result_dict["loss"] == 30.0 / 5.0 # 6.0 - assert result_dict["accuracy"] == 15.0 - - # Also check reduction types - for metric in result: - if metric.key == "loss": - assert metric.reduction == Reduce.MEAN - elif metric.key == "accuracy": - assert metric.reduction == Reduce.SUM + result_dict = {m.key: (m.value, m.reduction) for m in metrics} + + # Check scalar reductions + assert result_dict["loss"][0] == 30.0 / 5.0 # 6.0 + assert result_dict["loss"][1] == Reduce.MEAN + assert result_dict["accuracy"][0] == 15.0 + assert result_dict["accuracy"][1] == Reduce.SUM + + # Check sample concatenation + assert result_dict["rollout/sample"][0] == [ + {"id": 1, "reward": 0.5}, + {"id": 2, "reward": 0.8}, + ] + assert result_dict["rollout/sample"][1] == Reduce.SAMPLE def test_mismatched_reduction_types_raises_error(self): """Test reduce_metrics_states raises error for mismatched reduction types.""" From 42f0708cc358516d9ec3001c1ff6473cca75cdc0 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 21 Oct 2025 19:02:38 -0700 Subject: [PATCH 3/7] fix error + some debug messages --- src/forge/observability/metric_actors.py | 3 +++ src/forge/observability/metrics.py | 30 ++++++++++++++++++------ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 175b10096..c8851a959 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -26,6 +26,7 @@ LoggerBackend, LoggingMode, MetricCollector, + Reduce, reduce_metrics_states, ) @@ -423,6 +424,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) + print(f"[DEBUG] reduced_metrics: {reduced_metrics}") # Split into scalar metrics and sample metrics scalar_metrics = [ m for m in reduced_metrics if m.reduction != Reduce.SAMPLE @@ -434,6 +436,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Log to global backends for backend_name, backend in self.global_logger_backends.items(): if scalar_metrics: + print(f"[DEBUG] calling log_batch from GlobalLoggerActor") await backend.log_batch(scalar_metrics, global_step) if sample_metrics: await backend.log_samples(sample_metrics, global_step) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index fd9e56ead..6cd6e5051 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -228,7 +228,15 @@ def record_episode_sample(table_name: str, episode): "pad_id": episode.pad_id, } + print( + "[DEBUG] Adding sample to table via record_metric, episode_id: ", + episode.episode_id, + ) record_metric(table_name, sample, Reduce.SAMPLE) + print( + "[DEBUG] Added sample to table via record_metric, episode_id: ", + episode.episode_id, + ) ################# @@ -499,11 +507,13 @@ def __init__(self, reduction: Reduce): super().__init__(reduction) self.samples: List[Dict[str, Any]] = [] self.filter = TopBottomKFilter() + self.is_reset = True def append(self, value: dict) -> None: if not isinstance(value, dict): raise ValueError(f"Expected dict, got {type(value)}") + self.is_reset = False # Only keep the sample if filter_append returns True if self.filter.filter_append(value): self.samples.append(value) @@ -511,7 +521,8 @@ def append(self, value: dict) -> None: def get_value(self) -> list[dict]: """Return locally collected (and optionally filtered) samples.""" # Apply flush-time filter (e.g. heap selection, threshold trimming) - return self.filter.filter_flush(self.samples) + results = self.filter.filter_flush(self.samples) + return results def get_state(self) -> Dict[str, Any]: """Serialize accumulator state for cross-rank reduction.""" @@ -530,6 +541,7 @@ def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dic def reset(self) -> None: """Clear local samples and reset filter state.""" + self.is_reset = True self.samples.clear() self.filter.reset() @@ -701,12 +713,12 @@ def push(self, metric: Metric) -> None: # For PER_RANK_NO_REDUCE backends: stream without reduce for backend in self.per_rank_no_reduce_backends: - if metric.reduction == Reduce.SAMPLE: - # Wrap singleton Metric into expected {key: [list_of_dicts]} format - sample = {metric.key: [metric.value]} - asyncio.create_task(backend.log_samples(sample, self.global_step)) - else: - backend.log_stream(metric=metric, global_step=self.global_step) + # if metric.reduction == Reduce.SAMPLE: + # # Wrap singleton Metric into expected {key: [list_of_dicts]} format + # sample = {metric.key: [metric.value]} + # asyncio.create_task(backend.log_samples(sample, self.global_step)) + # else: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key @@ -773,6 +785,7 @@ async def flush( for backend in self.per_rank_reduce_backends: if scalar_metrics: + print(f"[DEBUG] calling log_batch from MetricCollector") await backend.log_batch(scalar_metrics, global_step) if sample_metrics: await backend.log_samples(sample_metrics, global_step) @@ -880,6 +893,7 @@ async def init( async def log_batch( self, metrics: list[Metric], global_step: int, *args, **kwargs ) -> None: + print(f"[DEBUG] calling log_batch with {len(metrics)} metrics") metrics_str = "\n".join( f" {metric.key}: {metric.value}" for metric in sorted(metrics, key=lambda m: m.key) @@ -898,6 +912,8 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: """Pretty-print sample-level logs to console.""" import json + print(f"[DEBUG] calling log_samples with {len(samples)} samples") + logger.info(f"========== SAMPLE LOGS STEP {step} ==========") for table_name, table_rows in samples.items(): logger.info(f"[{table_name}] ({len(table_rows)} samples)") From 76ac3d2da466ddbeebb8af263bd23c24c0e93140 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 21 Oct 2025 19:29:51 -0700 Subject: [PATCH 4/7] fix hanging issue: missing entries --- apps/grpo/main.py | 4 ++++ src/forge/observability/metrics.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index d051a4f0f..95f846b29 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -47,6 +47,8 @@ class Episode: request_len: int response_len: int target: Any | None = None + request: str | None = None + response: str | None = None # Processed data completion: Completion | None = None ref_logprobs: torch.Tensor | None = None @@ -390,6 +392,8 @@ async def continuous_rollouts(): request_len=max_req_tokens, response_len=max_res_tokens, target=target, + request=prompt, + response=response.text, completion=response, ) episode.reward_breakdown = await reward_actor.evaluate_response.route( diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 6cd6e5051..05ad9abee 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -231,6 +231,8 @@ def record_episode_sample(table_name: str, episode): print( "[DEBUG] Adding sample to table via record_metric, episode_id: ", episode.episode_id, + # "episode: ", + # episode, ) record_metric(table_name, sample, Reduce.SAMPLE) print( From f45ac2a0fb63d09a01596dfd94c4195cbebe8b07 Mon Sep 17 00:00:00 2001 From: DNXie Date: Fri, 24 Oct 2025 10:11:23 -0700 Subject: [PATCH 5/7] per_rank_no_reduce mode --- src/forge/observability/metrics.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 05ad9abee..5da5045d7 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import heapq import itertools import logging @@ -715,12 +716,13 @@ def push(self, metric: Metric) -> None: # For PER_RANK_NO_REDUCE backends: stream without reduce for backend in self.per_rank_no_reduce_backends: - # if metric.reduction == Reduce.SAMPLE: - # # Wrap singleton Metric into expected {key: [list_of_dicts]} format - # sample = {metric.key: [metric.value]} - # asyncio.create_task(backend.log_samples(sample, self.global_step)) - # else: - backend.log_stream(metric=metric, global_step=self.global_step) + + if metric.reduction == Reduce.SAMPLE: + # Wrap singleton Metric into expected {key: [list_of_dicts]} format + sample = {metric.key: [metric.value]} + asyncio.create_task(backend.log_samples(sample, self.global_step)) + else: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key From aac01884d9820aa4c360110be44d18b550f8b5a3 Mon Sep 17 00:00:00 2001 From: DNXie Date: Fri, 24 Oct 2025 10:13:14 -0700 Subject: [PATCH 6/7] clean up debug prints --- src/forge/observability/metric_actors.py | 2 -- src/forge/observability/metrics.py | 17 +---------------- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index c8851a959..950dbba85 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -424,7 +424,6 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) - print(f"[DEBUG] reduced_metrics: {reduced_metrics}") # Split into scalar metrics and sample metrics scalar_metrics = [ m for m in reduced_metrics if m.reduction != Reduce.SAMPLE @@ -436,7 +435,6 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Log to global backends for backend_name, backend in self.global_logger_backends.items(): if scalar_metrics: - print(f"[DEBUG] calling log_batch from GlobalLoggerActor") await backend.log_batch(scalar_metrics, global_step) if sample_metrics: await backend.log_samples(sample_metrics, global_step) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 5da5045d7..6cedbdbd0 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -228,18 +228,7 @@ def record_episode_sample(table_name: str, episode): "response_len": episode.response_len, "pad_id": episode.pad_id, } - - print( - "[DEBUG] Adding sample to table via record_metric, episode_id: ", - episode.episode_id, - # "episode: ", - # episode, - ) record_metric(table_name, sample, Reduce.SAMPLE) - print( - "[DEBUG] Added sample to table via record_metric, episode_id: ", - episode.episode_id, - ) ################# @@ -789,7 +778,6 @@ async def flush( for backend in self.per_rank_reduce_backends: if scalar_metrics: - print(f"[DEBUG] calling log_batch from MetricCollector") await backend.log_batch(scalar_metrics, global_step) if sample_metrics: await backend.log_samples(sample_metrics, global_step) @@ -897,7 +885,6 @@ async def init( async def log_batch( self, metrics: list[Metric], global_step: int, *args, **kwargs ) -> None: - print(f"[DEBUG] calling log_batch with {len(metrics)} metrics") metrics_str = "\n".join( f" {metric.key}: {metric.value}" for metric in sorted(metrics, key=lambda m: m.key) @@ -916,12 +903,10 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: """Pretty-print sample-level logs to console.""" import json - print(f"[DEBUG] calling log_samples with {len(samples)} samples") - logger.info(f"========== SAMPLE LOGS STEP {step} ==========") for table_name, table_rows in samples.items(): logger.info(f"[{table_name}] ({len(table_rows)} samples)") - logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False)) + logger.info(json.dumps(table_rows)) logger.info("==============================================\n") From 328f9b37d15c2551d92554b5d16b5355c74b39ae Mon Sep 17 00:00:00 2001 From: DNXie Date: Fri, 24 Oct 2025 10:17:01 -0700 Subject: [PATCH 7/7] json pprint --- src/forge/observability/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 6cedbdbd0..0da731a7d 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -906,7 +906,7 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: logger.info(f"========== SAMPLE LOGS STEP {step} ==========") for table_name, table_rows in samples.items(): logger.info(f"[{table_name}] ({len(table_rows)} samples)") - logger.info(json.dumps(table_rows)) + logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False)) logger.info("==============================================\n")