diff --git a/apps/grpo/main.py b/apps/grpo/main.py index ef522e57b..95f846b29 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -29,7 +29,7 @@ from forge.data.rewards import MathReward, ThinkingReward from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger -from forge.observability.metrics import record_metric, Reduce +from forge.observability.metrics import record_episode_sample, record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig @@ -47,10 +47,13 @@ class Episode: request_len: int response_len: int target: Any | None = None + request: str | None = None + response: str | None = None # Processed data completion: Completion | None = None ref_logprobs: torch.Tensor | None = None reward: float | None = None + reward_breakdown: dict[str, float] | None = None advantage: float | None = None @property @@ -151,8 +154,11 @@ class RewardActor(ForgeActor): reward_functions: list[Callable] @endpoint - async def evaluate_response(self, prompt: str, response: str, target: str) -> float: + async def evaluate_response( + self, prompt: str, response: str, target: str + ) -> dict[str, float]: total_rewards = 0.0 + reward_breakdown = {} # reward breakdown by function for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) total_rewards += reward @@ -161,6 +167,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl reward_fn_name = getattr( reward_fn, "__name__", reward_fn.__class__.__name__ ) + reward_breakdown[reward_fn_name] = reward # per function reward record_metric( f"reward/evaluate_response/sum_{reward_fn_name}_reward", @@ -193,7 +200,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl ) avg_reward = total_rewards / len(self.reward_functions) - return avg_reward + reward_breakdown["reward"] = avg_reward + return reward_breakdown @dataclass @@ -384,11 +392,14 @@ async def continuous_rollouts(): request_len=max_req_tokens, response_len=max_res_tokens, target=target, + request=prompt, + response=response.text, completion=response, ) - episode.reward = await reward_actor.evaluate_response.route( + episode.reward_breakdown = await reward_actor.evaluate_response.route( prompt=prompt, response=response.text, target=target ) + episode.reward = episode.reward_breakdown["reward"] episodes.append(episode) # Build input_ids for reference logprobs @@ -411,6 +422,7 @@ async def continuous_rollouts(): for episode, advantage in zip(episodes, advantages): episode.advantage = advantage await replay_buffer.add.call_one(episode) + record_episode_sample("rollout/sample", episode) # Log metrics rollout_count += 1 diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 555aa761e..e80bd1860 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -21,11 +21,14 @@ MetricAccumulator, MetricCollector, MinAccumulator, + record_episode_sample, record_metric, Reduce, reduce_metrics_states, + SampleAccumulator, StdAccumulator, SumAccumulator, + TopBottomKFilter, WandbBackend, ) from .perf_tracker import trace, Tracer @@ -35,6 +38,7 @@ # Main API functions "record_metric", "reduce_metrics_states", + "record_episode_sample", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking @@ -64,4 +68,7 @@ "MaxAccumulator", "MinAccumulator", "StdAccumulator", + "SampleAccumulator", + # Filter classes + "TopBottomKFilter", ] diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index e1f0a65df..950dbba85 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -26,6 +26,7 @@ LoggerBackend, LoggingMode, MetricCollector, + Reduce, reduce_metrics_states, ) @@ -423,9 +424,20 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) + # Split into scalar metrics and sample metrics + scalar_metrics = [ + m for m in reduced_metrics if m.reduction != Reduce.SAMPLE + ] + sample_metrics = { + m.key: m.value for m in reduced_metrics if m.reduction == Reduce.SAMPLE + } + # Log to global backends for backend_name, backend in self.global_logger_backends.items(): - await backend.log_batch(reduced_metrics, global_step) + if scalar_metrics: + await backend.log_batch(scalar_metrics, global_step) + if sample_metrics: + await backend.log_samples(sample_metrics, global_step) @endpoint def has_fetcher(self, proc_id: str) -> bool: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 55a3c31a2..0da731a7d 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -4,13 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio +import heapq +import itertools import logging import os from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, Dict, List import pytz from monarch.actor import current_rank @@ -68,6 +71,7 @@ class Reduce(Enum): MAX = "max" MIN = "min" STD = "std" + SAMPLE = "sample" @property def accumulator_class(self): @@ -77,6 +81,7 @@ def accumulator_class(self): Reduce.MAX: MaxAccumulator, Reduce.MIN: MinAccumulator, Reduce.STD: StdAccumulator, + Reduce.SAMPLE: SampleAccumulator, } return mapping[self] @@ -135,12 +140,32 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri list[Metric]: List of reduced metrics Example: - states = [ - {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}}, - {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, - ] - reduce_metrics_states(states) - >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] + >>> states = [ + ... { + ... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"}, + ... "reward/sample": { + ... "reduction_type": "sample", + ... "samples": [{"episode_id": 1, "reward": 0.5}], + ... }, + ... }, + ... { + ... "loss": {"count": 10, "sum": 16, "reduction_type": "mean"}, + ... "reward/sample": { + ... "reduction_type": "sample", + ... "samples": [{"episode_id": 2, "reward": 1.0}], + ... }, + ... }, + ... ] + >>> metrics = reduce_metrics_states(states) + >>> for m in metrics: + ... print(m) + Metric(key='loss', value=2.0, reduction=Reduce.MEAN) + Metric( + key='reward/sample', + value=[{'episode_id': 1, 'reward': 0.5}, + {'episode_id': 2, 'reward': 1.0}], + reduction=Reduce.SAMPLE, + ) Raises: ValueError: on mismatched reduction types for the same metric key. @@ -182,6 +207,79 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri return reduced_metrics +def record_episode_sample(table_name: str, episode): + """ + Record a structured sample-level log for a single episode. + Args: + table_name (str): logging prefix (e.g. "rollout/sample"). + episode (Episode): episode object with filled attributes. + """ + sample = { + "episode_id": episode.episode_id, + "policy_version": episode.policy_version, + "prompt": episode.request, + "response": episode.response, + "target": str(episode.target), + **( + episode.reward_breakdown or {} + ), # per-fn breakdown including the average reward + "advantage": episode.advantage, + "request_len": episode.request_len, + "response_len": episode.response_len, + "pad_id": episode.pad_id, + } + record_metric(table_name, sample, Reduce.SAMPLE) + + +################# +# SampleFilters # +################# + + +class TopBottomKFilter: + """Keep the top-k and bottom-k samples by a given key (e.g., reward).""" + + def __init__(self, top_k=1, bottom_k=1, key="reward"): + self.top_k = top_k + self.bottom_k = bottom_k + self.key = key + self._top_heap = [] # min-heap for top-k + self._bottom_heap = [] # max-heap for bottom-k (store -value) + self._counter = itertools.count() # tie-breaker id generator + + def filter_append(self, sample: Dict) -> bool: + val = sample.get(self.key, 0.0) + idx = next(self._counter) # unique tiebreaker + + # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none). + # maintain top-k + if self.top_k > 0: + if len(self._top_heap) < self.top_k: + heapq.heappush(self._top_heap, (val, idx, sample)) + else: + heapq.heappushpop(self._top_heap, (val, idx, sample)) + + # maintain bottom-k + if self.bottom_k > 0: + if len(self._bottom_heap) < self.bottom_k: + heapq.heappush(self._bottom_heap, (-val, idx, sample)) + else: + heapq.heappushpop(self._bottom_heap, (-val, idx, sample)) + + # always return False here because we don't store in samples list + return False + + def filter_flush(self, samples: List[Dict]) -> List[Dict]: + tops = [s for _, _, s in self._top_heap] + bottoms = [s for _, _, s in self._bottom_heap] + return bottoms + tops + + def reset(self): + self._top_heap = [] + self._bottom_heap = [] + self._counter = itertools.count() + + ################ # Accumulators # ################ @@ -392,6 +490,54 @@ def reset(self) -> None: self.count = 0 +class SampleAccumulator(MetricAccumulator): + """Accumulator for sample-level metrics (e.g., prompt/response/reward dicts). + Optionally uses a sample filter to decide what to keep at append/flush time. + """ + + def __init__(self, reduction: Reduce): + super().__init__(reduction) + self.samples: List[Dict[str, Any]] = [] + self.filter = TopBottomKFilter() + self.is_reset = True + + def append(self, value: dict) -> None: + if not isinstance(value, dict): + raise ValueError(f"Expected dict, got {type(value)}") + + self.is_reset = False + # Only keep the sample if filter_append returns True + if self.filter.filter_append(value): + self.samples.append(value) + + def get_value(self) -> list[dict]: + """Return locally collected (and optionally filtered) samples.""" + # Apply flush-time filter (e.g. heap selection, threshold trimming) + results = self.filter.filter_flush(self.samples) + return results + + def get_state(self) -> Dict[str, Any]: + """Serialize accumulator state for cross-rank reduction.""" + return { + "reduction_type": self.reduction_type.value, + "samples": self.get_value(), + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dict]: + """Merge sample states across ranks.""" + merged = [] + for s in states: + merged.extend(s.get("samples", [])) + return merged + + def reset(self) -> None: + """Clear local samples and reset filter state.""" + self.is_reset = True + self.samples.clear() + self.filter.reset() + + ############# # Collector # ############# @@ -559,7 +705,13 @@ def push(self, metric: Metric) -> None: # For PER_RANK_NO_REDUCE backends: stream without reduce for backend in self.per_rank_no_reduce_backends: - backend.log_stream(metric=metric, global_step=self.global_step) + + if metric.reduction == Reduce.SAMPLE: + # Wrap singleton Metric into expected {key: [list_of_dicts]} format + sample = {metric.key: [metric.value]} + asyncio.create_task(backend.log_samples(sample, self.global_step)) + else: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key @@ -614,8 +766,21 @@ async def flush( if self.per_rank_reduce_backends: metrics_for_backends = reduce_metrics_states([states]) + # Split into scalar metrics and sample metrics + scalar_metrics = [ + m for m in metrics_for_backends if m.reduction != Reduce.SAMPLE + ] + sample_metrics = { + m.key: m.value + for m in metrics_for_backends + if m.reduction == Reduce.SAMPLE + } + for backend in self.per_rank_reduce_backends: - await backend.log_batch(metrics_for_backends, global_step) + if scalar_metrics: + await backend.log_batch(scalar_metrics, global_step) + if sample_metrics: + await backend.log_samples(sample_metrics, global_step) # Update step counter for streaming backends # Note: This is incremented AFTER flush completes, so metrics recorded between @@ -734,6 +899,16 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: async def finish(self) -> None: pass + async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: + """Pretty-print sample-level logs to console.""" + import json + + logger.info(f"========== SAMPLE LOGS STEP {step} ==========") + for table_name, table_rows in samples.items(): + logger.info(f"[{table_name}] ({len(table_rows)} samples)") + logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False)) + logger.info("==============================================\n") + class WandbBackend(LoggerBackend): """ @@ -760,6 +935,7 @@ def __init__(self, logger_backend_config: dict[str, Any]) -> None: self.run = None self.logging_mode = LoggingMode(logger_backend_config["logging_mode"]) self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False) + self._tables: dict[str, "wandb.Table"] = {} async def init( self, @@ -871,13 +1047,58 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: # note: here we dont use step since wandb keeps only the latest value for each step self.run.log(log_data) + async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: + """Log sample-level data incrementally to persistent WandB Tables.""" + import wandb + + if not self.run: + return + + for table_name, table_rows in samples.items(): + if not table_rows: + continue + + # If table doesn't exist yet, create it in INCREMENTAL mode + if table_name not in self._tables: + columns = list(table_rows[0].keys()) + table = wandb.Table(columns=columns, log_mode="INCREMENTAL") + self._tables[table_name] = table + logger.info( + f"WandbBackend: Created new incremental table: {table_name}" + ) + else: + table = self._tables[table_name] + + # Add rows (fill missing columns with None) + for s in table_rows: + values = [s.get(c) for c in table.columns] + table.add_data(*values) + + # Log the same table object (INCREMENTAL update) + self.run.log({f"{table_name}_table": table}) + logger.info( + f"WandbBackend: Appended {len(table_rows)} rows to incremental table '{table_name}' at step {step}" + ) + def get_metadata_for_secondary_ranks(self) -> dict[str, Any]: if self.run and self.per_rank_share_run: return {"shared_run_id": self.run.id} return {} async def finish(self) -> None: + import wandb + if self.run: + # Convert each incremental table to immutable before finishing + for table_name, incr_table in self._tables.items(): + final_table = wandb.Table( + columns=incr_table.columns, + data=incr_table.data, + log_mode="IMMUTABLE", + ) + self.run.log({table_name: final_table}) + logger.info(f"WandbBackend: Finalized table {table_name}") + self.run.finish() logger.info(f"WandbBackend {self.process_name}: Finished run") diff --git a/tests/unit_tests/data/test_metrics_aggregator.py b/tests/unit_tests/data/test_metrics_aggregator.py index 5b847c92f..16c9f5f1a 100644 --- a/tests/unit_tests/data/test_metrics_aggregator.py +++ b/tests/unit_tests/data/test_metrics_aggregator.py @@ -246,6 +246,22 @@ def test_handler_replacement_warning(self, caplog): assert len(caplog.records) == 1 assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message + def test_sample_accumulator_with_topbottom_filter(self): + """Ensure SampleAccumulator integrates with TopBottomKFilter correctly.""" + from forge.observability.metrics import Reduce, SampleAccumulator + + acc = SampleAccumulator(Reduce.SAMPLE) + + rewards = [0.1, 0.9, 0.5, 0.7, 0.3] + for r in rewards: + acc.append({"reward": r, "prompt": "Q", "response": "A"}) + + result = acc.get_value() + result_rewards = sorted(s["reward"] for s in result) + + # Expect bottom-1 (0.1) and top-2 (0.7, 0.9) + assert result_rewards == [0.1, 0.9] + class TestDistributedMetricsAggregator(FSDPTest): """Distributed tests for MetricsAggregator using FSDPTest infrastructure.""" diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index cda3679a5..412640b31 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -115,33 +115,64 @@ def test_empty_states(self): def test_single_state(self): """Test reduce_metrics_states with single state.""" - states = [{"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}] - result = reduce_metrics_states(states) - assert len(result) == 1 - assert result[0].key == "loss" - assert result[0].value == 5.0 - assert result[0].reduction == Reduce.MEAN + states = [ + { + "loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}, + "rollout/sample": { + "reduction_type": "sample", + "samples": [{"id": 1, "reward": 0.5}], + }, + } + ] + metrics = reduce_metrics_states(states) + assert len(metrics) == 2 + # Convert to dict for easier testing + result_dict = {m.key: (m.value, m.reduction) for m in metrics} + + assert result_dict["loss"][0] == 5.0 + assert result_dict["loss"][1] == Reduce.MEAN + + assert result_dict["rollout/sample"][0] == [{"id": 1, "reward": 0.5}] + assert result_dict["rollout/sample"][1] == Reduce.SAMPLE def test_multiple_states(self): """Test reduce_metrics_states with multiple states.""" states = [ - {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, - {"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}}, + { + "loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}, + "rollout/sample": { + "reduction_type": "sample", + "samples": [{"id": 1, "reward": 0.5}], + }, + }, + { + "loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}, + "rollout/sample": { + "reduction_type": "sample", + "samples": [{"id": 2, "reward": 0.8}], + }, + }, {"accuracy": {"reduction_type": "sum", "total": 15.0}}, ] - result = reduce_metrics_states(states) + metrics = reduce_metrics_states(states) + + assert len(metrics) == 3 # Convert to dict for easier testing - result_dict = {metric.key: metric.value for metric in result} - assert result_dict["loss"] == 30.0 / 5.0 # 6.0 - assert result_dict["accuracy"] == 15.0 - - # Also check reduction types - for metric in result: - if metric.key == "loss": - assert metric.reduction == Reduce.MEAN - elif metric.key == "accuracy": - assert metric.reduction == Reduce.SUM + result_dict = {m.key: (m.value, m.reduction) for m in metrics} + + # Check scalar reductions + assert result_dict["loss"][0] == 30.0 / 5.0 # 6.0 + assert result_dict["loss"][1] == Reduce.MEAN + assert result_dict["accuracy"][0] == 15.0 + assert result_dict["accuracy"][1] == Reduce.SUM + + # Check sample concatenation + assert result_dict["rollout/sample"][0] == [ + {"id": 1, "reward": 0.5}, + {"id": 2, "reward": 0.8}, + ] + assert result_dict["rollout/sample"][1] == Reduce.SAMPLE def test_mismatched_reduction_types_raises_error(self): """Test reduce_metrics_states raises error for mismatched reduction types."""