From 4ea42e2f44f4485fd35865323236a0b277873216 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Sun, 31 Mar 2024 19:50:32 -0700 Subject: [PATCH] mean pooling in EBC/VBE (#1772) Summary: This diff supports mean pooling for Row Wise/Table Row Wise sharding schemes. This is achieved through applying mean pooling post reduce scatter collective as the KeyedTensor awaitable is created. The implementation is done through a callback. Reviewed By: dstaay-fb Differential Revision: D54656612 --- .../distributed/batched_embedding_kernel.py | 12 +- torchrec/distributed/embedding_lookup.py | 8 +- torchrec/distributed/embeddingbag.py | 189 ++++++++++++++++-- torchrec/distributed/sharding/rw_sharding.py | 2 + .../distributed/sharding/twrw_sharding.py | 2 + .../test_utils/test_model_parallel.py | 91 ++++++++- torchrec/modules/embedding_configs.py | 33 ++- 7 files changed, 312 insertions(+), 25 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index e1091b961..0d558edb1 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -50,6 +50,7 @@ Shard, ShardedTensor, ShardedTensorMetadata, + ShardingType, ShardMetadata, TensorProperties, ) @@ -720,13 +721,16 @@ def __init__( config: GroupedEmbeddingConfig, pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") self._config = config self._pg = pg - self._pooling: PoolingMode = pooling_type_to_pooling_mode(config.pooling) + self._pooling: PoolingMode = pooling_type_to_pooling_mode( + config.pooling, sharding_type # pyre-ignore[6] + ) self._local_rows: List[int] = [] self._weight_init_mins: List[float] = [] @@ -859,8 +863,9 @@ def __init__( config: GroupedEmbeddingConfig, pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, ) -> None: - super().__init__(config, pg, device) + super().__init__(config, pg, device, sharding_type) managed: List[EmbeddingLocation] = [] compute_devices: List[ComputeDevice] = [] @@ -962,8 +967,9 @@ def __init__( config: GroupedEmbeddingConfig, pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, ) -> None: - super().__init__(config, pg, device) + super().__init__(config, pg, device, sharding_type) weights_precision = data_type_to_sparse_type(config.data_type) fused_params = config.fused_params or {} diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 9601bbf3b..9a39fa90f 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -52,7 +52,7 @@ QuantBatchedEmbedding, QuantBatchedEmbeddingBag, ) -from torchrec.distributed.types import ShardedTensor +from torchrec.distributed.types import ShardedTensor, ShardingType from torchrec.sparse.jagged_tensor import KeyedJaggedTensor logger: logging.Logger = logging.getLogger(__name__) @@ -344,23 +344,27 @@ def __init__( pg: Optional[dist.ProcessGroup] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None, scale_weight_gradients: bool = True, + sharding_type: Optional[ShardingType] = None, ) -> None: # TODO rename to _create_embedding_kernel def _create_lookup( config: GroupedEmbeddingConfig, device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, ) -> BaseEmbedding: if config.compute_kernel == EmbeddingComputeKernel.DENSE: return BatchedDenseEmbeddingBag( config=config, pg=pg, device=device, + sharding_type=sharding_type, ) elif config.compute_kernel == EmbeddingComputeKernel.FUSED: return BatchedFusedEmbeddingBag( config=config, pg=pg, device=device, + sharding_type=sharding_type, ) else: raise ValueError( @@ -370,7 +374,7 @@ def _create_lookup( super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config, device)) + self._emb_modules.append(_create_lookup(config, device, sharding_type)) self._feature_splits: List[int] = [] for config in grouped_configs: diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 2ec8833c2..54e703205 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -8,10 +8,11 @@ # pyre-strict import copy -from collections import OrderedDict +from collections import defaultdict, OrderedDict from dataclasses import dataclass, field from typing import ( Any, + Callable, cast, Dict, Iterator, @@ -27,6 +28,7 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings from torch import nn, Tensor +from torch.autograd.profiler import record_function from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_sharding import ( @@ -79,7 +81,7 @@ ) from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -378,6 +380,7 @@ class EmbeddingBagCollectionContext(Multistreamable): ) inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None variable_batch_per_feature: bool = False + mean_pooling_callback: Optional[Callable[[KeyedTensor], KeyedTensor]] = None def record_stream(self, stream: torch.cuda.streams.Stream) -> None: for ctx in self.sharding_contexts: @@ -415,13 +418,22 @@ def __init__( self._embedding_bag_configs: List[EmbeddingBagConfig] = ( module.embedding_bag_configs() ) - self._table_names: List[str] = [ - config.name for config in self._embedding_bag_configs - ] - self._table_name_to_config: Dict[str, EmbeddingBagConfig] = { - config.name: config for config in self._embedding_bag_configs - } + self._table_names: List[str] = [] + self._pooling_type_to_rs_features: Dict[str, List[str]] = defaultdict(list) + self._table_name_to_config: Dict[str, EmbeddingBagConfig] = {} + + for config in self._embedding_bag_configs: + self._table_names.append(config.name) + self._table_name_to_config[config.name] = config + + if table_name_to_parameter_sharding[config.name].sharding_type in [ + ShardingType.TABLE_ROW_WISE.value, + ShardingType.ROW_WISE.value, + ]: + self._pooling_type_to_rs_features[config.pooling.value].extend( + config.feature_names + ) self.module_sharding_plan: EmbeddingModuleShardingPlan = cast( EmbeddingModuleShardingPlan, @@ -472,6 +484,16 @@ def __init__( self._uncombined_embedding_names: List[str] = [] self._uncombined_embedding_dims: List[int] = [] self._inverse_indices_permute_indices: Optional[torch.Tensor] = None + # to support mean pooling callback hook + self._has_mean_pooling_callback: bool = ( + True + if PoolingType.MEAN.value in self._pooling_type_to_rs_features + else False + ) + self._dim_per_key: Optional[torch.Tensor] = None + self._kjt_key_indices: Dict[str, int] = {} + self._kjt_inverse_order: Optional[torch.Tensor] = None + self._kt_key_ordering: Optional[torch.Tensor] = None # to support the FP16 hook self._create_output_dist() @@ -720,6 +742,38 @@ def _create_input_dist( persistent=False, ) + def _init_mean_pooling_callback( + self, + input_feature_names: List[str], + inverse_indices: Optional[Tuple[List[str], torch.Tensor]], + ) -> None: + # account for shared features + feature_names: List[str] = [ + feature_name + for sharding in self._sharding_type_to_sharding.values() + for feature_name in sharding.feature_names() + ] + + for i, key in enumerate(feature_names): + if key not in self._kjt_key_indices: # index of first occurence + self._kjt_key_indices[key] = i + + keyed_tensor_ordering = [] + for key in self._embedding_names: + if "@" in key: + key = key.split("@")[0] + keyed_tensor_ordering.append(self._kjt_key_indices[key]) + self._kt_key_ordering = torch.tensor(keyed_tensor_ordering, device=self._device) + + if inverse_indices: + key_to_inverse_index = { + name: i for i, name in enumerate(inverse_indices[0]) + } + self._kjt_inverse_order = torch.tensor( + [key_to_inverse_index[key] for key in feature_names], + device=self._device, + ) + def _create_lookups( self, ) -> None: @@ -737,6 +791,7 @@ def _create_output_dist(self) -> None: ) self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims()) embedding_shard_metadata.extend(sharding.embedding_shard_metadata()) + self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device) embedding_shard_offsets: List[int] = [ meta.shard_offsets[1] if meta is not None else 0 for meta in embedding_shard_metadata @@ -789,12 +844,31 @@ def input_dist( self._has_uninitialized_input_dist = False if ctx.variable_batch_per_feature: self._create_inverse_indices_permute_indices(ctx.inverse_indices) + if self._has_mean_pooling_callback: + self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices) with torch.no_grad(): if self._has_features_permute: features = features.permute( self._features_order, self._features_order_tensor, ) + if self._has_mean_pooling_callback: + ctx.mean_pooling_callback = _create_mean_pooling_callback( + lengths=features.lengths(), + stride=features.stride(), + keys=features.keys(), + pooling_type_to_rs_features=self._pooling_type_to_rs_features, + stride_per_key=features.stride_per_key(), + dim_per_key=self._dim_per_key, # pyre-ignore[6] + embedding_names=self._embedding_names, + embedding_dims=self._embedding_dims, + variable_batch_per_feature=ctx.variable_batch_per_feature, + kjt_inverse_order=self._kjt_inverse_order, # pyre-ignore[6] + kjt_key_indices=self._kjt_key_indices, + kt_key_ordering=self._kt_key_ordering, # pyre-ignore[6] + inverse_indices=ctx.inverse_indices, + ) + features_by_shards = features.split( self._feature_splits, ) @@ -840,7 +914,7 @@ def output_dist( assert ( ctx.inverse_indices is not None ), "inverse indices must be provided from KJT if using variable batch size per feature." - return VariableBatchEmbeddingBagCollectionAwaitable( + awaitable = VariableBatchEmbeddingBagCollectionAwaitable( awaitables=awaitables, inverse_indices=ctx.inverse_indices, inverse_indices_permute_indices=self._inverse_indices_permute_indices, @@ -851,12 +925,18 @@ def output_dist( permute_op=self._permute_op, ) else: - return EmbeddingBagCollectionAwaitable( + awaitable = EmbeddingBagCollectionAwaitable( awaitables=awaitables, embedding_dims=self._embedding_dims, embedding_names=self._embedding_names, ) + # register callback if there are features that need mean pooling + if self._has_mean_pooling_callback: + awaitable.callbacks.append(ctx.mean_pooling_callback) + + return awaitable + def compute_and_output_dist( self, ctx: EmbeddingBagCollectionContext, input: KJTList ) -> LazyAwaitable[KeyedTensor]: @@ -879,7 +959,7 @@ def compute_and_output_dist( assert ( ctx.inverse_indices is not None ), "inverse indices must be provided from KJT if using variable batch size per feature." - return VariableBatchEmbeddingBagCollectionAwaitable( + awaitable = VariableBatchEmbeddingBagCollectionAwaitable( awaitables=awaitables, inverse_indices=ctx.inverse_indices, inverse_indices_permute_indices=self._inverse_indices_permute_indices, @@ -890,12 +970,18 @@ def compute_and_output_dist( permute_op=self._permute_op, ) else: - return EmbeddingBagCollectionAwaitable( + awaitable = EmbeddingBagCollectionAwaitable( awaitables=awaitables, embedding_dims=self._embedding_dims, embedding_names=self._embedding_names, ) + # register callback if there are features that need mean pooling + if self._has_mean_pooling_callback: + awaitable.callbacks.append(ctx.mean_pooling_callback) + + return awaitable + @property def fused_optimizer(self) -> KeyedOptimizer: return self._optim @@ -1166,3 +1252,82 @@ def shardable_parameters(self, module: nn.EmbeddingBag) -> Dict[str, nn.Paramete @property def module_type(self) -> Type[nn.EmbeddingBag]: return nn.EmbeddingBag + + +def _create_mean_pooling_callback( + lengths: torch.Tensor, + keys: List[str], + stride: int, + stride_per_key: List[int], + dim_per_key: torch.Tensor, + pooling_type_to_rs_features: Dict[str, List[str]], + embedding_names: List[str], + embedding_dims: List[int], + variable_batch_per_feature: bool, + kjt_inverse_order: torch.Tensor, + kjt_key_indices: Dict[str, int], + kt_key_ordering: torch.Tensor, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, +) -> Callable[[KeyedTensor], KeyedTensor]: + with record_function("## ebc create mean pooling callback ##"): + batch_size = ( + inverse_indices[1].size(dim=1) if variable_batch_per_feature else stride # pyre-ignore[16] + ) + + if variable_batch_per_feature: + device = inverse_indices[1].device + inverse_indices_t = inverse_indices[1] + if len(keys) != len(inverse_indices[0]): + inverse_indices_t = torch.index_select( + inverse_indices[1], 0, kjt_inverse_order + ) + offsets = _to_offsets(torch.tensor(stride_per_key, device=device))[ + :-1 + ].unsqueeze(-1) + indices = (inverse_indices_t + offsets).flatten() + lengths = torch.index_select(input=lengths, dim=0, index=indices) + + # only convert the sum pooling features to be 1 lengths + for feature in pooling_type_to_rs_features[PoolingType.SUM.value]: + feature_index = kjt_key_indices[feature] + feature_index = feature_index * batch_size + lengths[feature_index : feature_index + batch_size] = 1 + + if len(embedding_names) != len(keys): + lengths = torch.index_select( + lengths.reshape(-1, batch_size), + 0, + kt_key_ordering, + ).reshape(-1) + + # transpose to align features with keyed tensor dim_per_key + lengths = lengths.reshape(-1, batch_size).T # [batch_size, num_features] + output_size = sum(embedding_dims) + + divisor = torch.repeat_interleave( + input=lengths, + repeats=dim_per_key, + dim=1, + output_size=output_size, + ) + eps = 1e-6 # used to safe guard against 0 division + divisor = divisor + eps + + # pyre-ignore[53] + def _apply_mean_pooling(keyed_tensor: KeyedTensor) -> KeyedTensor: + """ + Apply mean pooling to pooled embeddings in RW/TWRW sharding schemes. + This function is applied as a callback to the awaitable + """ + with record_function("## ebc apply mean pooling ##"): + mean_pooled_values = ( + keyed_tensor.values() / divisor + ) # [batch size, num_features * embedding dim] + return KeyedTensor( + keys=keyed_tensor.keys(), + values=mean_pooled_values, + length_per_key=keyed_tensor.length_per_key(), + key_dim=1, + ) + + return _apply_mean_pooling diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index afbfba94c..e78019fe7 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -47,6 +47,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingType, ShardMetadata, ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -488,6 +489,7 @@ def create_lookup( pg=self._pg, device=device if device is not None else self._device, feature_processor=feature_processor, + sharding_type=ShardingType.ROW_WISE, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index c174e4b7f..22651f75a 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -44,6 +44,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingType, ShardMetadata, ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -634,6 +635,7 @@ def create_lookup( pg=self._pg, device=device if device is not None else self._device, feature_processor=feature_processor, + sharding_type=ShardingType.TABLE_ROW_WISE, ) def create_output_dist( diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index e16399804..bf7339192 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -25,7 +25,7 @@ sharding_single_rank_test, ) from torchrec.distributed.types import ModuleSharder, ShardingType -from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_configs import EmbeddingBagConfig, PoolingType from torchrec.test_utils import seed_and_log, skip_if_asan_class @@ -58,6 +58,29 @@ def setUp(self, backend: str = "nccl") -> None: ] self.tables += shared_features_tables + self.mean_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 2) * 8, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + pooling=PoolingType.MEAN, + ) + for i in range(num_features) + ] + + shared_features_tables_mean = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 2) * 8, + name="table_" + str(i + num_features), + feature_names=["feature_" + str(i)], + pooling=PoolingType.MEAN, + ) + for i in range(shared_features) + ] + self.mean_tables += shared_features_tables_mean + self.weighted_tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, @@ -105,13 +128,14 @@ def _test_sharding( variable_batch_per_feature: bool = False, has_weighted_tables: bool = True, global_constant_batch: bool = False, + pooling: PoolingType = PoolingType.SUM, ) -> None: self._run_multi_process_test( callable=sharding_single_rank_test, world_size=world_size, local_size=local_size, model_class=model_class, - tables=self.tables, + tables=self.tables if pooling == PoolingType.SUM else self.mean_tables, weighted_tables=self.weighted_tables if has_weighted_tables else None, embedding_groups=self.embedding_groups, sharders=sharders, @@ -168,8 +192,9 @@ def setUp(self, backend: str = "nccl") -> None: ] ), variable_batch_size=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), ) - @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) def test_sharding_rw( self, sharder_type: str, @@ -179,6 +204,7 @@ def test_sharding_rw( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, + pooling: PoolingType, ) -> None: if self.backend == "gloo": self.skipTest( @@ -190,6 +216,7 @@ def test_sharding_rw( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size ) + self._test_sharding( sharders=[ cast( @@ -207,6 +234,7 @@ def test_sharding_rw( backend=self.backend, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, + pooling=pooling, ) # pyre-fixme[56] @@ -510,8 +538,9 @@ def test_sharding_tw( ] ), variable_batch_size=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), ) - @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) def test_sharding_twrw( self, sharder_type: str, @@ -521,6 +550,7 @@ def test_sharding_twrw( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, + pooling: PoolingType, ) -> None: if self.backend == "gloo": self.skipTest( @@ -547,6 +577,7 @@ def test_sharding_twrw( qcomms_config=qcomms_config, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, + pooling=pooling, ) @unittest.skipIf( @@ -559,14 +590,12 @@ def test_sharding_twrw( [ ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value, - ShardingType.ROW_WISE.value, - ShardingType.TABLE_ROW_WISE.value, ShardingType.TABLE_COLUMN_WISE.value, ] ), global_constant_batch=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_variable_batch( self, sharding_type: str, @@ -596,3 +625,51 @@ def test_sharding_variable_batch( has_weighted_tables=False, global_constant_batch=global_constant_batch, ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ] + ), + global_constant_batch=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_sharding_variable_batch_twrw( + self, + sharding_type: str, + global_constant_batch: bool, + pooling: PoolingType, + ) -> None: + if self.backend == "gloo": + # error is from FBGEMM, it says CPU even if we are on GPU. + self.skipTest( + "bounds_check_indices on CPU does not support variable length (batch size)" + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type=SharderType.EMBEDDING_BAG_COLLECTION.value, + sharding_type=sharding_type, + kernel_type=EmbeddingComputeKernel.FUSED.value, + device=self.device, + ), + ], + backend=self.backend, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + variable_batch_per_feature=True, + has_weighted_tables=False, + global_constant_batch=global_constant_batch, + pooling=pooling, + ) diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index 43ea508e3..01231681d 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -26,6 +26,28 @@ class PoolingType(Enum): NONE = "NONE" +# TODO - duplicated, move elsewhere to remove circular dependencies +class ShardingType(Enum): + """ + Well-known sharding types, used by inter-module optimizations. + """ + + # Replicated on all ranks + DATA_PARALLEL = "data_parallel" + # Placed on a single rank + TABLE_WISE = "table_wise" + # Placed on multiple ranks as different sharded tables + COLUMN_WISE = "column_wise" + # Range-split on the first dimension across all ranks + ROW_WISE = "row_wise" + # Row-wise on the same node and table-wise across nodes + # Useful when having multiple ranks per node + # and comms within a single node are more efficient than across nodes. + TABLE_ROW_WISE = "table_row_wise" + # Column-wise on the same node and table-wise across nodes + TABLE_COLUMN_WISE = "table_column_wise" + + DATA_TYPE_NUM_BITS: Dict[DataType, int] = { DataType.FP32: 32, DataType.FP16: 16, @@ -60,10 +82,19 @@ def dtype_to_data_type(dtype: torch.dtype) -> DataType: raise Exception(f"Invalid data type {dtype}") -def pooling_type_to_pooling_mode(pooling_type: PoolingType) -> PoolingMode: +def pooling_type_to_pooling_mode( + pooling_type: PoolingType, sharding_type: Optional[ShardingType] = None +) -> PoolingMode: if pooling_type.value == PoolingType.SUM.value: return PoolingMode.SUM elif pooling_type.value == PoolingType.MEAN.value: + if sharding_type is not None and sharding_type.value in [ + ShardingType.TABLE_ROW_WISE.value, + ShardingType.ROW_WISE.value, + ]: + # Mean pooling is not supported in TBE for TWRW/RW sharding. + # Pass 'SUM' as a workaround, and apply mean pooling as a callback in EBC. + return PoolingMode.SUM return PoolingMode.MEAN elif pooling_type.value == PoolingType.NONE.value: return PoolingMode.NONE