Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions distributed/cw_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3

from typing import List, Optional

import torch
import torch.distributed as dist
from torchrec.distributed.embedding_types import (
ShardedEmbeddingTable,
ShardedEmbeddingTableShard,
)
from torchrec.distributed.tw_sharding import TwEmbeddingSharding
from torchrec.distributed.types import (
ShardedTensorMetadata,
ShardMetadata,
)


class CwEmbeddingSharding(TwEmbeddingSharding):
"""
Shards embedding bags table-wise, i.e.. a given embedding table is entirely placed on a selected rank.
"""

def __init__(
self,
sharded_tables: List[ShardedEmbeddingTable],
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> None:
super().__init__(sharded_tables, pg, device)

def _shard(
self, tables: List[ShardedEmbeddingTable]
) -> List[List[ShardedEmbeddingTableShard]]:
world_size = self._pg.size()
tables_per_rank: List[List[ShardedEmbeddingTableShard]] = [
[] for i in range(world_size)
]
for table in tables:
# pyre-fixme [16]
shards: List[ShardMetadata] = table.sharding_spec.shards

# construct the global sharded_tensor_metadata
global_metadata = ShardedTensorMetadata(
shards_metadata=shards,
size=torch.Size([table.num_embeddings, table.embedding_dim]),
)
# pyre-fixme [6]
for i, rank in enumerate(table.ranks):
tables_per_rank[rank].append(
ShardedEmbeddingTableShard(
num_embeddings=table.num_embeddings,
embedding_dim=table.embedding_dim,
name=table.name,
embedding_names=table.embedding_names,
data_type=table.data_type,
feature_names=table.feature_names,
pooling=table.pooling,
compute_kernel=table.compute_kernel,
is_weighted=table.is_weighted,
local_rows=table.num_embeddings,
local_cols=shards[i].shard_lengths[1],
local_metadata=shards[i],
global_metadata=global_metadata,
)
)
return tables_per_rank
24 changes: 18 additions & 6 deletions distributed/dp_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@

import torch
import torch.distributed as dist
from torch.distributed._sharding_spec import ShardMetadata
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding_lookup import (
GroupedPooledEmbeddingsLookup,
GroupedEmbeddingsLookup,
)
from torchrec.distributed.embedding_sharding import (
EmbeddingSharding,
ShardedEmbeddingTable,
group_tables,
BasePooledEmbeddingDist,
BaseSequenceEmbeddingDist,
BaseSparseFeaturesDist,
SequenceShardingContext,
BaseEmbeddingLookup,
)
from torchrec.distributed.embedding_types import GroupedEmbeddingConfig, SparseFeatures
from torchrec.distributed.embedding_types import (
GroupedEmbeddingConfig,
SparseFeatures,
ShardedEmbeddingTable,
ShardedEmbeddingTableShard,
)
from torchrec.distributed.types import Awaitable, NoWait


Expand Down Expand Up @@ -88,15 +93,15 @@ def __init__(

def _shard(
self, tables: List[ShardedEmbeddingTable]
) -> List[List[ShardedEmbeddingTable]]:
) -> List[List[ShardedEmbeddingTableShard]]:
world_size = self._pg.size()
tables_per_rank: List[List[ShardedEmbeddingTable]] = [
tables_per_rank: List[List[ShardedEmbeddingTableShard]] = [
[] for i in range(world_size)
]
for table in tables:
for rank in range(world_size):
tables_per_rank[rank].append(
ShardedEmbeddingTable(
ShardedEmbeddingTableShard(
num_embeddings=table.num_embeddings,
embedding_dim=table.embedding_dim,
name=table.name,
Expand All @@ -106,7 +111,6 @@ def _shard(
pooling=table.pooling,
compute_kernel=table.compute_kernel,
is_weighted=table.is_weighted,
rank=table.rank,
local_rows=table.num_embeddings,
local_cols=table.embedding_dim,
)
Expand Down Expand Up @@ -168,6 +172,14 @@ def embedding_names(self) -> List[str]:
embedding_names.extend(grouped_config.embedding_names())
return embedding_names

def embedding_metadata(self) -> List[Optional[ShardMetadata]]:
embedding_metadata = []
for grouped_config in self._grouped_embedding_configs:
embedding_metadata.extend(grouped_config.embedding_metadata())
for grouped_config in self._score_grouped_embedding_configs:
embedding_metadata.extend(grouped_config.embedding_metadata())
return embedding_metadata

def id_list_feature_names(self) -> List[str]:
id_list_feature_names = []
for grouped_config in self._grouped_embedding_configs:
Expand Down
9 changes: 6 additions & 3 deletions distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import nn
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.cw_sharding import CwEmbeddingSharding
from torchrec.distributed.dp_sharding import DpEmbeddingSharding
from torchrec.distributed.embedding_sharding import (
EmbeddingSharding,
Expand Down Expand Up @@ -52,6 +53,8 @@ def create_embedding_sharding(
return DpEmbeddingSharding(sharded_tables, pg, device)
elif sharding_type == ShardingType.TABLE_ROW_WISE.value:
return TwRwEmbeddingSharding(sharded_tables, pg, device)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return CwEmbeddingSharding(sharded_tables, pg, device)
else:
raise ValueError(f"Sharding not supported {sharding_type}")

Expand Down Expand Up @@ -115,9 +118,9 @@ def _create_sharded_table_configs(
embedding_names=embedding_names,
compute_kernel=compute_kernel,
is_weighted=module.is_weighted,
rank=parameter_sharding.rank
if parameter_sharding.rank is not None
else 0,
# pyre-fixme [6]
sharding_spec=parameter_sharding.sharding_spec,
ranks=parameter_sharding.ranks,
)
)
return sharding_type_to_sharded_tables
Expand Down
15 changes: 10 additions & 5 deletions distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._sharding_spec import ShardMetadata
from torchrec.distributed.dist_data import KJTAllToAll
from torchrec.distributed.embedding_types import (
ShardedEmbeddingTable,
GroupedEmbeddingConfig,
BaseEmbeddingLookup,
SparseFeatures,
EmbeddingComputeKernel,
ShardedEmbeddingTableShard,
)
from torchrec.distributed.types import Awaitable
from torchrec.modules.embedding_configs import (
Expand Down Expand Up @@ -102,10 +103,10 @@ def forward(

# group tables by DataType, PoolingType, Weighted, and EmbeddingComputeKernel.
def group_tables(
tables_per_rank: List[List[ShardedEmbeddingTable]],
tables_per_rank: List[List[ShardedEmbeddingTableShard]],
) -> Tuple[List[List[GroupedEmbeddingConfig]], List[List[GroupedEmbeddingConfig]]]:
def _group_tables_helper(
embedding_tables: List[ShardedEmbeddingTable],
embedding_tables: List[ShardedEmbeddingTableShard],
) -> Tuple[List[GroupedEmbeddingConfig], List[GroupedEmbeddingConfig]]:
grouped_embedding_configs: List[GroupedEmbeddingConfig] = []
score_grouped_embedding_configs: List[GroupedEmbeddingConfig] = []
Expand All @@ -119,8 +120,8 @@ def _group_tables_helper(
EmbeddingComputeKernel.BATCHED_FUSED,
EmbeddingComputeKernel.SSD,
]:
grouped_tables: List[ShardedEmbeddingTable] = []
grouped_score_tables: List[ShardedEmbeddingTable] = []
grouped_tables: List[ShardedEmbeddingTableShard] = []
grouped_score_tables: List[ShardedEmbeddingTableShard] = []
for table in embedding_tables:
if table.compute_kernel in [
EmbeddingComputeKernel.BATCHED_FUSED_UVM,
Expand Down Expand Up @@ -256,6 +257,10 @@ def create_lookup(
def embedding_dims(self) -> List[int]:
pass

@abc.abstractmethod
def embedding_metadata(self) -> List[Optional[ShardMetadata]]:
pass

@abc.abstractmethod
def embedding_names(self) -> List[str]:
pass
Expand Down
43 changes: 33 additions & 10 deletions distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@

import torch
from torch import nn
from torch.distributed._sharded_tensor import ShardedTensorMetadata
from torch.distributed._sharding_spec import ShardMetadata, EnumerableShardingSpec
from torchrec.distributed.types import (
ModuleSharder,
ShardingType,
ShardMetadata,
ShardedTensorMetadata,
ParameterStorage,
)
from torchrec.modules.embedding_configs import (
PoolingType,
DataType,
BaseEmbeddingConfig,
EmbeddingTableConfig,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
Expand Down Expand Up @@ -60,18 +59,36 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:

@dataclass
class ShardedConfig:
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.DENSE
embedding_names: List[str] = field(default_factory=list)
rank: int = 0
local_rows: int = 0
local_cols: int = 0


@dataclass
class ShardedMetaConfig(ShardedConfig):
local_metadata: Optional[ShardMetadata] = None
global_metadata: Optional[ShardedTensorMetadata] = None


@dataclass
class ShardedEmbeddingTable(ShardedConfig, EmbeddingTableConfig):
class EmbeddingAttributes:
embedding_names: List[str] = field(default_factory=list)
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.DENSE


@dataclass
class ShardedEmbeddingTable(
EmbeddingAttributes,
EmbeddingTableConfig,
):
ranks: Optional[List[int]] = None
sharding_spec: Optional[EnumerableShardingSpec] = None


@dataclass
class ShardedEmbeddingTableShard(
ShardedMetaConfig, EmbeddingAttributes, EmbeddingTableConfig
):
pass


@dataclass
Expand All @@ -80,7 +97,7 @@ class GroupedEmbeddingConfig:
pooling: PoolingType
is_weighted: bool
compute_kernel: EmbeddingComputeKernel
embedding_tables: List[ShardedEmbeddingTable]
embedding_tables: List[ShardedEmbeddingTableShard]

def feature_hash_sizes(self) -> List[int]:
feature_hash_sizes = []
Expand All @@ -97,7 +114,7 @@ def num_features(self) -> int:
def dim_sum(self) -> int:
dim_sum = 0
for table in self.embedding_tables:
dim_sum += table.num_features() * table.embedding_dim
dim_sum += table.num_features() * table.local_cols
return dim_sum

def feature_names(self) -> List[str]:
Expand All @@ -109,7 +126,7 @@ def feature_names(self) -> List[str]:
def embedding_dims(self) -> List[int]:
embedding_dims = []
for table in self.embedding_tables:
embedding_dims.extend([table.embedding_dim] * table.num_features())
embedding_dims.extend([table.local_cols] * table.num_features())
return embedding_dims

def embedding_names(self) -> List[str]:
Expand All @@ -118,6 +135,12 @@ def embedding_names(self) -> List[str]:
embedding_names.extend(table.embedding_names)
return embedding_names

def embedding_metadata(self) -> List[Optional[ShardMetadata]]:
embedding_metadata: List[Optional[ShardMetadata]] = []
for table in self.embedding_tables:
embedding_metadata.append(table.local_metadata)
return embedding_metadata


class BaseEmbeddingLookup(abc.ABC, nn.Module):
"""
Expand Down
3 changes: 1 addition & 2 deletions distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
ShardingPlan,
ModuleSharder,
ShardedModule,
ShardedTensor,
)
from torchrec.optim.fused import FusedOptimizerModule
from torchrec.optim.keyed import KeyedOptimizer, CombinedOptimizer
Expand Down Expand Up @@ -88,7 +87,7 @@ def __init__(

# 2. Call ShardingPlanner.plan passing all found modules and corresponding sharders.
if plan is None:
plan = EmbeddingShardingPlanner(self._pg, self.device).collective_plan(
plan = EmbeddingShardingPlanner(self._pg, self.device).plan(
module, sharders
)

Expand Down
4 changes: 3 additions & 1 deletion distributed/planner/cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import math
from typing import Dict

from torchrec.distributed.embedding_types import ShardingType, EmbeddingComputeKernel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner.types import CostInput
from torchrec.distributed.types import ShardingType


# Constants
COMMS_MULTIPLER: Dict[str, int] = {
ShardingType.TABLE_WISE.value: 2,
ShardingType.COLUMN_WISE.value: 2,
ShardingType.ROW_WISE.value: 5,
ShardingType.TABLE_ROW_WISE.value: 3,
ShardingType.DATA_PARALLEL.value: 1,
Expand Down
Loading