diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 6c2efd5e6..164933082 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -253,16 +253,27 @@ async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): for handle in state_dict.values(): handle.drop() - async def _fetch_weights( + async def _fetch_weights_parallel( self, - version: int, + param_names: list[str], + *, + version: int | None = None, + dcp_key: str | None = None, + tracer_name: str, ) -> dict[str, SharedTensorHandle]: - """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" - t = Tracer("generator_perf/_fetch_weights") + """Fetch weights in parallel using multiple fetcher processes. + + Args: + param_names: List of parameter names to fetch + version: Version number for individual tensor loading (mutually exclusive with dcp_key) + dcp_key: Torchstore key for DCP handle (mutually exclusive with version) + tracer_name: Name for the performance tracer + + Returns: + Dictionary mapping parameter names to SharedTensorHandles + """ + t = Tracer(tracer_name) t.start() - prefix = get_param_prefix(version) - matching_keys = await ts.keys(prefix) - hf_param_names = [extract_param_name(key) for key in matching_keys] n_fetchers = self.weight_fetchers.size() @@ -270,10 +281,15 @@ def split_keys(keys): return [keys[i::n_fetchers] for i in range(n_fetchers)] futures = [] - for i, names in enumerate(split_keys(hf_param_names)): - fut = self.weight_fetchers.slice(procs=i).fetch.call_one( - version=version, param_names=names - ) + for i, names in enumerate(split_keys(param_names)): + if dcp_key is not None: + fut = self.weight_fetchers.slice(procs=i).fetch.call_one( + dcp_key=dcp_key, param_names=names + ) + else: + fut = self.weight_fetchers.slice(procs=i).fetch.call_one( + version=version, param_names=names + ) futures.append(fut) sub_state_dicts = [await fut for fut in futures] @@ -286,6 +302,38 @@ def split_keys(keys): return state_dict + async def _fetch_weights( + self, + version: int, + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + hf_param_names = [extract_param_name(key) for key in matching_keys] + + return await self._fetch_weights_parallel( + param_names=hf_param_names, + version=version, + tracer_name="generator_perf/_fetch_weights", + ) + + async def _fetch_weights_dcp( + self, + version: int, + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from DCP checkpoint and return a dict of {name: SharedTensorHandle}.""" + # Get the DCP handle from torchstore to access param names + dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) + dcp_handle = await ts.get(dcp_whole_state_dict_key) + hf_param_names = dcp_handle.param_names + + # Pass the DCP torchstore key so each fetcher can get the handle + return await self._fetch_weights_parallel( + param_names=hf_param_names, + dcp_key=dcp_whole_state_dict_key, + tracer_name="generator_perf/_fetch_weights_dcp", + ) + @endpoint async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: """Generate a response for the given prompt @@ -439,12 +487,25 @@ async def update_weights(self, version: int) -> None: >>> await trainer.push_weights() >>> generator.update_weights(version) """ - # TODO: enable shared memory prefetch for DCP-based weight sync - if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: - logger.info(f"[Generator] Fetching weights for v{version} to shared memory") - fetch_fut = asyncio.create_task(self._fetch_weights(version)) - else: - fetch_fut = None + # Prefetch weights to shared memory if enabled + fetch_fut = None + if self.prefetch_weights_to_shm: + # Check if DCP is being used for this version + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) + use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys + + if use_dcp_for_weight_sync: + logger.info( + f"[Generator] Fetching weights for v{version} from DCP to shared memory" + ) + fetch_fut = asyncio.create_task(self._fetch_weights_dcp(version)) + else: + logger.info( + f"[Generator] Fetching weights for v{version} to shared memory" + ) + fetch_fut = asyncio.create_task(self._fetch_weights(version)) # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -733,17 +794,41 @@ class _WeightFetcher(ForgeActor): async def fetch( self, *, - version: int, + version: int | None = None, + dcp_key: str | None = None, param_names: list[str], ) -> dict[str, SharedTensorHandle]: - """Fetch weights from torchstore and load them into shared memory.""" + """Fetch weights and load them into shared memory. + + Args: + version: Version number for individual tensor loading (mutually exclusive with dcp_key) + dcp_key: Torchstore key for DCP handle (mutually exclusive with version) + param_names: List of parameter names to fetch + + Returns: + Dictionary mapping parameter names to SharedTensorHandles + """ sd = {} + + # Setup for DCP loading if dcp_key is provided + if dcp_key is not None: + # Get the DCP handle from torchstore - this gives us the metadata and checkpoint path + dcp_handle = await ts.get(dcp_key) + + # Fetch each parameter for name in param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) + if dcp_key is not None: + # Load tensor from DCP checkpoint + param = load_tensor_from_dcp(dcp_handle, name) + else: + # Load tensor from torchstore + param_key = get_param_key(version, name) + param = await ts.get(param_key) + # Use context manager to ensure cleanup after getting handle with SharedTensor(tensor=param) as shared_tensor: handle = shared_tensor.get_handle() sd[name] = handle del param # Explicitly free the tensor after copying to shared memory + return sd