Skip to content

[RFC] Add Sample-level Logging API #301

@DNXie

Description

@DNXie

This issue propose adding sample-level logging to complement existing scalar metrics, so users can directly inspect concrete prompts, responses, targets, and rewards produced during RL training.

This RFC is structured in two stages.

  • Stage 1 introduces the core API for sample-level logging, enabling users to record and inspect concrete prompts, responses, targets, and rewards alongside existing scalar metrics. The focus here is to provide a minimal, consistent mechanism without adding configuration complexity.
  • Stage 2 looks ahead to user-defined filters, allowing researchers to control what subset of samples are logged (e.g., by reward thresholds, frequency, or custom predicates). This keeps Stage 1 lightweight while outlining a clear path for more flexible logging in the future.

Stage 1: Core API for Sample-level Logging

Motivation

As discussed in #187, we’re currently using reward/loss as our main proxy for RL correctness. It makes it hard to tell if the RL loop is truly learning versus simply reward hacking. Numeric metrics alone don’t show the actual examples that led to a particular reward signal.

By logging structured sample-level information (e.g., prompt, response, target, reward, advantage), users can directly examine the concrete cases that produced specific reward signals. This makes it easier to validate that rewards align with the intended task rather than artifacts, and enables qualitative checks such as:

  • Debugging zero-reward responses
  • Spotting distribution drift in generated responses.
  • Inspecting sample-level statistics (e.g., rewards).

Benefits: This enables qualitative inspection of responses in WandB tables with minimal overhead, since it only logs structured dicts at rollout time. It also stays consistent with the existing metrics system by reusing the record_metric API.

Scope

  • Only enable the API for sample-level logging, not token-level.
  • Loss is not included: see reasons below.

Design

What to Log

episode_id: int
policy_version: int
prompt: str
response: str
request_len: int
response_len: int
target: str
reward: float
ref_logprobs: float
advantage: float

Alternative Considered

  • Log per-sample loss (loss, policy_loss, kl_loss):
    Not in scope for this proposal. Supporting this would require structural changes to the loss function, forward_backward, and potentially train_step. Currently the loss function only returns batch-level loss; to enable per-sample logging, it would need to return a list of per-sample losses, which would then be propagated and logged inside train_step.

Form of Log

  • ConsoleBackend → Pretty-print dicts via pprint.
    Example:

    === SAMPLE LOGS STEP 42 ===
    {
      'id': 1,
      'prompt': '2+2?',
      'response': '<answer>5</answer>',
      'target': '4',
      'reward': 0.0,
      'advantage': -1.2
      ...
    }
    
  • WandbBackendwandb.Table
    Columns: [id, prompt, response, target, reward, advantage, ...]
    Rows: one per logged sample, browsable in the WandB UI.

Where to Insert the Log

How to Log

We reuse the existing record_metric API with a new Reduce.SAMPLE:

record_metric(
    "reward/samples",
    {"episode_id": ep.episode_id, "prompt": prompt, ...},
    Reduce.SAMPLE,
)
New Accumulator: SampleAccumulator
class SampleAccumulator(MetricAccumulator):
    def __init__(self, reduction: Reduce):
        super().__init__(reduction)
        self.samples = []

    def append(self, value: dict) -> None:
        assert isinstance(value, dict)
        self.samples.append(value)

    def get_value(self) -> list[dict]:
        return self.samples

    def get_state(self) -> Dict[str, Any]:
        return {
            "reduction_type": self.reduction_type.value,
            "samples": self.samples,
        }

    @classmethod
    def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dict]:
        # Each state looks like {"reduction_type": "sample", "samples": [ ... ]}
        merged = []
        for s in states:
            merged.extend(s.get("samples", []))
        return merged

    def reset(self) -> None:
        self.samples = []
Flush Behavior
async def flush(self, step: int, return_state: bool = False):
    ...
    metrics = {}
    samples = {}   # 👈 new

    for key, state in states.items():
        reduction = state["reduction_type"]
        acc_class = Reduce(reduction).accumulator_class

        if reduction == Reduce.SAMPLE.value:
            samples[key] = acc_class.get_reduced_value_from_states([state])   # 👈 new
        else:
            metrics[key] = acc_class.get_reduced_value_from_states([state])

    for backend in self.logger_backends:
        if metrics:
            await backend.log(metrics, step)
        if samples:
            await backend.log_samples(samples, step)   # 👈 new hook
  • With SAMPLE: flush will also collect dicts from SampleAccumulator.
  • Backends get a new hook log_samples(samples, step) that handles rendering:
    • Console → print JSON.
    • WandB → create/update a wandb.Table.

Stage 2: User-Defined Filters

We can extend sample logging with user-defined filters to control what and how much gets logged. Logging every sample is often too expensive, and users may want to focus only on specific cases.

Examples of filters to expose:

  • Frequency: log every Nth sample, or random subsampling.
  • Reward thresholds: log only when reward ≤ 0 or ≥ X.
  • Advantage thresholds: log only extreme positive/negative advantages.
  • By step/policy version: log only at certain intervals.

User-facing config (conceptual):

sample_logging:
  enabled: true
  filters:
    reward_below: 0.1
    freq: 0.05     # log 5% of samples
    step: 100      # log every 100 steps

This would give users flexible, configurable control while keeping Stage 1 simple and minimal.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions