Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions torchrec/distributed/fp_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@
# pyre-strict

from functools import partial
from typing import Any, Dict, Iterator, List, Optional, Type, Union
from typing import (
Any,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

import torch
from torch import nn
Expand All @@ -31,14 +42,20 @@
ShardingEnv,
ShardingType,
)
from torchrec.distributed.utils import append_prefix, init_parameters
from torchrec.distributed.utils import (
append_prefix,
init_parameters,
modify_input_for_feature_processor,
)
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
from torchrec.modules.fp_embedding_modules import (
apply_feature_processors_to_kjt,
FeatureProcessedEmbeddingBagCollection,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor

_T = TypeVar("_T")


def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor:
kt._values.add_(no_op_tensor)
Expand Down Expand Up @@ -74,6 +91,16 @@ def __init__(
)
)

self._row_wise_sharded: bool = False
for param_sharding in table_name_to_parameter_sharding.values():
if param_sharding.sharding_type in [
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.GRID_SHARD.value,
]:
self._row_wise_sharded = True
break

self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups

self._is_collection: bool = False
Expand All @@ -96,6 +123,11 @@ def __init__(
def input_dist(
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
) -> Awaitable[Awaitable[KJTList]]:
if not self.is_pipelined and self._row_wise_sharded:
# transform input to support row based sharding when not pipelined
modify_input_for_feature_processor(
features, self._feature_processors, self._is_collection
)
return self._embedding_bag_collection.input_dist(ctx, features)

def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
Expand All @@ -105,10 +137,7 @@ def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
kjt_list.append(self._feature_processors(features))
else:
kjt_list.append(
apply_feature_processors_to_kjt(
features,
self._feature_processors,
)
apply_feature_processors_to_kjt(features, self._feature_processors)
)
return KJTList(kjt_list)

Expand All @@ -117,7 +146,6 @@ def compute(
ctx: EmbeddingBagCollectionContext,
dist_input: KJTList,
) -> List[torch.Tensor]:

fp_features = self.apply_feature_processors_to_kjt_list(dist_input)
return self._embedding_bag_collection.compute(ctx, fp_features)

Expand Down Expand Up @@ -166,6 +194,18 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
self._embedding_bag_collection._initialize_torch_state(skip_registering)

def preprocess_input(
self, args: List[_T], kwargs: Mapping[str, _T]
) -> Tuple[List[_T], Mapping[str, _T]]:
for x in args + list(kwargs.values()):
if isinstance(x, KeyedJaggedTensor):
modify_input_for_feature_processor(
features=x,
feature_processors=self._feature_processors,
is_collection=self._is_collection,
)
return args, kwargs


class FeatureProcessedEmbeddingBagCollectionSharder(
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
Expand All @@ -191,7 +231,6 @@ def shard(
device: Optional[torch.device] = None,
module_fqn: Optional[str] = None,
) -> ShardedFeatureProcessedEmbeddingBagCollection:

if device is None:
device = torch.device("cuda")

Expand Down Expand Up @@ -228,12 +267,14 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
if compute_device_type in {"mtia"}:
return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value]

# No row wise because position weighted FP and RW don't play well together.
types = [
ShardingType.DATA_PARALLEL.value,
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.GRID_SHARD.value,
]

return types
1 change: 0 additions & 1 deletion torchrec/distributed/tests/test_fp_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase):
def test_sharding_ebc(
self, set_gradient_division: bool, use_dmp: bool, use_fp_collection: bool
) -> None:

import hypothesis

