Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from forge.data.rewards import MathReward, ThinkingReward
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.metrics import record_episode_sample, record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from forge.types import LauncherConfig, ProvisionerConfig
Expand All @@ -47,10 +47,13 @@ class Episode:
request_len: int
response_len: int
target: Any | None = None
request: str | None = None
response: str | None = None
# Processed data
completion: Completion | None = None
ref_logprobs: torch.Tensor | None = None
reward: float | None = None
reward_breakdown: dict[str, float] | None = None
advantage: float | None = None

@property
Expand Down Expand Up @@ -151,8 +154,11 @@ class RewardActor(ForgeActor):
reward_functions: list[Callable]

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
async def evaluate_response(
self, prompt: str, response: str, target: str
) -> dict[str, float]:
total_rewards = 0.0
reward_breakdown = {} # reward breakdown by function
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
total_rewards += reward
Expand All @@ -161,6 +167,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
reward_fn_name = getattr(
reward_fn, "__name__", reward_fn.__class__.__name__
)
reward_breakdown[reward_fn_name] = reward
# per function reward
record_metric(
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
Expand Down Expand Up @@ -193,7 +200,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
)

avg_reward = total_rewards / len(self.reward_functions)
return avg_reward
reward_breakdown["reward"] = avg_reward
return reward_breakdown


@dataclass
Expand Down Expand Up @@ -384,11 +392,14 @@ async def continuous_rollouts():
request_len=max_req_tokens,
response_len=max_res_tokens,
target=target,
request=prompt,
response=response.text,
completion=response,
)
episode.reward = await reward_actor.evaluate_response.route(
episode.reward_breakdown = await reward_actor.evaluate_response.route(
prompt=prompt, response=response.text, target=target
)
episode.reward = episode.reward_breakdown["reward"]
episodes.append(episode)

# Build input_ids for reference logprobs
Expand All @@ -411,6 +422,7 @@ async def continuous_rollouts():
for episode, advantage in zip(episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.call_one(episode)
record_episode_sample("rollout/sample", episode)

# Log metrics
rollout_count += 1
Expand Down
7 changes: 7 additions & 0 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
MetricAccumulator,
MetricCollector,
MinAccumulator,
record_episode_sample,
record_metric,
Reduce,
reduce_metrics_states,
SampleAccumulator,
StdAccumulator,
SumAccumulator,
TopBottomKFilter,
WandbBackend,
)
from .perf_tracker import trace, Tracer
Expand All @@ -35,6 +38,7 @@
# Main API functions
"record_metric",
"reduce_metrics_states",
"record_episode_sample",
"get_logger_backend_class",
"get_or_create_metric_logger",
# Performance tracking
Expand Down Expand Up @@ -64,4 +68,7 @@
"MaxAccumulator",
"MinAccumulator",
"StdAccumulator",
"SampleAccumulator",
# Filter classes
"TopBottomKFilter",
]
16 changes: 15 additions & 1 deletion src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
LoggerBackend,
LoggingMode,
MetricCollector,
Reduce,
reduce_metrics_states,
)

Expand Down Expand Up @@ -423,9 +424,22 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
# Reduce metrics from states
reduced_metrics = reduce_metrics_states(all_local_states)

print(f"[DEBUG] reduced_metrics: {reduced_metrics}")
# Split into scalar metrics and sample metrics
scalar_metrics = [
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
]
sample_metrics = {
m.key: m.value for m in reduced_metrics if m.reduction == Reduce.SAMPLE
}

# Log to global backends
for backend_name, backend in self.global_logger_backends.items():
await backend.log_batch(reduced_metrics, global_step)
if scalar_metrics:
print(f"[DEBUG] calling log_batch from GlobalLoggerActor")
await backend.log_batch(scalar_metrics, global_step)
if sample_metrics:
await backend.log_samples(sample_metrics, global_step)

@endpoint
def has_fetcher(self, proc_id: str) -> bool:
Expand Down
Loading