Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 106 additions & 21 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,27 +253,43 @@ async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]):
for handle in state_dict.values():
handle.drop()

async def _fetch_weights(
async def _fetch_weights_parallel(
self,
version: int,
param_names: list[str],
*,
version: int | None = None,
dcp_key: str | None = None,
tracer_name: str,
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
t = Tracer("generator_perf/_fetch_weights")
"""Fetch weights in parallel using multiple fetcher processes.

Args:
param_names: List of parameter names to fetch
version: Version number for individual tensor loading (mutually exclusive with dcp_key)
dcp_key: Torchstore key for DCP handle (mutually exclusive with version)
tracer_name: Name for the performance tracer

Returns:
Dictionary mapping parameter names to SharedTensorHandles
"""
t = Tracer(tracer_name)
t.start()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
hf_param_names = [extract_param_name(key) for key in matching_keys]

n_fetchers = self.weight_fetchers.size()

def split_keys(keys):
return [keys[i::n_fetchers] for i in range(n_fetchers)]

futures = []
for i, names in enumerate(split_keys(hf_param_names)):
fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
version=version, param_names=names
)
for i, names in enumerate(split_keys(param_names)):
if dcp_key is not None:
fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
dcp_key=dcp_key, param_names=names
)
else:
fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
version=version, param_names=names
)
futures.append(fut)

sub_state_dicts = [await fut for fut in futures]
Expand All @@ -286,6 +302,38 @@ def split_keys(keys):

return state_dict

async def _fetch_weights(
self,
version: int,
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
hf_param_names = [extract_param_name(key) for key in matching_keys]

return await self._fetch_weights_parallel(
param_names=hf_param_names,
version=version,
tracer_name="generator_perf/_fetch_weights",
)

async def _fetch_weights_dcp(
self,
version: int,
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from DCP checkpoint and return a dict of {name: SharedTensorHandle}."""
# Get the DCP handle from torchstore to access param names
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
dcp_handle = await ts.get(dcp_whole_state_dict_key)
hf_param_names = dcp_handle.param_names

# Pass the DCP torchstore key so each fetcher can get the handle
return await self._fetch_weights_parallel(
param_names=hf_param_names,
dcp_key=dcp_whole_state_dict_key,
tracer_name="generator_perf/_fetch_weights_dcp",
)

@endpoint
async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
"""Generate a response for the given prompt
Expand Down Expand Up @@ -439,12 +487,25 @@ async def update_weights(self, version: int) -> None:
>>> await trainer.push_weights()
>>> generator.update_weights(version)
"""
# TODO: enable shared memory prefetch for DCP-based weight sync
if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync:
logger.info(f"[Generator] Fetching weights for v{version} to shared memory")
fetch_fut = asyncio.create_task(self._fetch_weights(version))
else:
fetch_fut = None
# Prefetch weights to shared memory if enabled
fetch_fut = None
if self.prefetch_weights_to_shm:
# Check if DCP is being used for this version
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys

if use_dcp_for_weight_sync:
logger.info(
f"[Generator] Fetching weights for v{version} from DCP to shared memory"
)
fetch_fut = asyncio.create_task(self._fetch_weights_dcp(version))
else:
logger.info(
f"[Generator] Fetching weights for v{version} to shared memory"
)
fetch_fut = asyncio.create_task(self._fetch_weights(version))
# Serialize updates (only one update at a time)
async with self.update_lock:
# Grab the lock to stop accepting requests and wait on pending requests
Expand Down Expand Up @@ -733,17 +794,41 @@ class _WeightFetcher(ForgeActor):
async def fetch(
self,
*,
version: int,
version: int | None = None,
dcp_key: str | None = None,
param_names: list[str],
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and load them into shared memory."""
"""Fetch weights and load them into shared memory.

Args:
version: Version number for individual tensor loading (mutually exclusive with dcp_key)
dcp_key: Torchstore key for DCP handle (mutually exclusive with version)
param_names: List of parameter names to fetch

Returns:
Dictionary mapping parameter names to SharedTensorHandles
"""
sd = {}

# Setup for DCP loading if dcp_key is provided
if dcp_key is not None:
# Get the DCP handle from torchstore - this gives us the metadata and checkpoint path
dcp_handle = await ts.get(dcp_key)

# Fetch each parameter
for name in param_names:
param_key = get_param_key(version, name)
param = await ts.get(param_key)
if dcp_key is not None:
# Load tensor from DCP checkpoint
param = load_tensor_from_dcp(dcp_handle, name)
else:
# Load tensor from torchstore
param_key = get_param_key(version, name)
param = await ts.get(param_key)

# Use context manager to ensure cleanup after getting handle
with SharedTensor(tensor=param) as shared_tensor:
handle = shared_tensor.get_handle()
sd[name] = handle
del param # Explicitly free the tensor after copying to shared memory

return sd
Loading