# don't need to test entire matrix
Expand Down
7 changes: 6 additions & 1 deletion torchrec/distributed/tests/test_fp_embeddingbag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,12 @@ def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
pred = torch.cat(
[
fp_ebc_out[key]
for key in ["feature_0", "feature_1", "feature_2", "feature_3"]
for key in [
"feature_0",
"feature_1",
"feature_2",
"feature_3",
]
],
dim=1,
)
Expand Down
165 changes: 164 additions & 1 deletion torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._dynamo.utils import counters
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embedding_types import (
EmbeddingComputeKernel,
EmbeddingTableConfig,
)
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.fp_embeddingbag import (
FeatureProcessedEmbeddingBagCollectionSharder,
Expand All @@ -31,8 +34,13 @@
from torchrec.distributed.model_parallel import DMPCollection
from torchrec.distributed.sharding_plan import (
construct_module_sharding_plan,
row_wise,
table_wise,
)
from torchrec.distributed.test_utils.multi_process import (
MultiProcessContext,
MultiProcessTestBase,
)
from torchrec.distributed.test_utils.test_model import (
ModelInput,
TestEBCSharder,
Expand Down Expand Up @@ -331,6 +339,161 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
torch.testing.assert_close(pred_gpu.cpu(), pred)


def fp_ebc(
rank: int,
world_size: int,
tables: List[EmbeddingTableConfig],
weighted_tables: List[EmbeddingTableConfig],
data: List[Tuple[ModelInput, List[ModelInput]]],
backend: str = "nccl",
local_size: Optional[int] = None,
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
assert ctx.pg is not None
sharder = cast(
ModuleSharder[nn.Module],
FeatureProcessedEmbeddingBagCollectionSharder(),
)

class DummyWrapper(nn.Module):
def __init__(self, sparse_arch):
super().__init__()
self.m = sparse_arch

def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
return self.m(model_input.idlist_features)

max_feature_lengths = [10, 10, 12, 12]
sparse_arch = DummyWrapper(
create_module_and_freeze(
tables=tables, # pyre-ignore[6]
device=ctx.device,
use_fp_collection=False,
max_feature_lengths=max_feature_lengths,
)
)

# compute_kernel = EmbeddingComputeKernel.FUSED.value
module_sharding_plan = construct_module_sharding_plan(
sparse_arch.m._fp_ebc,
per_param_sharding={
"table_0": row_wise(),
"table_1": row_wise(),
"table_2": row_wise(),
"table_3": row_wise(),
},
world_size=2,
device_type=ctx.device.type,
sharder=sharder,
)
sharded_sparse_arch_pipeline = DistributedModelParallel(
module=copy.deepcopy(sparse_arch),
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
sharders=[sharder],
device=ctx.device,
)
sharded_sparse_arch_no_pipeline = DistributedModelParallel(
module=copy.deepcopy(sparse_arch),
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
sharders=[sharder],
device=ctx.device,
)

batches = []
for d in data:
batches.append(d[1][ctx.rank].to(ctx.device))
dataloader = iter(batches)

optimizer_no_pipeline = optim.SGD(
sharded_sparse_arch_no_pipeline.parameters(), lr=0.1
)
optimizer_pipeline = optim.SGD(
sharded_sparse_arch_pipeline.parameters(), lr=0.1
)

pipeline = TrainPipelineSparseDist(
sharded_sparse_arch_pipeline,
optimizer_pipeline,
ctx.device,
)

for batch in batches[:-2]:
batch = batch.to(ctx.device)
optimizer_no_pipeline.zero_grad()
loss, pred = sharded_sparse_arch_no_pipeline(batch)
loss.backward()
optimizer_no_pipeline.step()

pred_pipeline = pipeline.progress(dataloader)
torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu())


class TrainPipelineGPUTest(MultiProcessTestBase):
def setUp(self, backend: str = "nccl") -> None:
super().setUp()

self.pipeline_class = TrainPipelineSparseDist
num_features = 4
num_weighted_features = 4
self.tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 100,
embedding_dim=(i + 1) * 4,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(num_features)
]
self.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 100,
embedding_dim=(i + 1) * 4,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(num_weighted_features)
]

self.backend = backend
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")

if self.backend == "nccl" and self.device == torch.device("cpu"):
self.skipTest("NCCL not supported on CPUs.")

def _generate_data(
self,
num_batches: int = 5,
batch_size: int = 1,
max_feature_lengths: Optional[List[int]] = None,
) -> List[Tuple[ModelInput, List[ModelInput]]]:
return [
ModelInput.generate(
tables=self.tables,
weighted_tables=self.weighted_tables,
batch_size=batch_size,
world_size=2,
num_float_features=10,
max_feature_lengths=max_feature_lengths,
)
for i in range(num_batches)
]

def test_fp_ebc_rw(self) -> None:
data = self._generate_data(max_feature_lengths=[10, 10, 12, 12])
self._run_multi_process_test(
callable=fp_ebc,
world_size=2,
tables=self.tables,
weighted_tables=self.weighted_tables,
data=data,
)


class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setUp(self) -> None:
self.pg = init_distributed_single_host(backend=backend, rank=0, world_size=1)

num_features = 4
num_weighted_features = 2
num_weighted_features = 4
self.tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 100,
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _start_data_dist(
# and this info was done in the _rewrite_model by tracing the
# entire model to get the arg_info_list
args, kwargs = forward.args.build_args_kwargs(batch)
args, kwargs = module.preprocess_input(args, kwargs)

# Start input distribution.
module_ctx = module.create_context()
Expand Down Expand Up @@ -382,6 +383,8 @@ def _rewrite_model( # noqa C901
logger.info(f"Module '{node.target}' will be pipelined")
child = sharded_modules[node.target]
original_forwards.append(child.forward)
# Set pipelining flag on the child module
child.is_pipelined = True
# pyre-ignore[8] Incompatible attribute type
child.forward = pipelined_forward(
node.target,
Expand Down
Loading