diff --git a/distributed/cw_sharding.py b/distributed/cw_sharding.py new file mode 100644 index 000000000..96a93ea98 --- /dev/null +++ b/distributed/cw_sharding.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +from typing import List, Optional + +import torch +import torch.distributed as dist +from torchrec.distributed.embedding_types import ( + ShardedEmbeddingTable, + ShardedEmbeddingTableShard, +) +from torchrec.distributed.tw_sharding import TwEmbeddingSharding +from torchrec.distributed.types import ( + ShardedTensorMetadata, + ShardMetadata, +) + + +class CwEmbeddingSharding(TwEmbeddingSharding): + """ + Shards embedding bags table-wise, i.e.. a given embedding table is entirely placed on a selected rank. + """ + + def __init__( + self, + sharded_tables: List[ShardedEmbeddingTable], + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + ) -> None: + super().__init__(sharded_tables, pg, device) + + def _shard( + self, tables: List[ShardedEmbeddingTable] + ) -> List[List[ShardedEmbeddingTableShard]]: + world_size = self._pg.size() + tables_per_rank: List[List[ShardedEmbeddingTableShard]] = [ + [] for i in range(world_size) + ] + for table in tables: + # pyre-fixme [16] + shards: List[ShardMetadata] = table.sharding_spec.shards + + # construct the global sharded_tensor_metadata + global_metadata = ShardedTensorMetadata( + shards_metadata=shards, + size=torch.Size([table.num_embeddings, table.embedding_dim]), + ) + # pyre-fixme [6] + for i, rank in enumerate(table.ranks): + tables_per_rank[rank].append( + ShardedEmbeddingTableShard( + num_embeddings=table.num_embeddings, + embedding_dim=table.embedding_dim, + name=table.name, + embedding_names=table.embedding_names, + data_type=table.data_type, + feature_names=table.feature_names, + pooling=table.pooling, + compute_kernel=table.compute_kernel, + is_weighted=table.is_weighted, + local_rows=table.num_embeddings, + local_cols=shards[i].shard_lengths[1], + local_metadata=shards[i], + global_metadata=global_metadata, + ) + ) + return tables_per_rank diff --git a/distributed/dp_sharding.py b/distributed/dp_sharding.py index c13d867d3..5d7908102 100644 --- a/distributed/dp_sharding.py +++ b/distributed/dp_sharding.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +from torch.distributed._sharding_spec import ShardMetadata from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_lookup import ( GroupedPooledEmbeddingsLookup, @@ -11,7 +12,6 @@ ) from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, - ShardedEmbeddingTable, group_tables, BasePooledEmbeddingDist, BaseSequenceEmbeddingDist, @@ -19,7 +19,12 @@ SequenceShardingContext, BaseEmbeddingLookup, ) -from torchrec.distributed.embedding_types import GroupedEmbeddingConfig, SparseFeatures +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + SparseFeatures, + ShardedEmbeddingTable, + ShardedEmbeddingTableShard, +) from torchrec.distributed.types import Awaitable, NoWait @@ -88,15 +93,15 @@ def __init__( def _shard( self, tables: List[ShardedEmbeddingTable] - ) -> List[List[ShardedEmbeddingTable]]: + ) -> List[List[ShardedEmbeddingTableShard]]: world_size = self._pg.size() - tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + tables_per_rank: List[List[ShardedEmbeddingTableShard]] = [ [] for i in range(world_size) ] for table in tables: for rank in range(world_size): tables_per_rank[rank].append( - ShardedEmbeddingTable( + ShardedEmbeddingTableShard( num_embeddings=table.num_embeddings, embedding_dim=table.embedding_dim, name=table.name, @@ -106,7 +111,6 @@ def _shard( pooling=table.pooling, compute_kernel=table.compute_kernel, is_weighted=table.is_weighted, - rank=table.rank, local_rows=table.num_embeddings, local_cols=table.embedding_dim, ) @@ -168,6 +172,14 @@ def embedding_names(self) -> List[str]: embedding_names.extend(grouped_config.embedding_names()) return embedding_names + def embedding_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_metadata = [] + for grouped_config in self._grouped_embedding_configs: + embedding_metadata.extend(grouped_config.embedding_metadata()) + for grouped_config in self._score_grouped_embedding_configs: + embedding_metadata.extend(grouped_config.embedding_metadata()) + return embedding_metadata + def id_list_feature_names(self) -> List[str]: id_list_feature_names = [] for grouped_config in self._grouped_embedding_configs: diff --git a/distributed/embedding.py b/distributed/embedding.py index 25ce8b611..f76b7e19f 100644 --- a/distributed/embedding.py +++ b/distributed/embedding.py @@ -9,6 +9,7 @@ from torch import nn from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.cw_sharding import CwEmbeddingSharding from torchrec.distributed.dp_sharding import DpEmbeddingSharding from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, @@ -52,6 +53,8 @@ def create_embedding_sharding( return DpEmbeddingSharding(sharded_tables, pg, device) elif sharding_type == ShardingType.TABLE_ROW_WISE.value: return TwRwEmbeddingSharding(sharded_tables, pg, device) + elif sharding_type == ShardingType.COLUMN_WISE.value: + return CwEmbeddingSharding(sharded_tables, pg, device) else: raise ValueError(f"Sharding not supported {sharding_type}") @@ -115,9 +118,9 @@ def _create_sharded_table_configs( embedding_names=embedding_names, compute_kernel=compute_kernel, is_weighted=module.is_weighted, - rank=parameter_sharding.rank - if parameter_sharding.rank is not None - else 0, + # pyre-fixme [6] + sharding_spec=parameter_sharding.sharding_spec, + ranks=parameter_sharding.ranks, ) ) return sharding_type_to_sharded_tables diff --git a/distributed/embedding_sharding.py b/distributed/embedding_sharding.py index a606d998b..8a6b69e02 100644 --- a/distributed/embedding_sharding.py +++ b/distributed/embedding_sharding.py @@ -7,13 +7,14 @@ import torch import torch.distributed as dist from torch import nn +from torch.distributed._sharding_spec import ShardMetadata from torchrec.distributed.dist_data import KJTAllToAll from torchrec.distributed.embedding_types import ( - ShardedEmbeddingTable, GroupedEmbeddingConfig, BaseEmbeddingLookup, SparseFeatures, EmbeddingComputeKernel, + ShardedEmbeddingTableShard, ) from torchrec.distributed.types import Awaitable from torchrec.modules.embedding_configs import ( @@ -102,10 +103,10 @@ def forward( # group tables by DataType, PoolingType, Weighted, and EmbeddingComputeKernel. def group_tables( - tables_per_rank: List[List[ShardedEmbeddingTable]], + tables_per_rank: List[List[ShardedEmbeddingTableShard]], ) -> Tuple[List[List[GroupedEmbeddingConfig]], List[List[GroupedEmbeddingConfig]]]: def _group_tables_helper( - embedding_tables: List[ShardedEmbeddingTable], + embedding_tables: List[ShardedEmbeddingTableShard], ) -> Tuple[List[GroupedEmbeddingConfig], List[GroupedEmbeddingConfig]]: grouped_embedding_configs: List[GroupedEmbeddingConfig] = [] score_grouped_embedding_configs: List[GroupedEmbeddingConfig] = [] @@ -119,8 +120,8 @@ def _group_tables_helper( EmbeddingComputeKernel.BATCHED_FUSED, EmbeddingComputeKernel.SSD, ]: - grouped_tables: List[ShardedEmbeddingTable] = [] - grouped_score_tables: List[ShardedEmbeddingTable] = [] + grouped_tables: List[ShardedEmbeddingTableShard] = [] + grouped_score_tables: List[ShardedEmbeddingTableShard] = [] for table in embedding_tables: if table.compute_kernel in [ EmbeddingComputeKernel.BATCHED_FUSED_UVM, @@ -256,6 +257,10 @@ def create_lookup( def embedding_dims(self) -> List[int]: pass + @abc.abstractmethod + def embedding_metadata(self) -> List[Optional[ShardMetadata]]: + pass + @abc.abstractmethod def embedding_names(self) -> List[str]: pass diff --git a/distributed/embedding_types.py b/distributed/embedding_types.py index cf578c328..8bad96ae1 100644 --- a/distributed/embedding_types.py +++ b/distributed/embedding_types.py @@ -7,17 +7,16 @@ import torch from torch import nn +from torch.distributed._sharded_tensor import ShardedTensorMetadata +from torch.distributed._sharding_spec import ShardMetadata, EnumerableShardingSpec from torchrec.distributed.types import ( ModuleSharder, ShardingType, - ShardMetadata, - ShardedTensorMetadata, ParameterStorage, ) from torchrec.modules.embedding_configs import ( PoolingType, DataType, - BaseEmbeddingConfig, EmbeddingTableConfig, ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -60,18 +59,36 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None: @dataclass class ShardedConfig: - compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.DENSE - embedding_names: List[str] = field(default_factory=list) - rank: int = 0 local_rows: int = 0 local_cols: int = 0 + + +@dataclass +class ShardedMetaConfig(ShardedConfig): local_metadata: Optional[ShardMetadata] = None global_metadata: Optional[ShardedTensorMetadata] = None @dataclass -class ShardedEmbeddingTable(ShardedConfig, EmbeddingTableConfig): +class EmbeddingAttributes: embedding_names: List[str] = field(default_factory=list) + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.DENSE + + +@dataclass +class ShardedEmbeddingTable( + EmbeddingAttributes, + EmbeddingTableConfig, +): + ranks: Optional[List[int]] = None + sharding_spec: Optional[EnumerableShardingSpec] = None + + +@dataclass +class ShardedEmbeddingTableShard( + ShardedMetaConfig, EmbeddingAttributes, EmbeddingTableConfig +): + pass @dataclass @@ -80,7 +97,7 @@ class GroupedEmbeddingConfig: pooling: PoolingType is_weighted: bool compute_kernel: EmbeddingComputeKernel - embedding_tables: List[ShardedEmbeddingTable] + embedding_tables: List[ShardedEmbeddingTableShard] def feature_hash_sizes(self) -> List[int]: feature_hash_sizes = [] @@ -97,7 +114,7 @@ def num_features(self) -> int: def dim_sum(self) -> int: dim_sum = 0 for table in self.embedding_tables: - dim_sum += table.num_features() * table.embedding_dim + dim_sum += table.num_features() * table.local_cols return dim_sum def feature_names(self) -> List[str]: @@ -109,7 +126,7 @@ def feature_names(self) -> List[str]: def embedding_dims(self) -> List[int]: embedding_dims = [] for table in self.embedding_tables: - embedding_dims.extend([table.embedding_dim] * table.num_features()) + embedding_dims.extend([table.local_cols] * table.num_features()) return embedding_dims def embedding_names(self) -> List[str]: @@ -118,6 +135,12 @@ def embedding_names(self) -> List[str]: embedding_names.extend(table.embedding_names) return embedding_names + def embedding_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_metadata: List[Optional[ShardMetadata]] = [] + for table in self.embedding_tables: + embedding_metadata.append(table.local_metadata) + return embedding_metadata + class BaseEmbeddingLookup(abc.ABC, nn.Module): """ diff --git a/distributed/model_parallel.py b/distributed/model_parallel.py index 254cefef5..814ac3a6d 100644 --- a/distributed/model_parallel.py +++ b/distributed/model_parallel.py @@ -17,7 +17,6 @@ ShardingPlan, ModuleSharder, ShardedModule, - ShardedTensor, ) from torchrec.optim.fused import FusedOptimizerModule from torchrec.optim.keyed import KeyedOptimizer, CombinedOptimizer @@ -88,7 +87,7 @@ def __init__( # 2. Call ShardingPlanner.plan passing all found modules and corresponding sharders. if plan is None: - plan = EmbeddingShardingPlanner(self._pg, self.device).collective_plan( + plan = EmbeddingShardingPlanner(self._pg, self.device).plan( module, sharders ) diff --git a/distributed/planner/cost_functions.py b/distributed/planner/cost_functions.py index 9a2c2ce69..94ace9791 100644 --- a/distributed/planner/cost_functions.py +++ b/distributed/planner/cost_functions.py @@ -3,13 +3,15 @@ import math from typing import Dict -from torchrec.distributed.embedding_types import ShardingType, EmbeddingComputeKernel +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner.types import CostInput +from torchrec.distributed.types import ShardingType # Constants COMMS_MULTIPLER: Dict[str, int] = { ShardingType.TABLE_WISE.value: 2, + ShardingType.COLUMN_WISE.value: 2, ShardingType.ROW_WISE.value: 5, ShardingType.TABLE_ROW_WISE.value: 3, ShardingType.DATA_PARALLEL.value: 1, diff --git a/distributed/planner/embedding_planner.py b/distributed/planner/embedding_planner.py index 437627ab2..63e74a333 100644 --- a/distributed/planner/embedding_planner.py +++ b/distributed/planner/embedding_planner.py @@ -10,6 +10,7 @@ from torchrec.distributed.collective_utils import ( invoke_on_rank_and_broadcast_result, ) +from torchrec.distributed.comm import get_local_size from torchrec.distributed.planner.cost_functions import ( cost_func_compute_based, ) @@ -30,6 +31,7 @@ deallocate_param, param_sort_key, to_plan, + MIN_DIM, ) from torchrec.distributed.types import ( ShardingPlan, @@ -50,6 +52,7 @@ def __init__( cost_functions: Optional[List[Callable[[CostInput], int]]] = None, ) -> None: self._world_size: int = dist.get_world_size(pg) + self._local_size: int = get_local_size() self._hints: Dict[str, ParameterHints] = hints if hints else {} self._input_stats: Dict[str, ParameterInputStats] = ( input_stats if input_stats else {} @@ -119,7 +122,12 @@ def plan( if not self._place(unplaced_param_infos, placed_param_infos): self._backtrack(unplaced_param_infos, placed_param_infos) - return to_plan([param_info for _, param_info in placed_param_infos]) + return to_plan( + [param_info for _, param_info in placed_param_infos], + self._device, + self._world_size, + self._local_size, + ) def _place( self, @@ -135,25 +143,32 @@ def _place( heapq.heapify(candidate_devices) sort_key, param_info = heapq.heappop(unplaced_param_infos) sharding_option = param_info.sharding_options[0] + shards_count = sharding_option.shards_count is_placed = False - if sharding_option.sharding_type == ShardingType.TABLE_WISE.value: + if sharding_option.sharding_type in [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ]: constrained_devices = [] + ranks = [] while candidate_devices: candidate_device = heapq.heappop(candidate_devices) if is_enough_storage(sharding_option, self._topology, candidate_device): - sharding_option.rank = candidate_device.rank + ranks.append(candidate_device.rank) + sharding_option.ranks = ranks allocate_param(sharding_option, self._topology) - heapq.heappush( - placed_param_infos, - ( - param_sort_key(param_info, self._world_size, "storage"), - param_info, - ), - ) heapq.heappush(candidate_devices, candidate_device) - is_placed = True - break + if len(ranks) == shards_count: + heapq.heappush( + placed_param_infos, + ( + param_sort_key(param_info, self._world_size, "storage"), + param_info, + ), + ) + is_placed = True + break constrained_devices.append(candidate_device) for constrained_device in constrained_devices: @@ -163,6 +178,7 @@ def _place( devices_per_host = len(self._topology.hosts[0].devices) candidate_hosts = [0] * num_hosts constrained_devices = [] + ranks = [] while candidate_devices: candidate_device = heapq.heappop(candidate_devices) host_idx, _ = self._topology.host_and_device_by_rank[ @@ -172,7 +188,8 @@ def _place( if candidate_hosts[host_idx] == devices_per_host and is_enough_storage( sharding_option, self._topology, candidate_device ): - sharding_option.rank = candidate_device.rank + ranks.append(candidate_device.rank) + sharding_option.ranks = ranks allocate_param(sharding_option, self._topology) heapq.heappush( placed_param_infos, @@ -194,6 +211,7 @@ def _place( ShardingType.ROW_WISE.value, ]: if is_enough_storage(sharding_option, self._topology): + sharding_option.ranks = None allocate_param(sharding_option, self._topology) heapq.heappush( placed_param_infos, @@ -246,7 +264,6 @@ def _backtrack( if len(placed_param_info.sharding_options) > 1 and not is_option_discarded: placed_param_info.sharding_options.popleft() is_option_discarded = True - placed_param_info.sharding_options[0].rank = None heapq.heappush( unplaced_param_infos, ( @@ -280,6 +297,9 @@ def _get_param_infos( for sharding_type in self._filter_sharding_types( name, sharder.sharding_types ): + shards_count, shard_size = self._get_shards_count_size( + name, param, sharding_type + ) for compute_kernel in self._filter_compute_kernels( name, sharder.compute_kernels(sharding_type, self._device) ): @@ -304,6 +324,8 @@ def _get_param_infos( storage_usage=sharder.storage_usage( param, self._device, compute_kernel ), + shards_count=shards_count, + sharded_dim_block_size=shard_size, ) ) param_infos.append( @@ -343,3 +365,23 @@ def _filter_compute_kernels( f"No available compute kernels after applying hints for {name}" ) return compute_kernels + + def _get_shards_count_size( + self, name: str, param: torch.Tensor, sharding_type: str + ) -> Tuple[Optional[int], Optional[int]]: + shards_count = None + shard_dim = None + if sharding_type == ShardingType.COLUMN_WISE.value: + _hint = self._hints.get(name, None) + shard_dim_hint = None if _hint is None else _hint.shard_dim + shard_dim = shard_dim_hint if shard_dim_hint is not None else MIN_DIM + # column-wise shard the weights + shards_count, residual = divmod(param.shape[1], shard_dim) + assert ( + shards_count > 0 + ), f"the table {name} cannot be column-wise sharded into shards of {shard_dim} dimensions" + if residual > 0: + shards_count += 1 + elif sharding_type == ShardingType.TABLE_WISE.value: + shards_count = 1 + return shards_count, shard_dim diff --git a/distributed/planner/parameter_sharding.py b/distributed/planner/parameter_sharding.py new file mode 100644 index 000000000..8cf41564a --- /dev/null +++ b/distributed/planner/parameter_sharding.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +import abc +import itertools +import math +from typing import List, Tuple, Optional + +import torch +from torch.distributed._sharding_spec import EnumerableShardingSpec, ShardMetadata +from torchrec.distributed.planner.types import ParameterInfo +from torchrec.distributed.types import ShardingType, ParameterSharding + + +def _twrw_shard_table_rows( + table_node: int, + hash_size: int, + embedding_dim: int, + world_size: int, + local_size: int, +) -> Tuple[List[int], List[int], List[int]]: + block_size = math.ceil(hash_size / local_size) + last_block_size = hash_size - block_size * (local_size - 1) + first_local_rank = (table_node) * local_size + last_local_rank = first_local_rank + local_size - 1 + local_rows: List[int] = [] + local_cols: List[int] = [] + local_row_offsets: List[int] = [] + cumul_row_offset = 0 + for rank in range(world_size): + if rank < first_local_rank: + local_row = 0 + local_col = 0 + elif rank < last_local_rank: + local_row = block_size + local_col = embedding_dim + elif rank == last_local_rank: + local_row = last_block_size + local_col = embedding_dim + else: + cumul_row_offset = 0 + local_row = 0 + local_col = 0 + local_rows.append(local_row) + local_cols.append(local_col) + local_row_offsets.append(cumul_row_offset) + cumul_row_offset += local_row + + return (local_rows, local_cols, local_row_offsets) + + +def _rw_shard_table_rows(hash_size: int, world_size: int) -> Tuple[List[int], int, int]: + block_size = (hash_size + world_size - 1) // world_size + last_rank = hash_size // block_size + last_block_size = hash_size - block_size * last_rank + local_rows: List[int] = [] + for rank in range(world_size): + if rank < last_rank: + local_row = block_size + elif rank == last_rank: + local_row = last_block_size + else: + local_row = 0 + local_rows.append(local_row) + return (local_rows, block_size, last_rank) + + +class ParameterShardingFactory(abc.ABC): + @staticmethod + def shard_parameters( + param_info: ParameterInfo, + device: torch.device, + world_size: int, + local_size: Optional[int], + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + sharding_type = sharding_option.sharding_type + if sharding_type == ShardingType.TABLE_WISE.value: + parameter_sharding = TwParameterSharding.shard_parameters( + param_info, device, world_size, local_size + ) + elif sharding_type == ShardingType.ROW_WISE.value: + parameter_sharding = RwParameterSharding.shard_parameters( + param_info, device, world_size, local_size + ) + elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + parameter_sharding = TwRwParameterSharding.shard_parameters( + param_info, device, world_size, local_size + ) + elif sharding_type == ShardingType.COLUMN_WISE.value: + parameter_sharding = CwParameterSharding.shard_parameters( + param_info, device, world_size, local_size + ) + elif sharding_type == ShardingType.DATA_PARALLEL.value: + parameter_sharding = DpParameterSharding.shard_parameters( + param_info, device, world_size, local_size + ) + else: + raise ValueError( + f"unsupported {sharding_option.sharding_type} sharding type" + ) + return parameter_sharding + + +class TwParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + device: torch.device, + world_size: int, + local_size: Optional[int], + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + tensor = param_info.param + # pyre-fixme [16] + rank = sharding_option.ranks[0] + shards = [ + ShardMetadata( + shard_lengths=[ + tensor.shape[0], + tensor.shape[1], + ], + shard_offsets=[0, 0], + placement=f"rank:{rank}/{device}", + ) + ] + return ParameterSharding( + sharding_spec=EnumerableShardingSpec(shards), + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=sharding_option.ranks, + ) + + +class RwParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + device: torch.device, + world_size: int, + local_size: Optional[int], + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + tensor = param_info.param + local_rows, block_size, last_rank = _rw_shard_table_rows( + tensor.shape[0], world_size + ) + shards = [ + ShardMetadata( + shard_lengths=[ + local_rows[rank], + tensor.shape[1], + ], + shard_offsets=[block_size * min(rank, last_rank), 0], + placement=f"rank:{rank}/{device}", + ) + for rank in range(world_size) + ] + return ParameterSharding( + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=sharding_option.ranks, + sharding_spec=EnumerableShardingSpec(shards), + ) + + +class TwRwParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + device: torch.device, + world_size: int, + local_size: Optional[int], + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + tensor = param_info.param + # pyre-fixme [16] + rank = sharding_option.ranks[0] + table_node = rank // local_size + local_rows, local_cols, local_row_offsets = _twrw_shard_table_rows( + table_node=table_node, + hash_size=tensor.shape[0], + embedding_dim=tensor.shape[1], + world_size=world_size, + # pyre-fixme [6] + local_size=local_size, + ) + shards = [ + ShardMetadata( + shard_lengths=[ + local_rows[rank], + local_cols[rank], + ], + shard_offsets=[local_row_offsets[rank], 0], + placement=f"rank:{rank}/{device}", + ) + for rank in range(table_node * local_size, (table_node + 1) * local_size) + ] + + return ParameterSharding( + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=sharding_option.ranks, + sharding_spec=EnumerableShardingSpec(shards), + ) + + +class CwParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + device: torch.device, + world_size: int, + local_size: Optional[int], + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + tensor = param_info.param + # pyre-fixme [6] + ranks = sorted(sharding_option.ranks) + block_size = sharding_option.sharded_dim_block_size + shards_count, residual = divmod(tensor.shape[1], block_size) + sizes = [block_size] * shards_count + if residual > 0: + sizes += [residual] + merged_sizes = [] + merged_ranks = [] + for i, rank in enumerate(ranks): + if rank not in merged_ranks: + merged_ranks.append(rank) + merged_sizes.append(sizes[i]) + else: + merged_sizes[-1] += sizes[i] + offsets = [0] + list(itertools.accumulate(merged_sizes))[:-1] + shards = [ + ShardMetadata( + shard_lengths=[ + tensor.shape[0], + merged_sizes[i], + ], + shard_offsets=[0, offsets[i]], + placement=f"rank:{rank}/{device}", + ) + for i, rank in enumerate(merged_ranks) + ] + return ParameterSharding( + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=merged_ranks, + sharding_spec=EnumerableShardingSpec(shards), + ) + + +class DpParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + device: torch.device, + world_size: int, + local_size: Optional[int], + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + return ParameterSharding( + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=sharding_option.ranks, + ) diff --git a/distributed/planner/tests/test_embedding_planner.py b/distributed/planner/tests/test_embedding_planner.py index 28d44bcf6..e7723989e 100644 --- a/distributed/planner/tests/test_embedding_planner.py +++ b/distributed/planner/tests/test_embedding_planner.py @@ -6,17 +6,50 @@ import torch import torch.distributed as dist +from torch.distributed._sharding_spec import ShardMetadata, EnumerableShardingSpec from torchrec.distributed.embedding import EmbeddingBagCollectionSharder from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner.embedding_planner import EmbeddingShardingPlanner +from torchrec.distributed.planner.parameter_sharding import _rw_shard_table_rows +from torchrec.distributed.planner.types import ParameterHints +from torchrec.distributed.planner.utils import MIN_DIM from torchrec.distributed.tests.test_model import TestSparseNN -from torchrec.distributed.types import ParameterSharding, ShardingType, ShardingPlan +from torchrec.distributed.types import ParameterSharding, ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, ) +class CWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + @property + def sharding_types(self) -> List[str]: + return [ShardingType.COLUMN_WISE.value] + + """ + Restricts to single impl. + """ + + def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class DPCWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + @property + def sharding_types(self) -> List[str]: + return [ + ShardingType.COLUMN_WISE.value, + ShardingType.DATA_PARALLEL.value, + ] + + """ + Restricts to single impl. + """ + + def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + class TWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): @property def sharding_types(self) -> List[str]: @@ -81,32 +114,78 @@ def test_allocation_planner_balanced(self) -> None: for i in range(4) ] storage = {"hbm": 1} - expected_plan: ShardingPlan = ShardingPlan( - { - "sparse.ebc": { - "table_3": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=0, + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + tables[0].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ) + ] ), - "table_2": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=1, + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + tables[1].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:0", + ) + ] ), - "table_1": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=1, + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[2].num_embeddings, + tables[2].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:0", + ) + ] ), - "table_0": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=0, + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[3].num_embeddings, + tables[3].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ) + ] ), - } + ), } - ) + } model = TestSparseNN(tables=tables, weighted_tables=[]) planner = EmbeddingShardingPlanner( @@ -115,8 +194,8 @@ def test_allocation_planner_balanced(self) -> None: sharders = [TWSharder()] # pyre-ignore [6] - plan = planner.plan(model, sharders) - self.assertEqual(plan, expected_plan) + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) def test_allocation_planner_one_big_rest_small(self) -> None: big_hash = int(1024 * 1024 * 1024 / 16 / 4) @@ -133,32 +212,78 @@ def test_allocation_planner_one_big_rest_small(self) -> None: storage = {"hbm": 1} - expected_plan: ShardingPlan = ShardingPlan( - { - "sparse.ebc": { - "table_3": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=1, + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + tables[0].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ) + ] ), - "table_2": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=1, + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + tables[1].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:0", + ) + ] ), - "table_1": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=1, + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[2].num_embeddings, + tables[2].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:0", + ) + ] ), - "table_0": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=0, + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[3].num_embeddings, + tables[3].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:0", + ) + ] ), - } + ), } - ) + } model = TestSparseNN(tables=tables, weighted_tables=[]) planner = EmbeddingShardingPlanner( @@ -166,8 +291,8 @@ def test_allocation_planner_one_big_rest_small(self) -> None: ) sharders = [DPTWSharder()] # pyre-ignore [6] - plan = planner.plan(model, sharders) - self.assertEqual(plan, expected_plan) + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) def test_allocation_planner_two_big_rest_small(self) -> None: big_hash = int(1024 * 1024 * 1024 / 16 / 4) @@ -184,32 +309,56 @@ def test_allocation_planner_two_big_rest_small(self) -> None: storage = {"hbm": 1.1} - expected_plan: ShardingPlan = ShardingPlan( - { - "sparse.ebc": { - "table_3": ParameterSharding( - sharding_type=ShardingType.DATA_PARALLEL.value, - compute_kernel="dense", - rank=None, - ), - "table_2": ParameterSharding( - sharding_type=ShardingType.DATA_PARALLEL.value, - compute_kernel="dense", - rank=None, + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + tables[0].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ) + ] ), - "table_1": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=1, + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + tables[1].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:0", + ) + ] ), - "table_0": ParameterSharding( - sharding_type=ShardingType.TABLE_WISE.value, - compute_kernel="dense", - rank=0, - ), - } + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + sharding_spec=None, + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + sharding_spec=None, + ), } - ) + } model = TestSparseNN(tables=tables, weighted_tables=[]) planner = EmbeddingShardingPlanner( @@ -221,8 +370,8 @@ def test_allocation_planner_two_big_rest_small(self) -> None: ) sharders = [DPRWTWSharder()] # pyre-ignore [6] - plan = planner.plan(model, sharders) - self.assertEqual(plan, expected_plan) + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) def test_allocation_planner_rw_two_big_rest_small(self) -> None: big_hash = int(1024 * 1024 * 1024 / 16 / 4) @@ -236,37 +385,263 @@ def test_allocation_planner_rw_two_big_rest_small(self) -> None: ) for i in range(4) ] - dist.get_world_size = MagicMock(return_value=4) - + local_rows, block_size, last_rank = _rw_shard_table_rows(big_hash, 4) storage = {"hbm": 0.6} - expected_plan: ShardingPlan = ShardingPlan( - { - "sparse.ebc": { - "table_3": ParameterSharding( - sharding_type=ShardingType.DATA_PARALLEL.value, - compute_kernel="dense", - rank=None, - ), - "table_2": ParameterSharding( - sharding_type=ShardingType.DATA_PARALLEL.value, - compute_kernel="dense", - rank=None, + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.ROW_WISE.value, + compute_kernel="dense", + ranks=None, + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + local_rows[0], + tables[0].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + local_rows[1], + tables[0].embedding_dim, + ], + shard_offsets=[block_size, 0], + placement="rank:1/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + local_rows[2], + tables[0].embedding_dim, + ], + shard_offsets=[2 * block_size, 0], + placement="rank:2/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + local_rows[3], + tables[0].embedding_dim, + ], + shard_offsets=[3 * block_size, 0], + placement="rank:3/cuda:0", + ), + ], ), - "table_1": ParameterSharding( - sharding_type=ShardingType.ROW_WISE.value, - compute_kernel="dense", - rank=None, + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.ROW_WISE.value, + compute_kernel="dense", + ranks=None, + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + local_rows[0], + tables[1].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + local_rows[1], + tables[1].embedding_dim, + ], + shard_offsets=[block_size, 0], + placement="rank:1/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + local_rows[2], + tables[1].embedding_dim, + ], + shard_offsets=[2 * block_size, 0], + placement="rank:2/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + local_rows[3], + tables[1].embedding_dim, + ], + shard_offsets=[3 * block_size, 0], + placement="rank:3/cuda:0", + ), + ], ), - "table_0": ParameterSharding( - sharding_type=ShardingType.ROW_WISE.value, - compute_kernel="dense", - rank=None, + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + ), + } + } + model = TestSparseNN(tables=tables, weighted_tables=[]) + + planner = EmbeddingShardingPlanner( + pg=self.pg, + device=self.device, + # pyre-fixme[6]: Expected `Optional[typing.Dict[str, int]]` for 3rd + # param but got `Dict[str, float]`. + storage=storage, + ) + sharders = [DPRWTWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + + def test_allocation_planner_cw_balanced(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=128, + name="table_0", + feature_names=["feature_0"], + ) + ] + storage = {"hbm": 1} + block_size, residual = divmod(128, 2) + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.COLUMN_WISE.value, + compute_kernel="dense", + ranks=[ + 0, + 1, + ], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + block_size, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + block_size, + ], + shard_offsets=[0, block_size], + placement="rank:1/cuda:0", + ), + ] ), - } + ), } + } + + model = TestSparseNN(tables=tables, weighted_tables=[]) + planner = EmbeddingShardingPlanner( + pg=self.pg, + device=self.device, + storage=storage, + hints={ + "table_0": ParameterHints( + sharding_types=[ShardingType.COLUMN_WISE.value], + ), + }, ) + + sharders = [CWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + + def test_allocation_planner_cw_two_big_rest_small_with_residual(self) -> None: + big_hash = int(1024 * 1024 * 1024 / 16 / 4) + small_hash = 1000 + tables = [ + EmbeddingBagConfig( + num_embeddings=(big_hash if i <= 1 else small_hash) // 4, + embedding_dim=62, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + + dist.get_world_size = MagicMock(return_value=4) + block_size, residual = divmod(62, MIN_DIM) + + storage = {"hbm": 0.6} + + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.COLUMN_WISE.value, + compute_kernel="dense", + ranks=[0, 1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + MIN_DIM, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + residual, + ], + shard_offsets=[0, MIN_DIM], + placement="rank:1/cuda:0", + ), + ] + ), + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.COLUMN_WISE.value, + compute_kernel="dense", + ranks=[2, 3], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + MIN_DIM, + ], + shard_offsets=[0, 0], + placement="rank:2/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + residual, + ], + shard_offsets=[0, MIN_DIM], + placement="rank:3/cuda:0", + ), + ] + ), + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + ), + } + } model = TestSparseNN(tables=tables, weighted_tables=[]) planner = EmbeddingShardingPlanner( @@ -275,8 +650,16 @@ def test_allocation_planner_rw_two_big_rest_small(self) -> None: # pyre-fixme[6]: Expected `Optional[typing.Dict[str, int]]` for 3rd # param but got `Dict[str, float]`. storage=storage, + hints={ + "table_0": ParameterHints( + sharding_types=[ShardingType.COLUMN_WISE.value], + ), + "table_1": ParameterHints( + sharding_types=[ShardingType.COLUMN_WISE.value], + ), + }, ) - sharders = [DPRWTWSharder()] + sharders = [DPCWSharder()] # pyre-ignore [6] - plan = planner.plan(model, sharders) - self.assertEqual(plan, expected_plan) + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) diff --git a/distributed/planner/types.py b/distributed/planner/types.py index 0ef1fe51b..26ca549dc 100644 --- a/distributed/planner/types.py +++ b/distributed/planner/types.py @@ -16,6 +16,7 @@ class ParameterHints: sharding_types: Optional[List[str]] = None compute_kernels: Optional[List[str]] = None + shard_dim: Optional[int] = None @dataclass @@ -80,7 +81,9 @@ class ShardingOption: compute_kernel: str storage_usage: Dict[str, int] cost: int = 0 - rank: Optional[int] = None + ranks: Optional[List[int]] = None + shards_count: Optional[int] = None + sharded_dim_block_size: Optional[int] = None def __lt__(self, other: "ShardingOption") -> bool: """ diff --git a/distributed/planner/utils.py b/distributed/planner/utils.py index 60d479aff..e81be97ab 100644 --- a/distributed/planner/utils.py +++ b/distributed/planner/utils.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 import math -from typing import Any, Type, Dict, Optional, List, cast +from typing import Any, Type, Dict, Optional, List, cast, Tuple import torch import torch.distributed as dist from torchrec.distributed.comm import get_local_size, get_num_groups +from torchrec.distributed.planner.parameter_sharding import ParameterShardingFactory from torchrec.distributed.planner.types import ( ShardingOption, Topology, @@ -19,16 +20,17 @@ ParameterStorage, ShardingPlan, ShardingType, - ParameterSharding, ) MAX_DDR_STORAGE: int = 4 * 1024 * 1024 * 1024 * 1024 # 4 TB +MIN_DIM: int = 32 SHARDING_PREFERENCE: Dict[str, int] = { ShardingType.DATA_PARALLEL.value: 0, ShardingType.TABLE_WISE.value: 1, ShardingType.TABLE_ROW_WISE.value: 2, ShardingType.ROW_WISE.value: 3, + ShardingType.COLUMN_WISE.value: 4, } @@ -74,9 +76,19 @@ def is_enough_storage( ), "Sharding option must have a device for TW storage calcuation" device_ranks = [device.rank] host_ranks = [topology.host_and_device_by_rank[device.rank][0]] + elif sharding_option.sharding_type == ShardingType.COLUMN_WISE.value: + assert ( + device is not None + ), "Sharding option must have a device for CW storage calcuation" + device_ranks = [device.rank] + host_ranks = [topology.host_and_device_by_rank[device.rank][0]] + storage = { + # pyre-fixme[58] + k: math.ceil(v / sharding_option.shards_count) + for k, v in storage.items() + } else: raise ValueError(f"unsupported sharding_type {sharding_option.sharding_type}") - for storage_type, storage_usage in storage.items(): if storage_type == ParameterStorage.HBM.value: for device_rank in device_ranks: @@ -115,16 +127,16 @@ def allocate_param( } elif sharding_option.sharding_type == ShardingType.TABLE_ROW_WISE.value: assert ( - sharding_option.rank is not None + sharding_option.ranks is not None ), "Sharding option must have a device for TWRW storage calcuation" device_ranks = [ device.rank # pyre-fixme[22]: The cast is redundant. - for device in topology.get_host(cast(int, sharding_option.rank)).devices + for device in topology.get_host(cast(int, sharding_option.ranks[0])).devices ] host_ranks = [ # pyre-fixme[22]: The cast is redundant. - topology.host_and_device_by_rank[cast(int, sharding_option.rank)][0] + topology.host_and_device_by_rank[cast(int, sharding_option.ranks[0])][0] ] storage = { k: math.ceil(v / len(device_ranks if k == "hbm" else host_ranks)) @@ -132,11 +144,24 @@ def allocate_param( } elif sharding_option.sharding_type == ShardingType.TABLE_WISE.value: assert ( - sharding_option.rank is not None + sharding_option.ranks is not None ), "Sharding option must have a device for TW storage calcuation" # pyre-fixme[22]: The cast is redundant. - device_ranks = [cast(int, sharding_option.rank)] - host_ranks = [topology.host_and_device_by_rank[sharding_option.rank][0]] + device_ranks = [cast(int, sharding_option.ranks[0])] + # pyre-fixme[16] + host_ranks = [topology.host_and_device_by_rank[sharding_option.ranks[0]][0]] + elif sharding_option.sharding_type == ShardingType.COLUMN_WISE.value: + assert ( + sharding_option.ranks is not None + ), "Sharding option must have at least one device for CW storage calcuation" + # for col-wise sharding, we allocate one shard at a time + device_ranks = [sharding_option.ranks[-1]] + host_ranks = [topology.host_and_device_by_rank[sharding_option.ranks[-1]][0]] + storage = { + # pyre-fixme[58] + k: math.ceil(v / sharding_option.shards_count) + for k, v in storage.items() + } else: raise ValueError(f"unsupported sharding_type {sharding_option.sharding_type}") @@ -189,15 +214,18 @@ def param_sort_key( def to_plan( parameter_infos: List[ParameterInfo], + device: torch.device, + world_size: int, + local_size: Optional[int], ) -> ShardingPlan: plan = {} for parameter_info in parameter_infos: - sharding_option = parameter_info.sharding_options[0] shards = plan.get(parameter_info.prefix, {}) - shards[parameter_info.name] = ParameterSharding( - sharding_type=sharding_option.sharding_type, - compute_kernel=sharding_option.compute_kernel, - rank=sharding_option.rank, + shards[parameter_info.name] = ParameterShardingFactory.shard_parameters( + param_info=parameter_info, + device=device, + world_size=world_size, + local_size=local_size, ) plan[parameter_info.prefix] = shards return ShardingPlan(plan) diff --git a/distributed/rw_sharding.py b/distributed/rw_sharding.py index 910426e27..9ea826a62 100644 --- a/distributed/rw_sharding.py +++ b/distributed/rw_sharding.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 -from typing import List, Optional, Tuple, Dict, Any +from typing import List, Optional, Dict, Any import torch import torch.distributed as dist +from torch.distributed._sharding_spec import ShardMetadata from torchrec.distributed.dist_data import ( PooledEmbeddingsReduceScatter, SequenceEmbeddingAllToAll, @@ -15,7 +16,6 @@ from torchrec.distributed.embedding_sharding import ( group_tables, SparseFeaturesAllToAll, - ShardedEmbeddingTable, BasePooledEmbeddingDist, BaseSparseFeaturesDist, EmbeddingSharding, @@ -23,10 +23,14 @@ SequenceShardingContext, BaseEmbeddingLookup, ) -from torchrec.distributed.embedding_types import GroupedEmbeddingConfig, SparseFeatures +from torchrec.distributed.embedding_types import ( + ShardedEmbeddingTable, + GroupedEmbeddingConfig, + SparseFeatures, + ShardedEmbeddingTableShard, +) from torchrec.distributed.types import ( ShardedTensorMetadata, - ShardMetadata, Awaitable, ) @@ -189,54 +193,25 @@ def __init__( def _shard( self, tables: List[ShardedEmbeddingTable] - ) -> List[List[ShardedEmbeddingTable]]: + ) -> List[List[ShardedEmbeddingTableShard]]: world_size = self._pg.size() - tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + tables_per_rank: List[List[ShardedEmbeddingTableShard]] = [ [] for i in range(world_size) ] - def _shard_table_rows( - hash_size: int, world_size: int - ) -> Tuple[List[int], int, int]: - block_size = (hash_size + world_size - 1) // world_size - last_rank = hash_size // block_size - last_block_size = hash_size - block_size * last_rank - local_rows: List[int] = [] - for rank in range(world_size): - if rank < last_rank: - local_row = block_size - elif rank == last_rank: - local_row = last_block_size - else: - local_row = 0 - local_rows.append(local_row) - return (local_rows, block_size, last_rank) - for table in tables: - local_rows, block_size, last_rank = _shard_table_rows( - table.num_embeddings, world_size - ) - shards = [ - ShardMetadata( - shard_lengths=[ - local_rows[rank], - table.embedding_dim, - ], - shard_offsets=[block_size * min(rank, last_rank), 0], - placement=f"rank:{rank}/{self._device}", - ) - for rank in range(world_size) - ] + # pyre-fixme [16] + shards = table.sharding_spec.shards # construct the global sharded_tensor_metadata - table.global_metadata = ShardedTensorMetadata( + global_metadata = ShardedTensorMetadata( shards_metadata=shards, size=torch.Size([table.num_embeddings, table.embedding_dim]), ) for rank in range(world_size): tables_per_rank[rank].append( - ShardedEmbeddingTable( + ShardedEmbeddingTableShard( num_embeddings=table.num_embeddings, embedding_dim=table.embedding_dim, name=table.name, @@ -246,11 +221,10 @@ def _shard_table_rows( pooling=table.pooling, compute_kernel=table.compute_kernel, is_weighted=table.is_weighted, - rank=table.rank, - local_rows=local_rows[rank], + local_rows=shards[rank].shard_lengths[0], local_cols=table.embedding_dim, local_metadata=shards[rank], - global_metadata=table.global_metadata, + global_metadata=global_metadata, ) ) return tables_per_rank @@ -314,6 +288,14 @@ def embedding_names(self) -> List[str]: embedding_names.extend(grouped_config.embedding_names()) return embedding_names + def embedding_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_metadata = [] + for grouped_config in self._grouped_embedding_configs: + embedding_metadata.extend(grouped_config.embedding_metadata()) + for grouped_config in self._score_grouped_embedding_configs: + embedding_metadata.extend(grouped_config.embedding_metadata()) + return embedding_metadata + def id_list_feature_names(self) -> List[str]: id_list_feature_names = [] for grouped_config in self._grouped_embedding_configs: diff --git a/distributed/tests/test_model_parallel.py b/distributed/tests/test_model_parallel.py index b00f0b10f..9c09a1848 100644 --- a/distributed/tests/test_model_parallel.py +++ b/distributed/tests/test_model_parallel.py @@ -4,7 +4,7 @@ import os import unittest from collections import OrderedDict -from typing import List, Tuple, Optional, Callable, cast +from typing import List, Tuple, Optional, Callable, Dict, cast import hypothesis.strategies as st import numpy as np @@ -14,6 +14,8 @@ from hypothesis import Verbosity, given, settings from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.planner.embedding_planner import EmbeddingShardingPlanner +from torchrec.distributed.planner.types import ParameterHints from torchrec.distributed.tests.test_model import ( TestSparseNN, TestEBCSharder, @@ -23,7 +25,6 @@ ModuleSharder, ShardedTensor, ShardingPlan, - ShardingPlanner, ShardingType, ) from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -283,7 +284,7 @@ def setUp(self) -> None: self.tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, - embedding_dim=(i + 1) * 4, + embedding_dim=(i + 2) * 4, name="table_" + str(i), feature_names=["feature_" + str(i)], ) @@ -292,7 +293,7 @@ def setUp(self) -> None: self.weighted_tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, - embedding_dim=(i + 1) * 4, + embedding_dim=(i + 2) * 4, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], ) @@ -307,6 +308,7 @@ def _run_multi_process_test( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], backend: str, + hints: Optional[Dict[str, ParameterHints]] = None, local_size: Optional[int] = None, ) -> List[torch.Tensor]: mgr = multiprocessing.Manager() @@ -324,6 +326,7 @@ def _run_multi_process_test( sharders, outputs, backend, + hints, ), kwargs={ "local_size": local_size, @@ -359,6 +362,7 @@ def _test_sharding( backend: str = "gloo", world_size: int = 2, local_size: Optional[int] = None, + hints: Optional[Dict[str, ParameterHints]] = None, ) -> None: # Run distributed training and collect predictions. @@ -373,6 +377,7 @@ def _test_sharding( # but got `List[TestEBCSharder]`. sharders=sharders, backend=backend, + hints=hints, ) full_pred = self._gen_full_pred_after_one_step( @@ -393,7 +398,7 @@ def _test_sharding_single_rank( sharders: List[ModuleSharder[nn.Module]], outputs: List[torch.Tensor], backend: str, - planner: Optional[ShardingPlanner] = None, + hints: Optional[Dict[str, ParameterHints]] = None, local_size: Optional[int] = None, ) -> None: # Generate model & inputs. @@ -404,7 +409,6 @@ def _test_sharding_single_rank( # Instantiate lazy modules. with torch.no_grad(): global_model(local_input) - if backend == "nccl": device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) @@ -417,7 +421,7 @@ def _test_sharding_single_rank( backend=backend, local_size=local_size, ) - + planner = EmbeddingShardingPlanner(pg, device, hints) # Shard model. local_model = TestSparseNN( sparse_device=torch.device("meta"), @@ -426,10 +430,8 @@ def _test_sharding_single_rank( dense_device=device, ) plan: Optional[ShardingPlan] - if planner: - plan = planner.collective_plan(local_model, sharders) - else: - plan = None + plan = planner.plan(local_model, sharders) + local_model = DistributedModelParallel( local_model, pg=pg, diff --git a/distributed/tw_sharding.py b/distributed/tw_sharding.py index 1db51228a..b9c6a1ff0 100644 --- a/distributed/tw_sharding.py +++ b/distributed/tw_sharding.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +from torch.distributed._sharding_spec import ShardMetadata from torchrec.distributed.dist_data import ( PooledEmbeddingsAllToAll, SequenceEmbeddingAllToAll, @@ -15,7 +16,6 @@ from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, SparseFeaturesAllToAll, - ShardedEmbeddingTable, group_tables, BasePooledEmbeddingDist, BaseSequenceEmbeddingDist, @@ -23,10 +23,14 @@ SequenceShardingContext, BaseEmbeddingLookup, ) -from torchrec.distributed.embedding_types import GroupedEmbeddingConfig, SparseFeatures +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + SparseFeatures, + ShardedEmbeddingTable, + ShardedEmbeddingTableShard, +) from torchrec.distributed.types import ( ShardedTensorMetadata, - ShardMetadata, Awaitable, ) @@ -130,34 +134,25 @@ def __init__( def _shard( self, tables: List[ShardedEmbeddingTable] - ) -> List[List[ShardedEmbeddingTable]]: + ) -> List[List[ShardedEmbeddingTableShard]]: world_size = self._pg.size() - tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + tables_per_rank: List[List[ShardedEmbeddingTableShard]] = [ [] for i in range(world_size) ] for table in tables: - rank = table.rank - shards: List[ShardMetadata] = [] - local_metadata = ShardMetadata( - shard_lengths=[ - table.num_embeddings, - table.embedding_dim, - ], - shard_offsets=[0, 0], - placement=f"rank:{rank}/{self._device}", - ) - for r in range(world_size): - if r == rank: - shards.append(local_metadata) + # pyre-fixme [16] + rank = table.ranks[0] + # pyre-fixme [16] + shards = table.sharding_spec.shards # construct the global sharded_tensor_metadata - table.global_metadata = ShardedTensorMetadata( + global_metadata = ShardedTensorMetadata( shards_metadata=shards, size=torch.Size([table.num_embeddings, table.embedding_dim]), ) tables_per_rank[rank].append( - ShardedEmbeddingTable( + ShardedEmbeddingTableShard( num_embeddings=table.num_embeddings, embedding_dim=table.embedding_dim, name=table.name, @@ -167,11 +162,10 @@ def _shard( pooling=table.pooling, compute_kernel=table.compute_kernel, is_weighted=table.is_weighted, - rank=table.rank, local_rows=table.num_embeddings, local_cols=table.embedding_dim, - local_metadata=local_metadata, - global_metadata=table.global_metadata, + local_metadata=shards[0], + global_metadata=global_metadata, ) ) return tables_per_rank @@ -256,6 +250,18 @@ def embedding_names(self) -> List[str]: embedding_names.extend(grouped_config.embedding_names()) return embedding_names + def embedding_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_metadata = [] + for grouped_embedding_configs, score_grouped_embedding_configs in zip( + self._grouped_embedding_configs_per_rank, + self._score_grouped_embedding_configs_per_rank, + ): + for grouped_config in grouped_embedding_configs: + embedding_metadata.extend(grouped_config.embedding_metadata()) + for grouped_config in score_grouped_embedding_configs: + embedding_metadata.extend(grouped_config.embedding_metadata()) + return embedding_metadata + def id_list_feature_names(self) -> List[str]: id_list_feature_names = [] for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: diff --git a/distributed/twrw_sharding.py b/distributed/twrw_sharding.py index 50d746d23..bdec4d839 100644 --- a/distributed/twrw_sharding.py +++ b/distributed/twrw_sharding.py @@ -2,10 +2,11 @@ import itertools import math -from typing import List, Optional, Tuple, Dict, Any +from typing import List, Optional, Dict, Any import torch import torch.distributed as dist +from torch.distributed._sharding_spec import ShardMetadata from torchrec.distributed.comm import intra_and_cross_node_pg from torchrec.distributed.dist_data import ( PooledEmbeddingsReduceScatter, @@ -15,17 +16,20 @@ from torchrec.distributed.embedding_sharding import ( group_tables, SparseFeaturesAllToAll, - ShardedEmbeddingTable, BasePooledEmbeddingDist, BaseSequenceEmbeddingDist, BaseSparseFeaturesDist, EmbeddingSharding, BaseEmbeddingLookup, ) -from torchrec.distributed.embedding_types import GroupedEmbeddingConfig, SparseFeatures +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + SparseFeatures, + ShardedEmbeddingTable, + ShardedEmbeddingTableShard, +) from torchrec.distributed.types import ( ShardedTensorMetadata, - ShardMetadata, Awaitable, ) @@ -229,74 +233,21 @@ def __init__( def _shard( self, tables: List[ShardedEmbeddingTable] - ) -> List[List[ShardedEmbeddingTable]]: + ) -> List[List[ShardedEmbeddingTableShard]]: world_size = self._world_size local_size = self._local_size - tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + tables_per_rank: List[List[ShardedEmbeddingTableShard]] = [ [] for i in range(world_size) ] - def _shard_table_rows( - table_node: int, - hash_size: int, - embedding_dim: int, - world_size: int, - local_size: int, - ) -> Tuple[List[int], List[int], List[int]]: - block_size = math.ceil(hash_size / local_size) - last_block_size = hash_size - block_size * (local_size - 1) - first_local_rank = (table_node) * local_size - last_local_rank = first_local_rank + local_size - 1 - local_rows: List[int] = [] - local_cols: List[int] = [] - local_row_offsets: List[int] = [] - cumul_row_offset = 0 - for rank in range(world_size): - if rank < first_local_rank: - local_row = 0 - local_col = 0 - elif rank < last_local_rank: - local_row = block_size - local_col = embedding_dim - elif rank == last_local_rank: - local_row = last_block_size - local_col = embedding_dim - else: - cumul_row_offset = 0 - local_row = 0 - local_col = 0 - local_rows.append(local_row) - local_cols.append(local_col) - local_row_offsets.append(cumul_row_offset) - cumul_row_offset += local_row - - return (local_rows, local_cols, local_row_offsets) - for table in tables: - table_node = table.rank // local_size - local_rows, local_cols, local_row_offsets = _shard_table_rows( - table_node=table_node, - embedding_dim=table.embedding_dim, - hash_size=table.num_embeddings, - world_size=world_size, - local_size=local_size, - ) - shards = [ - ShardMetadata( - shard_lengths=[ - local_rows[rank], - local_cols[rank], - ], - shard_offsets=[local_row_offsets[rank], 0], - placement=f"rank:{rank}/{self._device}", - ) - for rank in range( - table_node * local_size, (table_node + 1) * local_size - ) - ] + # pyre-ignore [16] + table_node = table.ranks[0] // local_size + # pyre-fixme [16] + shards = table.sharding_spec.shards # construct the global sharded_tensor_metadata - table.global_metadata = ShardedTensorMetadata( + global_metadata = ShardedTensorMetadata( shards_metadata=shards, size=torch.Size([table.num_embeddings, table.embedding_dim]), ) @@ -307,7 +258,7 @@ def _shard_table_rows( ): rank_idx = rank - (table_node * local_size) tables_per_rank[rank].append( - ShardedEmbeddingTable( + ShardedEmbeddingTableShard( num_embeddings=table.num_embeddings, embedding_dim=table.embedding_dim, name=table.name, @@ -317,11 +268,10 @@ def _shard_table_rows( pooling=table.pooling, compute_kernel=table.compute_kernel, is_weighted=table.is_weighted, - rank=table.rank, - local_rows=local_rows[rank], + local_rows=shards[rank_idx].shard_lengths[0], local_cols=table.embedding_dim, local_metadata=shards[rank_idx], - global_metadata=table.global_metadata, + global_metadata=global_metadata, ) ) return tables_per_rank @@ -397,6 +347,16 @@ def embedding_names(self) -> List[str]: embedding_names.extend(grouped_config.embedding_names()) return embedding_names + def embedding_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_metadata = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + embedding_metadata.extend(config.embedding_metadata()) + for grouped_config in self._score_grouped_embedding_configs_per_node: + for config in grouped_config: + embedding_metadata.extend(config.embedding_metadata()) + return embedding_metadata + def id_list_feature_names(self) -> List[str]: id_list_feature_names = [] for grouped_config in self._grouped_embedding_configs_per_node: diff --git a/distributed/types.py b/distributed/types.py index a0e9d7c8e..87ada2cc4 100644 --- a/distributed/types.py +++ b/distributed/types.py @@ -6,6 +6,8 @@ from enum import Enum, unique from typing import Any, Dict, Generic, Optional, TypeVar, List, Type +from torch.distributed._sharding_spec import ShardingSpec + try: # For python 3.6 and below, GenericMeta will be used by # other metaclasses (i.e. AwaitableMeta) for customized @@ -46,6 +48,8 @@ class ShardingType(Enum): DATA_PARALLEL = "data_parallel" # Placed on a single rank TABLE_WISE = "table_wise" + # Placed on multiple ranks as different sharded tables + COLUMN_WISE = "column_wise" # Range-split on the first dimension across all ranks ROW_WISE = "row_wise" # Row-wise on the same node and table-wise across nodes @@ -250,10 +254,12 @@ class ParameterSharding: """ ShardingType.TABLE_WISE - rank where this embedding is placed + ShardingType.COLUMN_WISE - rank where this embedding shards are placed, we see them as individual tables ShardingType.TABLE_ROW_WISE - first rank when this embedding is placed ShardingType.ROW_WISE, ShardingType.DATA_PARALLEL - unused """ - rank: Optional[int] = None + ranks: Optional[List[int]] = None + sharding_spec: Optional[ShardingSpec] = None @dataclass