Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 45 additions & 44 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from forge.data.rewards import MathReward, ThinkingReward
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce

# from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from forge.types import LauncherConfig, ProvisionerConfig
Expand Down Expand Up @@ -161,36 +162,36 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
reward_fn_name = getattr(
reward_fn, "__name__", reward_fn.__class__.__name__
)
# per function reward
record_metric(
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
reward,
Reduce.SUM,
)
record_metric(
f"reward/evaluate_response/avg_{reward_fn_name}_reward",
reward,
Reduce.MEAN,
)
record_metric(
f"reward/evaluate_response/std_{reward_fn_name}_reward",
reward,
Reduce.STD,
)

# avg total reward
record_metric(
"reward/evaluate_response/avg_total_reward",
reward,
Reduce.MEAN,
)

# count fn calls
record_metric(
f"reward/evaluate_response/count_{reward_fn_name}_calls",
1,
Reduce.SUM,
)
# # per function reward
# record_metric(
# f"reward/evaluate_response/sum_{reward_fn_name}_reward",
# reward,
# Reduce.SUM,
# )
# record_metric(
# f"reward/evaluate_response/avg_{reward_fn_name}_reward",
# reward,
# Reduce.MEAN,
# )
# record_metric(
# f"reward/evaluate_response/std_{reward_fn_name}_reward",
# reward,
# Reduce.STD,
# )

# # avg total reward
# record_metric(
# "reward/evaluate_response/avg_total_reward",
# reward,
# Reduce.MEAN,
# )

# # count fn calls
# record_metric(
# f"reward/evaluate_response/count_{reward_fn_name}_calls",
# 1,
# Reduce.SUM,
# )

avg_reward = total_rewards / len(self.reward_functions)
return avg_reward
Expand Down Expand Up @@ -256,12 +257,12 @@ async def sample(self) -> dict[str, str] | None:
sample = next(self._iterator)

# Record dataset metrics
record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM)
record_metric(
"dataset/sample/avg_sample_len",
len(sample["request"]),
Reduce.MEAN,
)
# record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM)
# record_metric(
# "dataset/sample/avg_sample_len",
# len(sample["request"]),
# Reduce.MEAN,
# )

return sample
except StopIteration:
Expand Down Expand Up @@ -304,9 +305,9 @@ async def main(cfg: DictConfig):
else:
provisioner = await init_provisioner()

metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
# metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
# mlogger = await get_or_create_metric_logger()
# await mlogger.init_backends.call_one(metric_logging_cfg)

# ---- Setup services ---- #

Expand Down Expand Up @@ -414,9 +415,9 @@ async def continuous_rollouts():

# Log metrics
rollout_count += 1
record_metric(
"main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM
)
# record_metric(
# "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM
# )
t.stop()

async def continuous_training():
Expand Down Expand Up @@ -458,7 +459,7 @@ async def continuous_training():
restart_tracer = True

# Flush metrics every training step to WandB
await mlogger.flush.call_one(training_step)
# await mlogger.flush.call_one(training_step)

print(
f"Reached training limit ({max_steps} steps). Exiting continuous_training loop."
Expand Down
16 changes: 8 additions & 8 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ off_by_n: 1 # Off by one by default
rollout_threads: 1 # Recommended to set equal to policy.num_replicas


# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
console:
reduce_across_ranks: True
# # Observability configuration
# metric_logging:
# wandb:
# project: "grpo-training"
# group: "grpo_exp_${oc.env:USER}"
# reduce_across_ranks: True
# console:
# reduce_across_ranks: True

# Dataset configuration
dataset:
Expand Down
Loading