Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,21 @@ class Policy(PolicyInterface):
lora_request: LoRARequest | None = None
tokenization_kwargs: dict = field(default_factory=dict)
policy_worker: "PolicyWorker" = None
store: MultiProcessStore | None = None

def __post_init__(self):
self._run_task: asyncio.Task | None = None
self._policy_proc: ProcMesh | None = None
self._worker_procs: ProcMesh | None = None
self.weights_version: int = 0

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Policy"],
*,
process_config: ProcessConfig,
config: PolicyConfig,
store: MultiProcessStore | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does spawn services need to know to pass this in?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it has to be passed in from the top level so you can use the same for Trainer.

**kwargs,
) -> "Policy":
# Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
Expand All @@ -132,7 +135,11 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
policy = await policy_proc.spawn(
actor_name, cls, config=config, policy_worker=workers
actor_name,
cls,
config=config,
policy_worker=workers,
store=store,
)
policy._policy_proc = policy_proc
policy._worker_procs = worker_procs
Expand Down Expand Up @@ -160,7 +167,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
async def setup(self):
# Set up policy_worker
assert self.policy_worker is not None, "Policy worker should not be None"
await self.policy_worker.setup.call()
await self.policy_worker.setup.call(store=self.store)

self.request_id = 0
self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
Expand Down Expand Up @@ -313,9 +320,21 @@ async def run(self):
fut.set_result(request_output.outputs)

@endpoint
async def update_weights(self):
async def update_weights(self) -> int:
"""Update the policy weights."""
pass
# Wait for all current requests to finish, then publish model weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a check that the new version exists on the store?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just thinking that it would fail when trying to lookup the key and that would be obvious enough for now, but I could make it more explicit.

futures = [fut for _, fut in self.requests.values()]
if futures:
await asyncio.gather(*futures)
new_version = self.weights_version + 1
await self.policy_worker.update.call(version=new_version)
self.weights_version = new_version
return self.weights_version

@endpoint
async def get_version(self) -> int:
"""Get the current policy version."""
return self.weights_version

@endpoint
async def stop(self):
Expand Down Expand Up @@ -383,7 +402,9 @@ async def setup(self, store: MultiProcessStore = None):
async def execute_model(self, schedule: SchedulerOutput):
return self.worker.execute_model(schedule)

async def _load_tensor_parallel_state_dict(self, current_state_dict: dict):
async def _load_tensor_parallel_state_dict(
self, current_state_dict: dict, version: int
):
"""
Load full state dict from torchstore into tensor parallel model with deterministic sharding.
"""
Expand All @@ -398,7 +419,7 @@ async def _load_tensor_parallel_state_dict(self, current_state_dict: dict):
# Load the full tensor from torchstore
# TODO: only get the part of the tensor that is needed
stored_tensor = await self.torchstore.get(
f"{self.state_dict_key}{DELIM}{param_name}"
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}"
)
sharding.load_from_source_to_target(
param_name,
Expand All @@ -409,23 +430,19 @@ async def _load_tensor_parallel_state_dict(self, current_state_dict: dict):
updated_count += 1

@endpoint
async def update(self):
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""

if self.torchstore is None:
raise Exception("No torchstore configured, skipping model update")

logger.debug(
f"Starting model update from torchstore with key: {self.state_dict_key}"
f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}"
)

model = self.worker.model_runner.model
current_state_dict = model.state_dict()

logger.debug(f"Current state dict has {len(current_state_dict)} parameters")

await self._load_tensor_parallel_state_dict(current_state_dict)

await self._load_tensor_parallel_state_dict(current_state_dict, version)
logger.debug("Successfully updated model weights from torchstore")

@endpoint
Expand Down
Loading