diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 80e6ee10a..773a16892 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -12,6 +12,7 @@ import torch from datasets import load_dataset from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig +from forge.actors.replay_buffer import ReplayBuffer from forge.controller import ServiceConfig, spawn_service from forge.controller.actor import ForgeActor from monarch.actor import endpoint