Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Any,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Expand Down Expand Up @@ -329,7 +330,7 @@ def shardings(self) -> Dict[str, FeatureShardingMixIn]:


Out = TypeVar("Out")
CompIn = TypeVar("CompIn")
CompIn = TypeVar("CompIn", KJTList, ListOfKJTList, KeyedJaggedTensor)
DistOut = TypeVar("DistOut")
ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable)

Expand Down Expand Up @@ -358,14 +359,14 @@ def __init__(

def prefetch(
self,
dist_input: KJTList,
dist_input: CompIn,
forward_stream: Optional[Union[torch.cuda.Stream, torch.mtia.Stream]] = None,
ctx: Optional[ShrdCtx] = None,
) -> None:
"""
Prefetch input features for each lookup module.
"""

assert isinstance(dist_input, Iterable)
for feature, emb_lookup in zip(dist_input, self._lookups):
while isinstance(emb_lookup, DistributedDataParallel):
emb_lookup = emb_lookup.module
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/quant_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo
from torchrec.distributed.embedding_types import (
GroupedEmbeddingConfig,
KeyedJaggedTensor,
KJTList,
ListOfKJTList,
ShardedEmbeddingModule,
)
from torchrec.distributed.types import ParameterSharding, ShardingType
Expand All @@ -34,7 +37,7 @@
from torchrec.tensor_types import UInt2Tensor, UInt4Tensor

Out = TypeVar("Out")
CompIn = TypeVar("CompIn")
CompIn = TypeVar("CompIn", KJTList, ListOfKJTList, KeyedJaggedTensor)
DistOut = TypeVar("DistOut")
ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable)

Expand Down
108 changes: 105 additions & 3 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import copy

import unittest
from contextlib import ExitStack
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
from functools import partial
from typing import cast, Dict, List, Optional, Tuple, Type, Union
from unittest.mock import MagicMock
from typing import cast, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union
from unittest.mock import MagicMock, patch

