diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index bd71dcb3f..051e7ccbb 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -13,6 +13,7 @@ from functools import partial from typing import ( Any, + Callable, cast, Dict, Iterator, @@ -37,6 +38,7 @@ EmbeddingShardingInfo, KJTListSplitsAwaitable, Multistreamable, + USE_ONE_TBE_PER_TABLE, ) from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, @@ -73,6 +75,7 @@ optimizer_type_to_emb_opt_type, ) from torchrec.modules.embedding_configs import ( + BaseEmbeddingConfig, EmbeddingBagConfig, EmbeddingTableConfig, PoolingType, @@ -141,7 +144,6 @@ def replace_placement_with_meta_device( def create_embedding_bag_sharding( - sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[torch.device] = None, @@ -150,6 +152,7 @@ def create_embedding_bag_sharding( ) -> EmbeddingSharding[ EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ]: + sharding_type = sharding_infos[0].param_sharding.sharding_type if device is not None and device.type == "meta": replace_placement_with_meta_device(sharding_infos) if sharding_type == ShardingType.TABLE_WISE.value: @@ -195,12 +198,48 @@ def create_embedding_bag_sharding( raise ValueError(f"Sharding type not supported {sharding_type}") -def create_sharding_infos_by_sharding( +def get_sharding_group( + config: BaseEmbeddingConfig, + param_sharding: ParameterSharding, + fused_params: Optional[Dict[str, Any]] = None, +) -> str: + if fused_params and fused_params.get(USE_ONE_TBE_PER_TABLE, False): + return config.name + if param_sharding.sharding_type in { + ShardingType.COLUMN_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + }: + assert param_sharding.ranks + num_ranks = len(param_sharding.ranks) + assert config.embedding_dim % num_ranks == 0 + dim = config.embedding_dim // num_ranks + else: + dim = config.embedding_dim + + group = f"{param_sharding.sharding_type}@{param_sharding.compute_kernel}" + if ( + param_sharding.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + and ( + (fused_params and fused_params.get("prefetch_pipeline", False)) + or ( + param_sharding.cache_params + and param_sharding.cache_params.prefetch_pipeline + ) + ) + ): + group += f"@{dim}" + return group + + +def create_sharding_infos_by_group( module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], prefix: str, fused_params: Optional[Dict[str, Any]], suffix: Optional[str] = "weight", + group_fn: Optional[ + Callable[[EmbeddingBagConfig, ParameterSharding, Optional[Dict[str, Any]]], str] + ] = None, ) -> Dict[str, List[EmbeddingShardingInfo]]: if fused_params is None: @@ -216,7 +255,7 @@ def create_sharding_infos_by_sharding( else: shared_feature[feature_name] = True - sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {} + group_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = defaultdict(list) # state_dict returns parameter.Tensor, which loses parameter level attributes parameter_by_name = dict(module.named_parameters()) @@ -249,9 +288,6 @@ def create_sharding_infos_by_sharding( assert param_name in parameter_by_name or param_name in state_dict param = parameter_by_name.get(param_name, state_dict[param_name]) - if parameter_sharding.sharding_type not in sharding_type_to_sharding_infos: - sharding_type_to_sharding_infos[parameter_sharding.sharding_type] = [] - optimizer_params = getattr(param, "_optimizer_kwargs", [{}]) optimizer_classes = getattr(param, "_optimizer_classes", [None]) @@ -273,28 +309,32 @@ def create_sharding_infos_by_sharding( ) per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params) - sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append( - EmbeddingShardingInfo( - embedding_config=EmbeddingTableConfig( - num_embeddings=config.num_embeddings, - embedding_dim=config.embedding_dim, - name=config.name, - data_type=config.data_type, - feature_names=copy.deepcopy(config.feature_names), - pooling=config.pooling, - is_weighted=module.is_weighted(), - has_feature_processor=False, - embedding_names=embedding_names, - weight_init_max=config.weight_init_max, - weight_init_min=config.weight_init_min, - pruning_indices_remapping=config.pruning_indices_remapping, - ), - param_sharding=parameter_sharding, - param=param, - fused_params=per_table_fused_params, - ) + group = ( + group_fn(config, parameter_sharding, fused_params) + if group_fn is not None + else parameter_sharding.sharding_type ) - return sharding_type_to_sharding_infos + sharding_info = EmbeddingShardingInfo( + embedding_config=EmbeddingTableConfig( + num_embeddings=config.num_embeddings, + embedding_dim=config.embedding_dim, + name=config.name, + data_type=config.data_type, + feature_names=copy.deepcopy(config.feature_names), + pooling=config.pooling, + is_weighted=module.is_weighted(), + has_feature_processor=False, + embedding_names=embedding_names, + weight_init_max=config.weight_init_max, + weight_init_min=config.weight_init_min, + pruning_indices_remapping=config.pruning_indices_remapping, + ), + param_sharding=parameter_sharding, + param=param, + fused_params=per_table_fused_params, + ) + group_to_sharding_infos[group].append(sharding_info) + return group_to_sharding_infos def create_sharding_infos_by_sharding_device_group( @@ -571,31 +611,30 @@ def __init__( ) self._env = env - sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( + group_to_sharding_infos = create_sharding_infos_by_group( module, table_name_to_parameter_sharding, "embedding_bags.", fused_params, + group_fn=get_sharding_group, ) - self._sharding_type_to_sharding: Dict[ - str, + self._embedding_shardings: List[ EmbeddingSharding[ EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor, - ], - ] = { - sharding_type: create_embedding_bag_sharding( - sharding_type, + ] + ] = [ + create_embedding_bag_sharding( embedding_configs, env, device, permute_embeddings=True, qcomm_codecs_registry=self.qcomm_codecs_registry, ) - for sharding_type, embedding_configs in sharding_type_to_sharding_infos.items() - } + for embedding_configs in group_to_sharding_infos.values() + ] self._is_weighted: bool = module.is_weighted() self._device = device @@ -640,15 +679,12 @@ def __init__( optims.append(("", tbe_module.fused_optimizer)) self._optim: CombinedOptimizer = CombinedOptimizer(optims) - for index, (sharding, lookup) in enumerate( - zip( - self._sharding_type_to_sharding.values(), - self._lookups, - ) + for i, (sharding, lookup) in enumerate( + zip(self._embedding_shardings, self._lookups) ): # TODO: can move this into DpPooledEmbeddingSharding once all modules are composable if isinstance(sharding, DpPooledEmbeddingSharding): - self._lookups[index] = DistributedDataParallel( + self._lookups[i] = DistributedDataParallel( module=lookup, device_ids=( [device] @@ -770,10 +806,8 @@ def _initialize_torch_state(self) -> None: # noqa table.embedding_dim, ) - for sharding_type, lookup in zip( - self._sharding_type_to_sharding.keys(), self._lookups - ): - if sharding_type == ShardingType.DATA_PARALLEL.value: + for lookup, sharding in zip(self._lookups, self._embedding_shardings): + if isinstance(sharding, DpPooledEmbeddingSharding): # unwrap DDP lookup = lookup.module else: @@ -864,7 +898,7 @@ def _create_input_dist( input_feature_names: List[str], ) -> None: feature_names: List[str] = [] - for sharding in self._sharding_type_to_sharding.values(): + for sharding in self._embedding_shardings: self._input_dists.append(sharding.create_input_dist()) feature_names.extend(sharding.feature_names()) self._feature_splits.append(len(sharding.feature_names())) @@ -890,7 +924,7 @@ def _init_mean_pooling_callback( # account for shared features feature_names: List[str] = [ feature_name - for sharding in self._sharding_type_to_sharding.values() + for sharding in self._embedding_shardings for feature_name in sharding.feature_names() ] @@ -917,12 +951,12 @@ def _init_mean_pooling_callback( def _create_lookups( self, ) -> None: - for sharding in self._sharding_type_to_sharding.values(): + for sharding in self._embedding_shardings: self._lookups.append(sharding.create_lookup()) def _create_output_dist(self) -> None: embedding_shard_metadata: List[Optional[ShardMetadata]] = [] - for sharding in self._sharding_type_to_sharding.values(): + for sharding in self._embedding_shardings: self._output_dists.append(sharding.create_output_dist(device=self._device)) self._embedding_names.extend(sharding.embedding_names()) self._embedding_dims.extend(sharding.embedding_dims()) @@ -1236,7 +1270,6 @@ def __init__( self._embedding_sharding: EmbeddingSharding[ EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] = create_embedding_bag_sharding( - sharding_type=self.parameter_sharding.sharding_type, sharding_infos=[ EmbeddingShardingInfo( embedding_config=embedding_table_config, diff --git a/torchrec/distributed/fused_embeddingbag.py b/torchrec/distributed/fused_embeddingbag.py index ec77d7cbc..12c44ae8e 100644 --- a/torchrec/distributed/fused_embeddingbag.py +++ b/torchrec/distributed/fused_embeddingbag.py @@ -65,7 +65,7 @@ def __init__( ) for index, (sharding, lookup) in enumerate( - zip(self._sharding_type_to_sharding.values(), self._lookups) + zip(self._embedding_shardings, self._lookups) ): if isinstance(sharding, DpPooledEmbeddingSharding): self._lookups[index] = DistributedDataParallel( diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py index df5211a3c..563e70fcf 100644 --- a/torchrec/distributed/mc_embedding_modules.py +++ b/torchrec/distributed/mc_embedding_modules.py @@ -109,14 +109,18 @@ def __init__( # TODO: This is a hack since _embedding_module doesn't need input # dist, so eliminating it so all fused a2a will ignore it. self._embedding_module._has_uninitialized_input_dist = False + embedding_shardings = ( + self._embedding_module._embedding_shardings + if isinstance(self._embedding_module, ShardedEmbeddingBagCollection) + else list(self._embedding_module._sharding_type_to_sharding.values()) + ) self._managed_collision_collection: ShardedManagedCollisionCollection = ( mc_sharder.shard( module._managed_collision_collection, table_name_to_parameter_sharding, env=env, device=device, - # pyre-ignore - sharding_type_to_sharding=self._embedding_module._sharding_type_to_sharding, + embedding_shardings=embedding_shardings, ) ) self._return_remapped_features: bool = module._return_remapped_features diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 4b40e3385..aa7a2c07e 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -133,14 +133,13 @@ def __init__( table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, device: torch.device, - sharding_type_to_sharding: Dict[ - str, + embedding_shardings: List[ EmbeddingSharding[ EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor, - ], + ] ], qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: @@ -152,13 +151,13 @@ def __init__( copy.deepcopy(table_name_to_parameter_sharding) ) # TODO: create a MCSharding type instead of leveraging EmbeddingSharding - self._sharding_type_to_sharding = sharding_type_to_sharding + self._embedding_shardings = embedding_shardings self._embedding_names_per_sharding: List[List[str]] = [] - for sharding_type, sharding in self._sharding_type_to_sharding.items(): + for sharding in self._embedding_shardings: # TODO: support TWRW sharding - assert ( - sharding_type == ShardingType.ROW_WISE.value + assert isinstance( + sharding, BaseRwEmbeddingSharding ), "Only ROW_WISE sharding is supported." self._embedding_names_per_sharding.append(sharding.embedding_names()) @@ -263,71 +262,70 @@ def _create_managed_collision_modules( ] = defaultdict(lambda: defaultdict(list)) self._feature_to_offset: Dict[str, int] = {} - for sharding_type, sharding in self._sharding_type_to_sharding.items(): - if sharding_type == ShardingType.ROW_WISE.value: - assert isinstance(sharding, BaseRwEmbeddingSharding) + for sharding in self._embedding_shardings: + assert isinstance(sharding, BaseRwEmbeddingSharding) - grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( - sharding._grouped_embedding_configs - ) - for group_config in grouped_embedding_configs: - for table in group_config.embedding_tables: - # pyre-ignore [16] - new_min_output_id = table.local_metadata.shard_offsets[0] - # pyre-ignore [16] - new_range_size = table.local_metadata.shard_sizes[0] - - mc_module = module._managed_collision_modules[table.name] - - # TODO: - # 1) need to make TBE accept global indices for now force to local indices - # 2) MCH is particularly nasty with a portion of each shard; ideally dont do this - # 3) now create a feature_to_offset and pass into awaitable callbacks to act as raw id adder - self._managed_collision_modules[table.name] = ( - mc_module.rebuild_with_output_id_range( - output_id_range=( - 0, # new_min_output_id, - new_range_size, # new_min_output_id + new_range_size, - ), - device=self._device, - ) + grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + sharding._grouped_embedding_configs + ) + for group_config in grouped_embedding_configs: + for table in group_config.embedding_tables: + # pyre-ignore [16] + new_min_output_id = table.local_metadata.shard_offsets[0] + # pyre-ignore [16] + new_range_size = table.local_metadata.shard_sizes[0] + + mc_module = module._managed_collision_modules[table.name] + + # TODO: + # 1) need to make TBE accept global indices for now force to local indices + # 2) MCH is particularly nasty with a portion of each shard; ideally dont do this + # 3) now create a feature_to_offset and pass into awaitable callbacks to act as raw id adder + self._managed_collision_modules[table.name] = ( + mc_module.rebuild_with_output_id_range( + output_id_range=( + 0, # new_min_output_id, + new_range_size, # new_min_output_id + new_range_size, + ), + device=self._device, ) - zch_size = self._managed_collision_modules[table.name]._zch_size - - zch_size_by_rank = [ - torch.zeros(1, dtype=torch.int64, device=self._device) - for _ in range(self._env.world_size) - ] - if self._env.world_size > 1: - dist.all_gather( - zch_size_by_rank, - torch.tensor( - [zch_size], dtype=torch.int64, device=self._device - ), - group=self._env.process_group, - ) - else: - zch_size_by_rank[0] = torch.tensor( + ) + zch_size = self._managed_collision_modules[table.name]._zch_size + + zch_size_by_rank = [ + torch.zeros(1, dtype=torch.int64, device=self._device) + for _ in range(self._env.world_size) + ] + if self._env.world_size > 1: + dist.all_gather( + zch_size_by_rank, + torch.tensor( [zch_size], dtype=torch.int64, device=self._device - ) + ), + group=self._env.process_group, + ) + else: + zch_size_by_rank[0] = torch.tensor( + [zch_size], dtype=torch.int64, device=self._device + ) - # Calculate the sum of all ZCH sizes from rank 0 to list - # index. The last item is the sum of all elements in zch_size_by_rank - zch_size_cumsum = torch.cumsum( - torch.cat(zch_size_by_rank), dim=0 - ).tolist() + # Calculate the sum of all ZCH sizes from rank 0 to list + # index. The last item is the sum of all elements in zch_size_by_rank + zch_size_cumsum = torch.cumsum( + torch.cat(zch_size_by_rank), dim=0 + ).tolist() - zch_size_sum_before_this_rank = ( - zch_size_cumsum[self._env.rank] - zch_size - ) + zch_size_sum_before_this_rank = ( + zch_size_cumsum[self._env.rank] - zch_size + ) - self._mc_module_name_shard_metadata[table.name] = ( - zch_size_sum_before_this_rank, - zch_size, - zch_size_cumsum[-1], - ) - for feature in table.feature_names: - self._feature_to_offset[feature] = new_min_output_id + self._mc_module_name_shard_metadata[table.name] = ( + zch_size_sum_before_this_rank, + zch_size, + zch_size_cumsum[-1], + ) + for feature in table.feature_names: + self._feature_to_offset[feature] = new_min_output_id def _create_input_dists( self, @@ -335,31 +333,26 @@ def _create_input_dists( ) -> None: feature_names: List[str] = [] self._feature_splits: List[int] = [] - for sharding_type, sharding in self._sharding_type_to_sharding.items(): - if sharding_type == ShardingType.ROW_WISE.value: - feature_hash_sizes: List[int] = [ - self._managed_collision_modules[ - self._feature_to_table[f] - ].input_size() - for f in sharding.feature_names() - ] - - input_dist = RwSparseFeaturesDist( - # pyre-ignore [16] - pg=sharding._pg, - # pyre-ignore [16] - num_features=sharding._get_num_features(), - feature_hash_sizes=feature_hash_sizes, - # pyre-ignore [16] - device=sharding._device, - is_sequence=True, - # pyre-ignore [16] - has_feature_processor=sharding._has_feature_processor, - need_pos=False, - ) - self._input_dists.append(input_dist) - feature_names.extend(sharding.feature_names()) - self._feature_splits.append(len(sharding.feature_names())) + for sharding in self._embedding_shardings: + assert isinstance(sharding, BaseRwEmbeddingSharding) + feature_hash_sizes: List[int] = [ + self._managed_collision_modules[self._feature_to_table[f]].input_size() + for f in sharding.feature_names() + ] + + input_dist = RwSparseFeaturesDist( + # pyre-ignore [6] + pg=sharding._pg, + num_features=sharding._get_num_features(), + feature_hash_sizes=feature_hash_sizes, + device=sharding._device, + is_sequence=True, + has_feature_processor=sharding._has_feature_processor, + need_pos=False, + ) + self._input_dists.append(input_dist) + feature_names.extend(sharding.feature_names()) + self._feature_splits.append(len(sharding.feature_names())) self._features_order: List[int] = [] for f in feature_names: @@ -378,18 +371,16 @@ def _create_input_dists( def _create_output_dists( self, ) -> None: - for sharding_type, sharding in self._sharding_type_to_sharding.items(): - if sharding_type == ShardingType.ROW_WISE.value: - self._output_dists.append( - RwSequenceEmbeddingDist( - # pyre-ignore [16] - sharding._pg, - # pyre-ignore [16] - sharding._get_num_features(), - # pyre-ignore [16] - sharding._device, - ) + for sharding in self._embedding_shardings: + assert isinstance(sharding, BaseRwEmbeddingSharding) + self._output_dists.append( + RwSequenceEmbeddingDist( + # pyre-ignore [6] + sharding._pg, + sharding._get_num_features(), + sharding._device, ) + ) # pyre-ignore [14] def input_dist( @@ -541,14 +532,13 @@ def shard( module: ManagedCollisionCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, - sharding_type_to_sharding: Dict[ - str, + embedding_shardings: List[ EmbeddingSharding[ EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor, - ], + ] ], device: Optional[torch.device] = None, ) -> ShardedManagedCollisionCollection: @@ -561,7 +551,7 @@ def shard( params, env=env, device=device, - sharding_type_to_sharding=sharding_type_to_sharding, + embedding_shardings=embedding_shardings, ) def shardable_parameters( diff --git a/torchrec/distributed/planner/tests/test_embeddingbag_utils.py b/torchrec/distributed/planner/tests/test_embeddingbag_utils.py index 4a37161dd..9b9ebfb29 100644 --- a/torchrec/distributed/planner/tests/test_embeddingbag_utils.py +++ b/torchrec/distributed/planner/tests/test_embeddingbag_utils.py @@ -11,7 +11,7 @@ import unittest from torchrec.distributed.embeddingbag import ( - create_sharding_infos_by_sharding, + create_sharding_infos_by_group, EmbeddingBagCollectionSharder, ) from torchrec.distributed.planner import ( @@ -79,21 +79,21 @@ def setUp(self) -> None: ) self.expected_plan = planner.plan(self.model, [self.sharder]) # pyre-ignore[6] - self.expected_sharding_infos = create_sharding_infos_by_sharding( + self.expected_sharding_infos = create_sharding_infos_by_group( self.model, self.expected_plan.get_plan_for_module(""), # pyre-ignore[6] prefix="embedding_bags.", fused_params=None, ) - def test_create_sharding_infos_by_sharding_override(self) -> None: + def test_create_sharding_infos_by_group_override(self) -> None: """ Test that fused_params from sharders get overridden. """ # with sharder fused params that will get overridden sharder_fused_params = {"enforce_hbm": False} - overriden_sharding_infos = create_sharding_infos_by_sharding( + overriden_sharding_infos = create_sharding_infos_by_group( self.model, self.expected_plan.get_plan_for_module(""), prefix="embedding_bags.", @@ -106,7 +106,7 @@ def test_create_sharding_infos_by_sharding_override(self) -> None: # with sharder fused params that won't get overridden sharder_fused_params = {"ABC": True} - not_overriden_sharding_infos = create_sharding_infos_by_sharding( + not_overriden_sharding_infos = create_sharding_infos_by_group( self.model, self.expected_plan.get_plan_for_module(""), prefix="embedding_bags.", @@ -120,7 +120,7 @@ def test_create_sharding_infos_by_sharding_override(self) -> None: for a, b in zip(expected_sharding_info, not_overriden_sharding_info): self.assertNotEqual(a.fused_params, b.fused_params) - def test_create_sharding_infos_by_sharding_combine(self) -> None: + def test_create_sharding_infos_by_group_combine(self) -> None: """ Test that fused_params can get info from both sharder and constraints. """ @@ -141,7 +141,7 @@ def test_create_sharding_infos_by_sharding_combine(self) -> None: # provide that two fused params from sharder sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": False} - combined_sharding_infos = create_sharding_infos_by_sharding( + combined_sharding_infos = create_sharding_infos_by_group( self.model, new_plan.get_plan_for_module(""), # pyre-ignore[6] prefix="embedding_bags.", @@ -156,7 +156,7 @@ def test_create_sharding_infos_by_sharding_combine(self) -> None: # provide that two fused params from sharder wrongly sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": True} - wrong_combined_sharding_infos = create_sharding_infos_by_sharding( + wrong_combined_sharding_infos = create_sharding_infos_by_group( self.model, new_plan.get_plan_for_module(""), # pyre-ignore[6] prefix="embedding_bags.", diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index c6d81b8c2..7ea244a66 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -15,6 +15,7 @@ from fbgemm_gpu.split_embedding_configs import EmbOptimType from hypothesis import assume, given, settings, strategies as st, Verbosity from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig from torchrec.distributed.planner import ParameterConstraints from torchrec.distributed.test_utils.multi_process import MultiProcessTestBase @@ -630,3 +631,34 @@ def test_sharding_variable_batch( global_constant_batch=global_constant_batch, pooling=pooling, ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given(sharding_type=st.just(ShardingType.COLUMN_WISE.value)) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_multiple_kernels(self, sharding_type: str) -> None: + if self.backend == "gloo": + self.skipTest("ProcessGroupGloo does not support reduce_scatter") + + constraints = { + table.name: ParameterConstraints( + min_partition=4, + compute_kernels=( + [EmbeddingComputeKernel.FUSED.value] + if i % 2 == 0 + else [EmbeddingComputeKernel.FUSED_UVM_CACHING.value] + ), + ) + for i, table in enumerate(self.tables) + } + self._test_sharding( + # pyre-ignore[6] + sharders=[EmbeddingBagCollectionSharder()], + backend=self.backend, + constraints=constraints, + variable_batch_per_feature=True, + has_weighted_tables=False, + )