-
Notifications
You must be signed in to change notification settings - Fork 16
Description
This issue propose adding sample-level logging to complement existing scalar metrics, so users can directly inspect concrete prompts, responses, targets, and rewards produced during RL training.
This RFC is structured in two stages.
- Stage 1 introduces the core API for sample-level logging, enabling users to record and inspect concrete prompts, responses, targets, and rewards alongside existing scalar metrics. The focus here is to provide a minimal, consistent mechanism without adding configuration complexity.
- Stage 2 looks ahead to user-defined filters, allowing researchers to control what subset of samples are logged (e.g., by reward thresholds, frequency, or custom predicates). This keeps Stage 1 lightweight while outlining a clear path for more flexible logging in the future.
Stage 1: Core API for Sample-level Logging
Motivation
As discussed in #187, we’re currently using reward/loss as our main proxy for RL correctness. It makes it hard to tell if the RL loop is truly learning versus simply reward hacking. Numeric metrics alone don’t show the actual examples that led to a particular reward signal.
By logging structured sample-level information (e.g., prompt, response, target, reward, advantage), users can directly examine the concrete cases that produced specific reward signals. This makes it easier to validate that rewards align with the intended task rather than artifacts, and enables qualitative checks such as:
- Debugging zero-reward responses
- Spotting distribution drift in generated responses.
- Inspecting sample-level statistics (e.g., rewards).
Benefits: This enables qualitative inspection of responses in WandB tables with minimal overhead, since it only logs structured dicts at rollout time. It also stays consistent with the existing metrics system by reusing the record_metric
API.
Scope
- Only enable the API for sample-level logging, not token-level.
- Loss is not included: see reasons below.
Design
What to Log
episode_id: int
policy_version: int
prompt: str
response: str
request_len: int
response_len: int
target: str
reward: float
ref_logprobs: float
advantage: float
Alternative Considered
- Log per-sample loss (
loss
,policy_loss
,kl_loss
):
Not in scope for this proposal. Supporting this would require structural changes to the loss function,forward_backward
, and potentiallytrain_step
. Currently the loss function only returns batch-level loss; to enable per-sample logging, it would need to return a list of per-sample losses, which would then be propagated and logged insidetrain_step
.
Form of Log
-
ConsoleBackend → Pretty-print dicts via
pprint
.
Example:=== SAMPLE LOGS STEP 42 === { 'id': 1, 'prompt': '2+2?', 'response': '<answer>5</answer>', 'target': '4', 'reward': 0.0, 'advantage': -1.2 ... }
-
WandbBackend →
wandb.Table
Columns:[id, prompt, response, target, reward, advantage, ...]
Rows: one per logged sample, browsable in the WandB UI.
Where to Insert the Log
- Preferred:
continuous_rollouts
(has full context includingadvantage
andref_logprobs
). - Alternative:
RewardActor.evaluate_response
(if we don't wantadvantage
andref_logprobs
).
How to Log
We reuse the existing record_metric
API with a new Reduce.SAMPLE
:
record_metric(
"reward/samples",
{"episode_id": ep.episode_id, "prompt": prompt, ...},
Reduce.SAMPLE,
)
New Accumulator: SampleAccumulator
class SampleAccumulator(MetricAccumulator):
def __init__(self, reduction: Reduce):
super().__init__(reduction)
self.samples = []
def append(self, value: dict) -> None:
assert isinstance(value, dict)
self.samples.append(value)
def get_value(self) -> list[dict]:
return self.samples
def get_state(self) -> Dict[str, Any]:
return {
"reduction_type": self.reduction_type.value,
"samples": self.samples,
}
@classmethod
def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dict]:
# Each state looks like {"reduction_type": "sample", "samples": [ ... ]}
merged = []
for s in states:
merged.extend(s.get("samples", []))
return merged
def reset(self) -> None:
self.samples = []
Flush Behavior
async def flush(self, step: int, return_state: bool = False):
...
metrics = {}
samples = {} # 👈 new
for key, state in states.items():
reduction = state["reduction_type"]
acc_class = Reduce(reduction).accumulator_class
if reduction == Reduce.SAMPLE.value:
samples[key] = acc_class.get_reduced_value_from_states([state]) # 👈 new
else:
metrics[key] = acc_class.get_reduced_value_from_states([state])
for backend in self.logger_backends:
if metrics:
await backend.log(metrics, step)
if samples:
await backend.log_samples(samples, step) # 👈 new hook
- With
SAMPLE
:flush
will also collect dicts fromSampleAccumulator
. - Backends get a new hook
log_samples(samples, step)
that handles rendering:- Console → print JSON.
- WandB → create/update a
wandb.Table
.
Stage 2: User-Defined Filters
We can extend sample logging with user-defined filters to control what and how much gets logged. Logging every sample is often too expensive, and users may want to focus only on specific cases.
Examples of filters to expose:
- Frequency: log every Nth sample, or random subsampling.
- Reward thresholds: log only when reward ≤ 0 or ≥ X.
- Advantage thresholds: log only extreme positive/negative advantages.
- By step/policy version: log only at certain intervals.
User-facing config (conceptual):
sample_logging:
enabled: true
filters:
reward_below: 0.1
freq: 0.05 # log 5% of samples
step: 100 # log every 100 steps
This would give users flexible, configurable control while keeping Stage 1 simple and minimal.