import torch
from hypothesis import assume, given, settings, strategies as st, Verbosity
Expand All @@ -29,6 +29,7 @@
FeatureProcessedEmbeddingBagCollectionSharder,
ShardedFeatureProcessedEmbeddingBagCollection,
)
from torchrec.distributed.model_parallel import DMPCollection
from torchrec.distributed.sharding_plan import (
construct_module_sharding_plan,
table_wise,
Expand Down Expand Up @@ -86,6 +87,9 @@
from torchrec.streamable import Pipelineable


T = TypeVar("T")


@dataclass
class ModelInputSimple(Pipelineable):
float_features: torch.Tensor
Expand Down Expand Up @@ -692,6 +696,104 @@ def custom_model_fwd(
self.assertEqual(pred_pipeline.size(0), 64)


class TrainPipelineSparseDist2DShardingTest(unittest.TestCase):
@contextmanager
def _mocked_pipeline(self, obj: T) -> Generator[T, None, None]:
disabled_methods = [
"fill_pipeline",
"_wait_for_batch",
"enqueue_batch",
"dequeue_batch",
"start_sparse_data_dist",
"wait_sparse_data_dist",
]

with ExitStack() as stack:
for method in disabled_methods:
stack.enter_context(
patch.object(obj.__class__, method, return_value=None)
)
yield obj

def test_dmp_collection_sync(self) -> None:
dmp = MagicMock(spec=DMPCollection)
dmp.training = True
dmp.return_value = (
torch.tensor(0.1, requires_grad=True),
torch.tensor(2),
) # loss, output

optimizer = MagicMock(spec=torch.optim.Optimizer)
data_iter = MagicMock()
mock_data: MagicMock = MagicMock(spec=Pipelineable)

def _add_context(pipeline: TrainPipelineSparseDist) -> None: # pyre-ignore
context = TrainPipelineContext()
context.index = 10
for _ in range(3):
pipeline.batches.append(mock_data)
pipeline.contexts.append(context)

# disable
pipeline = TrainPipelineSparseDist(
dmp,
optimizer,
device=torch.device("cpu"),
dmp_collection_sync_interval_batches=None,
)
_add_context(pipeline)
with self._mocked_pipeline(pipeline):
pipeline.progress(data_iter)

dmp.sync.assert_not_called()

# enable
dmp.reset_mock()
pipeline_with_dmp_sync = TrainPipelineSparseDist(
dmp,
optimizer,
device=torch.device("cpu"),
dmp_collection_sync_interval_batches=10,
)
_add_context(pipeline_with_dmp_sync)
with self._mocked_pipeline(pipeline_with_dmp_sync):
pipeline_with_dmp_sync.progress(data_iter)

dmp.assert_called_once()
dmp.sync.assert_called_once()

def test_sync_disabled_if_dmp_collection_is_not_used(self) -> None:
dmp = MagicMock(spec=DistributedModelParallel)
dmp.training = True
dmp.return_value = (
torch.tensor(0.1, requires_grad=True),
torch.tensor(2),
) # loss, output

optimizer = MagicMock(spec=torch.optim.Optimizer)
data_iter = MagicMock()
mock_data: MagicMock = MagicMock(spec=Pipelineable)

# set interval but pass in raw DMP
# interval will be ignored
pipeline = TrainPipelineSparseDist(
dmp,
optimizer,
device=torch.device("cpu"),
dmp_collection_sync_interval_batches=10,
)
context = TrainPipelineContext()
context.index = 10
for _ in range(3):
pipeline.batches.append(mock_data)
pipeline.contexts.append(context)
with self._mocked_pipeline(pipeline):
# no exception
pipeline.progress(data_iter)

dmp.assert_called_once()


class TrainPipelineAttachDetachTest(TrainPipelineSparseDistTestBase):
@unittest.skipIf(
not torch.cuda.is_available(),
Expand Down
75 changes: 75 additions & 0 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@
torch.ops.import_module("fbgemm_gpu.sparse_ops")


# Note: doesn't make much sense but better than throwing.
# Somehow some users would mess up their dependency when using torch package,
# and we cannot fix their problem sorry
has_2d_support = True
try:
from torchrec.distributed.model_parallel import DMPCollection
except ImportError:
logger.warning("DMPCollection is not available. 2D sharding is not supported.")
has_2d_support = False


class ModelDetachedException(Exception):
pass

Expand All @@ -85,6 +96,39 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
def progress(self, dataloader_iter: Iterator[In]) -> Out:
pass

def sync_embeddings(
self,
model: torch.nn.Module,
interval_batches: Optional[int],
context: Optional[TrainPipelineContext] = None,
) -> None:
"""
Sync the embedding weights and fused optimizer states across replicas.
Only enabled if DMPCollection is used to shard the model.
Otherwise this is a no op.
"""
if (
not has_2d_support
or not isinstance(model, DMPCollection)
or interval_batches is None
):
return

if not context:
logger.warning(
f"{self.__class__.__name__} does not support context (not expected). "
"Embedding weight sync is disabled."
Comment on lines +119 to +120
Copy link
Preview

Copilot AI May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider clarifying or documenting in the warning message when the context is absent so that it’s clear under what conditions embedding sync is disabled.

Suggested change
f"{self.__class__.__name__} does not support context (not expected). "
"Embedding weight sync is disabled."
f"{self.__class__.__name__}: Embedding weight synchronization requires a valid "
"TrainPipelineContext. No context was provided, so embedding sync is disabled. "
"Ensure that a TrainPipelineContext is passed to enable this feature."

Copilot uses AI. Check for mistakes.

)
return

index = context.index
assert (
index is not None
), f"{self.__class__.__name__} context does not provide number of batches: {context=}"
if index % interval_batches == 0:
with record_function("## dmp_collection_sync ##"):
model.sync()


@dataclass
class TorchCompileConfig:
Expand Down Expand Up @@ -343,6 +387,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
execute_all_batches (bool): executes remaining batches in pipeline after
exhausting dataloader iterator.
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
dmp_collection_sync_interval_batches (Optional[int]):
(applicable to 2D sharding only)
if set and DMP collection is enabled for 2D sharding,
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
"""

# The PipelinedForward class that is used in _rewrite_model
Expand All @@ -361,6 +409,7 @@ def __init__(
custom_model_fwd: Optional[
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
] = None,
dmp_collection_sync_interval_batches: Optional[int] = 1,
) -> None:
self._model = model
self._optimizer = optimizer
Expand Down Expand Up @@ -416,6 +465,15 @@ def __init__(
self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = (
custom_model_fwd if custom_model_fwd else model
)
self._pipelined_forward_type = PipelinedForward
self._dmp_collection_sync_interval_batches = (
dmp_collection_sync_interval_batches
)
if self._dmp_collection_sync_interval_batches is not None:
logger.info(
f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every "
f"{self._dmp_collection_sync_interval_batches} batches"
)

# DEPRECATED FIELDS
self._batch_i: Optional[In] = None
Expand Down Expand Up @@ -610,6 +668,12 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
# backward
self._backward(losses)

self.sync_embeddings(
self._model,
self._dmp_collection_sync_interval_batches,
self.contexts[0],
)

# update
with record_function("## optimizer ##"):
self._optimizer.step()
Expand Down Expand Up @@ -823,6 +887,10 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX.
stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync"
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
dmp_collection_sync_interval_batches (Optional[int]):
(applicable to 2D sharding only)
if set and DMP collection is enabled for 2D sharding,
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
"""

# The PipelinedForward class that is used in _rewrite_model
Expand All @@ -842,6 +910,7 @@ def __init__(
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
] = None,
strict: bool = False,
dmp_collection_sync_interval_batches: Optional[int] = 1,
) -> None:
super().__init__(
model=model,
Expand All @@ -852,6 +921,7 @@ def __init__(
context_type=EmbeddingTrainPipelineContext,
pipeline_postproc=pipeline_postproc,
custom_model_fwd=custom_model_fwd,
dmp_collection_sync_interval_batches=dmp_collection_sync_interval_batches,
)
self._start_batch = start_batch
self._stash_gradients = stash_gradients
Expand Down Expand Up @@ -972,6 +1042,11 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
# pyre-ignore [6]
self.embedding_backward(context)

self.sync_embeddings(
self._model,
self._dmp_collection_sync_interval_batches,
context,
)
del context # context is no longer needed, deleting to free up memory

with record_function(f"## optimizer {iteration - 1} ##"):
Expand Down
9 changes: 5 additions & 4 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,15 +1731,14 @@ def _prefetch_embeddings(
context: PrefetchTrainPipelineContext,
pipelined_modules: List[ShardedModule],
device: torch.device,
stream_context: torch.Stream,
stream_context: Callable[[Optional[torch.Stream]], torch.cuda.StreamContext],
data_dist_stream: Optional[torch.Stream],
default_stream: Optional[torch.Stream],
) -> Dict[str, KJTList]:
data_per_sharded_module = {}
for sharded_module in pipelined_modules:
forward = sharded_module.forward
assert isinstance(forward, PrefetchPipelinedForward)

assert forward._name in context.input_dist_tensors_requests
request = context.input_dist_tensors_requests.pop(forward._name)
assert isinstance(request, Awaitable)
Expand All @@ -1762,10 +1761,12 @@ def _prefetch_embeddings(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
data.record_stream(cur_stream)
data.record_stream(default_stream)
if default_stream:
data.record_stream(default_stream)

module_context.record_stream(cur_stream)
module_context.record_stream(default_stream)
if default_stream:
module_context.record_stream(default_stream)

sharded_module.prefetch(
ctx=module_context,
Expand Down
8 changes: 8 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,14 @@ def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut:
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
pass

def prefetch(
self,
dist_input: CompIn,
forward_stream: Optional[torch.Stream] = None,
ctx: Optional[ShrdCtx] = None,
) -> None:
return None

def compute_and_output_dist(
self, ctx: ShrdCtx, input: CompIn
) -> LazyAwaitable[Out]:
Expand Down
Loading
Loading