diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 45ed9a793..ae41259b5 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -41,6 +41,7 @@ SplitTableBatchedEmbeddingBagsCodegen, ) from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags +from fbgemm_gpu.tbe.ssd.training import BackendType, KVZCHParams from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import ( PartiallyMaterializedTensor, ) @@ -50,7 +51,11 @@ from torchrec.distributed.composable.table_batched_embedding_slice import ( TableBatchedEmbeddingSlice, ) -from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict +from torchrec.distributed.embedding_kernel import ( + BaseEmbedding, + create_virtual_sharded_tensors, + get_state_dict, +) from torchrec.distributed.embedding_types import ( compute_kernel_to_embedding_location, DTensorMetadata, @@ -65,7 +70,7 @@ ShardMetadata, TensorProperties, ) -from torchrec.distributed.utils import append_prefix +from torchrec.distributed.utils import append_prefix, none_throws from torchrec.modules.embedding_configs import ( data_type_to_sparse_type, pooling_type_to_pooling_mode, @@ -169,6 +174,24 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: return ssd_tbe_params +def _populate_zero_collision_tbe_params( + tbe_params: Dict[str, Any], + sharded_local_buckets: List[Tuple[int, int, int]], +) -> None: + """ + Construct Zero Collision TBE params from config and fused params dict. + """ + bucket_offsets: List[Tuple[int, int]] = [ + (offset_start, offset_end) + for offset_start, offset_end, _ in sharded_local_buckets + ] + bucket_sizes: List[int] = [size for _, _, size in sharded_local_buckets] + + tbe_params["kv_zch_params"] = KVZCHParams( + bucket_offsets=bucket_offsets, bucket_sizes=bucket_sizes + ) + + class KeyValueEmbeddingFusedOptimizer(FusedOptimizer): def __init__( self, @@ -676,24 +699,6 @@ def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: self._emb_module.update_hyper_parameters(params_dict) -def _gen_named_parameters_by_table_ssd( - emb_module: SSDTableBatchedEmbeddingBags, - table_name_to_count: Dict[str, int], - config: GroupedEmbeddingConfig, - pg: Optional[dist.ProcessGroup] = None, -) -> Iterator[Tuple[str, nn.Parameter]]: - """ - Return an empty tensor to indicate that the table is on remote device. - """ - for table in config.embedding_tables: - table_name = table.name - # placeholder - weight: nn.Parameter = nn.Parameter(torch.empty(0)) - # pyre-ignore - weight._in_backward_optimizers = [EmptyFusedOptimizer()] - yield (table_name, weight) - - def _gen_named_parameters_by_table_ssd_pmt( emb_module: SSDTableBatchedEmbeddingBags, table_name_to_count: Dict[str, int], @@ -956,6 +961,10 @@ def __init__( **ssd_tbe_params, ).to(device) + logger.info( + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" + ) + self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer( config, self._emb_module, @@ -1064,6 +1073,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat Return an iterator over embedding tables, yielding both the table name as well as the embedding table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid RocksDB snapshot to support windowed access. + optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch + optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch """ for config, tensor in zip( self._config.embedding_tables, @@ -1095,6 +1106,280 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[ return self.emb_module.split_embedding_weights(no_snapshot) +class ZeroCollisionKeyValueEmbedding( + BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule +): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + backend_type: BackendType = BackendType.SSD, + ) -> None: + super().__init__(config, pg, device) + + assert ( + len(config.embedding_tables) > 0 + ), "Expected to see at least one table in SSD TBE, but found 0." + assert ( + len({table.embedding_dim for table in config.embedding_tables}) == 1 + ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + + ssd_tbe_params = _populate_ssd_tbe_params(config) + self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets() + _populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec) + compute_kernel = config.embedding_tables[0].compute_kernel + embedding_location = compute_kernel_to_embedding_location(compute_kernel) + + self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( + embedding_specs=list(zip(self._num_embeddings, self._local_cols)), + feature_table_map=self._feature_table_map, + ssd_cache_location=embedding_location, + pooling_mode=PoolingMode.NONE, + backend_type=backend_type, + **ssd_tbe_params, + ).to(device) + + logger.info( + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" + ) + + self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer( + config, + self._emb_module, + pg, + ) + self._param_per_table: Dict[str, nn.Parameter] = dict( + _gen_named_parameters_by_table_ssd_pmt( + emb_module=self._emb_module, + table_name_to_count=self.table_name_to_count.copy(), + config=self._config, + pg=pg, + ) + ) + self.init_parameters() + + # every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db + # use split weights result cache so that multiple calls in the same train iteration will only trigger once + self._split_weights_res: Optional[ + Tuple[ + List[ShardedTensor], + List[ShardedTensor], + List[ShardedTensor], + ] + ] = None + + def init_parameters(self) -> None: + """ + An advantage of KV TBE is that we don't need to init weights. Hence skipping. + """ + pass + + @property + def emb_module( + self, + ) -> SSDTableBatchedEmbeddingBags: + return self._emb_module + + @property + def fused_optimizer(self) -> FusedOptimizer: + """ + SSD Embedding fuses backward with backward. + """ + return self._optim + + def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]: + """ + utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec + """ + sharded_local_buckets: List[Tuple[int, int, int]] = [] + world_size = dist.get_world_size(self._pg) + local_rank = dist.get_rank(self._pg) + + for table in self._config.embedding_tables: + total_num_buckets = none_throws(table.total_num_buckets) + assert ( + total_num_buckets % world_size == 0 + ), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}" + assert ( + table.total_num_buckets + and table.num_embeddings % table.total_num_buckets == 0 + ), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'" + bucket_offset_start = total_num_buckets // world_size * local_rank + bucket_offset_end = min( + total_num_buckets, total_num_buckets // world_size * (local_rank + 1) + ) + bucket_size = ( + table.num_embeddings + total_num_buckets - 1 + ) // total_num_buckets + sharded_local_buckets.append( + (bucket_offset_start, bucket_offset_end, bucket_size) + ) + logger.info( + f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}" + ) + return sharded_local_buckets + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + no_snapshot: bool = True, + ) -> Dict[str, Any]: + """ + Args: + no_snapshot (bool): the tensors in the returned dict are + PartiallyMaterializedTensors. this argument controls wether the + PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the + PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the + PartiallyMaterializedTensor has a RocksDB snapshot handle + """ + # in the case no_snapshot=False, a flush is required. we rely on the flush operation in + # ShardedEmbeddingBagCollection._pre_state_dict_hook() + + emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) + for emb_table in emb_table_config_copy: + emb_table.local_metadata.placement._device = torch.device("cpu") + ret = get_state_dict( + emb_table_config_copy, + emb_tables, + self._pg, + destination, + prefix, + ) + return ret + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Only allowed ways to get state_dict. + """ + for name, tensor in self.named_split_embedding_weights( + prefix, recurse, remove_duplicate + ): + # hack before we support optimizer on sharded parameter level + # can delete after PEA deprecation + # pyre-ignore [6] + param = nn.Parameter(tensor) + # pyre-ignore + param._in_backward_optimizers = [EmptyFusedOptimizer()] + yield name, param + + # pyre-ignore [15] + def named_split_embedding_weights( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + assert ( + remove_duplicate + ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights()[0], + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, tensor + + def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[ + Tuple[ + str, + Union[ShardedTensor, PartiallyMaterializedTensor], + Optional[ShardedTensor], + Optional[ShardedTensor], + ] + ]: + """ + Return an iterator over embedding tables, for each table yielding + table name, + PMT for embedding table with a valid RocksDB snapshot to support tensor IO + optional ShardedTensor for weight_id + optional ShardedTensor for bucket_cnt + """ + if self._split_weights_res is not None: + pmt_sharded_t_list = self._split_weights_res[0] + # pyre-ignore + weight_id_sharded_t_list = self._split_weights_res[1] + bucket_cnt_sharded_t_list = self._split_weights_res[2] + for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): + table_config = self._config.embedding_tables[table_idx] + key = append_prefix(prefix, f"{table_config.name}") + + yield key, pmt_sharded_t, weight_id_sharded_t_list[ + table_idx + ], bucket_cnt_sharded_t_list[table_idx] + return + + pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights( + no_snapshot=False + ) + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) + for emb_table in emb_table_config_copy: + emb_table.local_metadata.placement._device = torch.device("cpu") + + pmt_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, pmt_list, self._pg, prefix + ) + weight_id_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore + ) + bucket_cnt_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, bucket_cnt_list, self._pg, prefix # pyre-ignore + ) + # pyre-ignore + assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list) + assert ( + len(pmt_sharded_t_list) + == len(weight_id_sharded_t_list) + == len(bucket_cnt_sharded_t_list) + ) + for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): + table_config = self._config.embedding_tables[table_idx] + key = append_prefix(prefix, f"{table_config.name}") + + yield key, pmt_sharded_t, weight_id_sharded_t_list[ + table_idx + ], bucket_cnt_sharded_t_list[table_idx] + + self._split_weights_res = ( + pmt_sharded_t_list, + weight_id_sharded_t_list, + bucket_cnt_sharded_t_list, + ) + + def flush(self) -> None: + """ + Flush the embeddings in cache back to SSD. Should be pretty expensive. + """ + self.emb_module.flush() + + def purge(self) -> None: + """ + Reset the cache space. This is needed when we load state dict. + """ + # TODO: move the following to SSD TBE. + self.emb_module.lxu_cache_weights.zero_() + self.emb_module.lxu_cache_state.fill_(-1) + + # pyre-ignore [15] + def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[ + List[PartiallyMaterializedTensor], + Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], + ]: + return self.emb_module.split_embedding_weights(no_snapshot) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + # reset split weights during training + self._split_weights_res = None + + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + ) + + class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule): def __init__( self, @@ -1563,6 +1848,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat Return an iterator over embedding tables, yielding both the table name as well as the embedding table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid RocksDB snapshot to support windowed access. + optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch + optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch """ for config, tensor in zip( self._config.embedding_tables, diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index cc381fca8..8b4465a53 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -32,6 +32,7 @@ from torch.distributed._tensor import DTensor from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size +from torchrec.distributed.embedding_lookup import PartiallyMaterializedTensor from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingInfo, @@ -231,6 +232,8 @@ def create_sharding_infos_by_sharding_device_group( embedding_names=embedding_names, weight_init_max=config.weight_init_max, weight_init_min=config.weight_init_min, + total_num_buckets=config.total_num_buckets, + use_virtual_table=config.use_virtual_table, ), param_sharding=parameter_sharding, param=param, @@ -592,6 +595,8 @@ def create_grouped_sharding_infos( embedding_names=embedding_names, weight_init_max=config.weight_init_max, weight_init_min=config.weight_init_min, + total_num_buckets=config.total_num_buckets, + use_virtual_table=config.use_virtual_table, ), param_sharding=parameter_sharding, param=param, @@ -669,6 +674,28 @@ def _pre_load_state_dict_hook( to transform from ShardedTensors/DTensors into tensors """ for table_name in self._model_parallel_name_to_local_shards.keys(): + if self._table_name_to_config[table_name].use_virtual_table: + # weight_id and bucket are generated at the runtime of state_dict instead of registered class + # so we need to erase them before passing into load_state_dict + weight_key = f"{prefix}embeddings.{table_name}.weight" + weight_id_key = f"{prefix}embeddings.{table_name}.weight_id" + bucket_key = f"{prefix}embeddings.{table_name}.bucket" + if weight_id_key in state_dict: + del state_dict[weight_id_key] + if bucket_key in state_dict: + del state_dict[bucket_key] + assert weight_key in state_dict + assert ( + len(self._model_parallel_name_to_local_shards[table_name]) == 1 + ), "currently only support 1 shard per rank" + + # for loading state_dict into virtual table, we skip the weights assignment + # if needed, for now this should be handled separately outside of load_state_dict call + state_dict[weight_key] = self._model_parallel_name_to_local_shards[ + table_name + ][0].tensor + continue + key = f"{prefix}embeddings.{table_name}.weight" # gather model shards from both DTensor and ShardedTensor maps model_shards_sharded_tensor = self._model_parallel_name_to_local_shards[ @@ -823,6 +850,8 @@ def _initialize_torch_state(self) -> None: # noqa shards_wrapper["global_stride"] = v.stride() shards_wrapper["placements"] = v.placements elif isinstance(v, ShardedTensor): + # for virtual table, we only populate the shardedTensor for Embedding Table during + # initial state_dict calls, skip weight id and bucket tensor self._model_parallel_name_to_local_shards[table_name].extend( v.local_shards() ) @@ -832,6 +861,10 @@ def _initialize_torch_state(self) -> None: # noqa # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `named_parameters_by_table`. ) in lookup.named_parameters_by_table(): + # for virtual table, currently we don't expose id tensor and bucket tensor + # because they are not updated in real time, and they are created on the fly + # whenever state_dict is called + # reference: ƒbgs _gen_named_parameters_by_table_ssd_pmt self.embeddings[table_name].register_parameter("weight", tbe_slice) for table_name in self._model_parallel_name_to_local_shards.keys(): local_shards = self._model_parallel_name_to_local_shards[table_name] @@ -852,7 +885,9 @@ def _initialize_torch_state(self) -> None: # noqa if self._output_dtensor: assert _model_parallel_name_to_compute_kernel[table_name] not in { - EmbeddingComputeKernel.KEY_VALUE.value + EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, } if shards_wrapper_map["local_tensors"]: self._model_parallel_name_to_dtensor[table_name] = ( @@ -897,7 +932,17 @@ def _initialize_torch_state(self) -> None: # noqa self._model_parallel_name_to_sharded_tensor[table_name] = ( ShardedTensor._init_from_local_shards( local_shards, - self._name_to_table_size[table_name], + ( + [ + # assuming virtual table only supports rw sharding for now + 0 if dim == 0 else dim_size + for dim, dim_size in enumerate( + self._name_to_table_size[table_name] + ) + ] + if self._table_name_to_config[table_name].use_virtual_table + else self._name_to_table_size[table_name] + ), process_group=( self._env.sharding_pg if isinstance(self._env, ShardingEnv2D) @@ -916,7 +961,9 @@ def extract_sharded_kvtensors( sharded_t, ) in module._model_parallel_name_to_sharded_tensor.items(): if _model_parallel_name_to_compute_kernel[table_name] in { - EmbeddingComputeKernel.KEY_VALUE.value + EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, }: ret[table_name] = sharded_t return ret @@ -940,7 +987,6 @@ def post_state_dict_hook( ) in module._model_parallel_name_to_dtensor.items(): destination_key = f"{prefix}embeddings.{table_name}.weight" destination[destination_key] = d_tensor - # kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid # snapshot for read access. sharded_kvtensors = extract_sharded_kvtensors(module) @@ -948,26 +994,80 @@ def post_state_dict_hook( return sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors) + virtual_table_sharded_t_map: Optional[ + Dict[str, Tuple[ShardedTensor, ShardedTensor]] + ] = None for lookup, sharding_type in zip( module._lookups, module._sharding_type_to_sharding.keys() ): if sharding_type != ShardingType.DATA_PARALLEL.value: for ( - key, - v, - _, - _, + table_name, + weights_t, + weight_ids_sharded_t, + id_cnt_per_bucket_sharded_t, ) in ( lookup.get_named_split_embedding_weights_snapshot() # pyre-ignore ): - assert key in sharded_kvtensors_copy - sharded_kvtensors_copy[key].local_shards()[0].tensor = v + assert table_name in sharded_kvtensors_copy + if self._table_name_to_config[table_name].use_virtual_table: + assert isinstance(weights_t, ShardedTensor) + if virtual_table_sharded_t_map is None: + virtual_table_sharded_t_map = {} + assert ( + weight_ids_sharded_t is not None + and id_cnt_per_bucket_sharded_t is not None + ) + # The logic here assumes there is only one shard per table on any particular rank + # if there are cases each rank has >1 shards, we need to update here accordingly + sharded_kvtensors_copy[table_name] = weights_t + virtual_table_sharded_t_map[table_name] = ( + weight_ids_sharded_t, + id_cnt_per_bucket_sharded_t, + ) + else: + assert isinstance(weights_t, PartiallyMaterializedTensor) + assert ( + weight_ids_sharded_t is None + and id_cnt_per_bucket_sharded_t is None + ) + # The logic here assumes there is only one shard per table on any particular rank + # if there are cases each rank has >1 shards, we need to update here accordingly + # pyre-ignore + sharded_kvtensors_copy[table_name].local_shards()[ + 0 + ].tensor = weights_t + + def update_destination( + table_name: str, + tensor_name: str, + destination: Dict[str, torch.Tensor], + value: torch.Tensor, + ) -> None: + destination_key = f"{prefix}embeddings.{table_name}.{tensor_name}" + destination[destination_key] = value + for ( table_name, sharded_kvtensor, ) in sharded_kvtensors_copy.items(): - destination_key = f"{prefix}embeddings.{table_name}.weight" - destination[destination_key] = sharded_kvtensor + update_destination(table_name, "weight", destination, sharded_kvtensor) + if ( + virtual_table_sharded_t_map + and table_name in virtual_table_sharded_t_map + ): + update_destination( + table_name, + "weight_id", + destination, + virtual_table_sharded_t_map[table_name][0], + ) + update_destination( + table_name, + "bucket", + destination, + virtual_table_sharded_t_map[table_name][1], + ) self.register_state_dict_pre_hook(self._pre_state_dict_hook) self._register_state_dict_hook(post_state_dict_hook) @@ -984,6 +1084,8 @@ def reset_parameters(self) -> None: for table_config in self._embedding_configs: if self.module_sharding_plan[table_config.name].compute_kernel in { EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, }: continue assert table_config.init_fn is not None diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 428e9fd3e..58492b622 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -20,7 +20,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( SplitTableBatchedEmbeddingBagsCodegen, ) -from fbgemm_gpu.tbe.ssd.training import SSDTableBatchedEmbeddingBags +from fbgemm_gpu.tbe.ssd.training import BackendType, SSDTableBatchedEmbeddingBags from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import ( PartiallyMaterializedTensor, ) @@ -38,6 +38,7 @@ BatchedFusedEmbeddingBag, KeyValueEmbedding, KeyValueEmbeddingBag, + ZeroCollisionKeyValueEmbedding, ) from torchrec.distributed.comm_ops import get_gradient_division from torchrec.distributed.composable.table_batched_embedding_slice import ( @@ -214,6 +215,8 @@ def _create_embedding_kernel( if ( table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE + or table.compute_kernel == EmbeddingComputeKernel.SSD_VIRTUAL_TABLE + or table.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE ): self._need_prefetch = True if config.compute_kernel == EmbeddingComputeKernel.DENSE: @@ -228,14 +231,28 @@ def _create_embedding_kernel( pg=pg, device=device, ) - elif config.compute_kernel in { - EmbeddingComputeKernel.KEY_VALUE, - }: + elif config.compute_kernel == EmbeddingComputeKernel.KEY_VALUE: return KeyValueEmbedding( config=config, pg=pg, device=device, ) + elif config.compute_kernel == EmbeddingComputeKernel.SSD_VIRTUAL_TABLE: + # for ssd kv + return ZeroCollisionKeyValueEmbedding( + config=config, + pg=pg, + device=device, + backend_type=BackendType.SSD, + ) + elif config.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE: + # for dram kv + return ZeroCollisionKeyValueEmbedding( + config=config, + pg=pg, + device=device, + backend_type=BackendType.DRAM, + ) else: raise ValueError(f"Compute kernel not supported {config.compute_kernel}") @@ -364,7 +381,9 @@ def get_named_split_embedding_weights_snapshot( RocksDB snapshot to support windowed access. """ for emb_module in self._emb_modules: - if isinstance(emb_module, KeyValueEmbedding): + if isinstance(emb_module, KeyValueEmbedding) or isinstance( + emb_module, ZeroCollisionKeyValueEmbedding + ): yield from emb_module.get_named_split_embedding_weights_snapshot() def flush(self) -> None: diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 98fa2d15f..cfc1a915e 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -468,6 +468,8 @@ def _prefetch_and_cached( """ if table.compute_kernel in { EmbeddingComputeKernel.KEY_VALUE, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE, }: return True diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 40d9f2308..86d81585b 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -79,8 +79,16 @@ class EmbeddingComputeKernel(Enum): QUANT = "quant" QUANT_UVM = "quant_uvm" QUANT_UVM_CACHING = "quant_uvm_caching" - KEY_VALUE = "key_value" + KEY_VALUE = ( + "key_value" # ssd as kv backend storage for fully materialized embedding table + ) CUSTOMIZED_KERNEL = "customized_kernel" + SSD_VIRTUAL_TABLE = ( + "ssd_virtual_table" # ssd as kv backend storage for virtual table + ) + DRAM_VIRTUAL_TABLE = ( + "dram_virtual_table" # dram as kv backend storage for virtual table + ) def compute_kernel_to_embedding_location( @@ -91,6 +99,8 @@ def compute_kernel_to_embedding_location( EmbeddingComputeKernel.FUSED, EmbeddingComputeKernel.QUANT, EmbeddingComputeKernel.KEY_VALUE, # use hbm for cache + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE, # use hbm for cache + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE, # use hbm for cache ]: return EmbeddingLocation.DEVICE elif compute_kernel in [ @@ -472,6 +482,8 @@ def compute_kernels( EmbeddingComputeKernel.FUSED_UVM.value, EmbeddingComputeKernel.FUSED_UVM_CACHING.value, EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, ] else: # TODO re-enable model parallel and dense diff --git a/torchrec/distributed/planner/constants.py b/torchrec/distributed/planner/constants.py index 7dd86b060..ef76d92d4 100644 --- a/torchrec/distributed/planner/constants.py +++ b/torchrec/distributed/planner/constants.py @@ -93,7 +93,10 @@ def kernel_bw_lookup( caching_ratio * hbm_mem_bw + (1 - caching_ratio) * hbm_to_ddr_mem_bw ) / 10, + # TODO: revisit whether this estimation makes sense ("cuda", EmbeddingComputeKernel.KEY_VALUE.value): hbm_to_ddr_mem_bw, + ("cuda", EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value): hbm_to_ddr_mem_bw, + ("cuda", EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value): hbm_to_ddr_mem_bw, } if ( diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 66ea9ee2d..7676c5d45 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -43,7 +43,9 @@ # compute kernels that should only be used if users specified them GUARDED_COMPUTE_KERNELS: Set[EmbeddingComputeKernel] = { - EmbeddingComputeKernel.KEY_VALUE + EmbeddingComputeKernel.KEY_VALUE, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE, } diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index d3f230e78..3abef9fcc 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -1178,10 +1178,18 @@ def calculate_shard_storages( EmbeddingComputeKernel.FUSED_UVM_CACHING.value, EmbeddingComputeKernel.QUANT_UVM_CACHING.value, EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, }: + # TODO(wangj): for ssd/dram kv, most likely we use absolute L1 cache size instead of caching ratio, as denominator is huge hbm_storage = round(ddr_storage * caching_ratio) table_cached = True - if compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value}: + if compute_kernel in { + EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, + }: + # TODO(wangj): update this to the L2 cache usage and add SSD usage ddr_storage = 0 optimizer_class = getattr(tensor, "_optimizer_classes", [None])[0] diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index b62609da1..c6037a9e5 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -217,6 +217,8 @@ def _shard( weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning, + total_num_buckets=info.embedding_config.total_num_buckets, + use_virtual_table=info.embedding_config.use_virtual_table, ) ) return tables_per_rank diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index e06821e47..b652699f5 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn from tensordict import TensorDict +from torchrec.distributed.embedding import EmbeddingCollectionSharder from torchrec.distributed.embedding_tower_sharding import ( EmbeddingTowerCollectionSharder, EmbeddingTowerSharder, @@ -1671,6 +1672,38 @@ def forward( return pred +class TestECSharder(EmbeddingCollectionSharder): + def __init__( + self, + sharding_type: str, + kernel_type: str, + fused_params: Optional[Dict[str, Any]] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + if fused_params is None: + fused_params = {} + + self._sharding_type = sharding_type + self._kernel_type = kernel_type + super().__init__(fused_params, qcomm_codecs_registry) + + """ + Restricts sharding to single type only. + """ + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + """ + Restricts to single impl. + """ + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [self._kernel_type] + + class TestEBCSharder(EmbeddingBagCollectionSharder): def __init__( self, diff --git a/torchrec/distributed/test_utils/test_model_parallel_base.py b/torchrec/distributed/test_utils/test_model_parallel_base.py index bd918d97e..b37111cb8 100644 --- a/torchrec/distributed/test_utils/test_model_parallel_base.py +++ b/torchrec/distributed/test_utils/test_model_parallel_base.py @@ -16,6 +16,9 @@ import torch import torch.nn as nn from fbgemm_gpu.split_embedding_configs import EmbOptimType +from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import ( + PartiallyMaterializedTensor, +) from hypothesis import given, settings, strategies as st, Verbosity from torch import distributed as dist from torch.distributed._shard.sharded_tensor import ShardedTensor @@ -396,20 +399,46 @@ def _compare_models( m1: DistributedModelParallel, m2: DistributedModelParallel, is_deterministic: bool = True, + use_virtual_table: bool = False, ) -> None: sd1 = m1.state_dict() - for key, value in m2.state_dict().items(): + sd2 = m2.state_dict() + for key, value in sd2.items(): v2 = sd1[key] if isinstance(value, ShardedTensor): assert isinstance(v2, ShardedTensor) self.assertEqual(len(value.local_shards()), len(v2.local_shards())) - for dst, src in zip(value.local_shards(), v2.local_shards()): + for local_shard_id, (dst, src) in enumerate( + zip(value.local_shards(), v2.local_shards()) + ): + src_tensor = None + dst_tensor = None + if isinstance(dst.tensor, PartiallyMaterializedTensor): + assert isinstance(src.tensor, PartiallyMaterializedTensor) + if use_virtual_table: + # kvz zch emb table comparison, id is non-continuous + wid_key = key[: key.rfind(".")] + ".weight_id" + src_wid = sd1[wid_key].local_shards()[local_shard_id].tensor + dst_wid = sd2[wid_key].local_shards()[local_shard_id].tensor + + sorted_src_wid, _ = torch.sort(src_wid.view(-1)) + sorted_dst_wid, _ = torch.sort(dst_wid.view(-1)) + assert torch.equal(sorted_src_wid, sorted_dst_wid) + src_tensor = src.tensor.get_weights_by_ids(src_wid) + dst_tensor = dst.tensor.get_weights_by_ids(dst_wid) + else: + # normal ssd offloading emb table comparison + src_tensor = src.tensor.full_tensor() + dst_tensor = dst.tensor.full_tensor() + else: + src_tensor = src.tensor + dst_tensor = dst.tensor if is_deterministic: - self.assertTrue(torch.equal(src.tensor, dst.tensor)) + self.assertTrue(torch.equal(src_tensor, dst_tensor)) else: - rtol, atol = _get_default_rtol_and_atol(src.tensor, dst.tensor) + rtol, atol = _get_default_rtol_and_atol(src_tensor, dst_tensor) torch.testing.assert_close( - src.tensor, dst.tensor, rtol=rtol, atol=atol + src_tensor, dst_tensor, rtol=rtol, atol=atol ) elif isinstance(value, DTensor): assert isinstance(v2, DTensor) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index e4a8469c2..cb28791fc 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -36,6 +36,7 @@ ModelInput, TestEBCSharder, TestEBSharder, + TestECSharder, TestETCSharder, TestETSharder, TestSparseNNBase, @@ -63,6 +64,7 @@ class SharderType(Enum): EMBEDDING_BAG_COLLECTION = "embedding_bag_collection" EMBEDDING_TOWER = "embedding_tower" EMBEDDING_TOWER_COLLECTION = "embedding_tower_collection" + EMBEDDING_COLLECTION = "embedding_collection" def create_test_sharder( @@ -72,7 +74,7 @@ def create_test_sharder( fused_params: Optional[Dict[str, Any]] = None, qcomms_config: Optional[QCommsConfig] = None, device: Optional[torch.device] = None, -) -> Union[TestEBSharder, TestEBCSharder, TestETSharder, TestETCSharder]: +) -> Union[TestEBSharder, TestEBCSharder, TestETSharder, TestETCSharder, TestECSharder]: if fused_params is None: fused_params = {} qcomm_codecs_registry = {} @@ -91,6 +93,10 @@ def create_test_sharder( fused_params, qcomm_codecs_registry, ) + elif sharder_type == SharderType.EMBEDDING_COLLECTION.value: + return TestECSharder( + sharding_type, kernel_type, fused_params, qcomm_codecs_registry + ) elif sharder_type == SharderType.EMBEDDING_TOWER.value: return TestETSharder( sharding_type, kernel_type, fused_params, qcomm_codecs_registry diff --git a/torchrec/distributed/tests/test_embedding_sharding.py b/torchrec/distributed/tests/test_embedding_sharding.py index 466cf1a16..618bf9273 100644 --- a/torchrec/distributed/tests/test_embedding_sharding.py +++ b/torchrec/distributed/tests/test_embedding_sharding.py @@ -11,16 +11,16 @@ import random import unittest from typing import Any, Dict, List -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import hypothesis.strategies as st import torch from hypothesis import given, settings +from torchrec.distributed.batched_embedding_kernel import ZeroCollisionKeyValueEmbedding from torchrec.distributed.embedding import EmbeddingCollectionContext from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel - from torchrec.distributed.embedding_sharding import ( _get_compute_kernel_type, _get_grouping_fused_params, @@ -38,6 +38,8 @@ from torchrec.modules.embedding_configs import DataType, PoolingType from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +WORLD_SIZE = 4 + class TestGetWeightedAverageCacheLoadFactor(unittest.TestCase): def test_get_avg_cache_load_factor_hbm(self) -> None: @@ -532,3 +534,97 @@ def test_set_sharding_context_post_a2a(self) -> None: _set_sharding_context_post_a2a(kjts, ctx) for context, result in zip(ctx.sharding_contexts, results): self.assertEqual(context.batch_size_per_rank_per_feature, result) + + +class TestECBucketMetadata(unittest.TestCase): + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + data_type=st.sampled_from([DataType.FP16, DataType.FP32]), + embedding_dim=st.sampled_from(list(range(160, 320, 40))), + total_bucket=st.sampled_from([14, 20, 32, 40]), + my_rank=st.integers(min_value=0, max_value=WORLD_SIZE), + ) + @settings(max_examples=10, deadline=10000) + def test_bucket_metadata_calculation_util( + self, data_type: DataType, embedding_dim: int, total_bucket: int, my_rank: int + ) -> None: + compute_kernels = [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE, + ] + fused_params_groups = [ + {"cache_load_factor": 0.5}, + {"cache_load_factor": 0.5}, + {"cache_load_factor": 0.5}, + {"cache_load_factor": 0.5}, + ] + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=data_type, + pooling=PoolingType.NONE, + has_feature_processor=False, + fused_params=fused_params_groups[i], + feature_names=[f"feature_{i}"], + compute_kernel=compute_kernels[i], + embedding_dim=embedding_dim, + local_rows=10000 * (2 * i + 1) // WORLD_SIZE, + local_cols=embedding_dim, + num_embeddings=10000 * (2 * i + 1), + total_num_buckets=total_bucket, + use_virtual_table=True, + ) + for i in range(len(compute_kernels)) + ] + + # since we don't have access to _group_tables_per_rank + tables_per_rank: List[List[ShardedEmbeddingTable]] = [tables] + + # taking only the list for the first rank + table_groups: List[GroupedEmbeddingConfig] = group_tables(tables_per_rank)[0] + + # assert that they are grouped together + self.assertEqual(len(table_groups), 1) + + table_group = table_groups[0] + + expect_failure = False + for table in tables: + if ( + table.num_embeddings % total_bucket != 0 + or total_bucket % WORLD_SIZE != 0 + ): + expect_failure = True + break + + with patch( + "torchrec.distributed.batched_embedding_kernel.dist.get_world_size", + return_value=WORLD_SIZE, + ), patch( + "torchrec.distributed.batched_embedding_kernel.dist.get_rank", + return_value=my_rank, + ): + if expect_failure: + with self.assertRaises(AssertionError): + zch_kv_emb = ZeroCollisionKeyValueEmbedding( + table_group, None, torch.device("cpu") + ) + else: + zch_kv_emb = ZeroCollisionKeyValueEmbedding( + table_group, None, torch.device("cpu") + ) + expected_tuple = [ + ( + total_bucket / WORLD_SIZE * my_rank, + total_bucket / WORLD_SIZE * (my_rank + 1), + table.num_embeddings / total_bucket, + ) + for table in tables + ] + self.assertEqual(zch_kv_emb._bucket_spec, expected_tuple) diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py index eb0f9545e..a9a77fcaf 100644 --- a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py @@ -7,6 +7,7 @@ # pyre-strict +import copy import unittest from typing import cast, List, OrderedDict, Union @@ -17,6 +18,7 @@ from torchrec.distributed.batched_embedding_kernel import ( KeyValueEmbedding, KeyValueEmbeddingBag, + ZeroCollisionKeyValueEmbedding, ) from torchrec.distributed.embedding_types import ( EmbeddingComputeKernel, @@ -37,7 +39,7 @@ TestEmbeddingCollectionSharder, TestSequenceSparseNN, ) -from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.distributed.types import KeyValueParams, ModuleSharder, ShardingType from torchrec.modules.embedding_configs import ( DataType, EmbeddingBagConfig, @@ -112,25 +114,20 @@ def _copy_ssd_emb_modules( ) emb1_kv = { - t: (w, w_id, bucket_cnt) - for t, w, w_id, bucket_cnt in emb_module1.get_named_split_embedding_weights_snapshot() + t: pmt + for t, pmt, _, _ in emb_module1.get_named_split_embedding_weights_snapshot() } for ( t, - w, + pmt, _, _, ) in emb_module2.get_named_split_embedding_weights_snapshot(): - w1 = emb1_kv[t][0] - w1_full_tensor = w1.full_tensor() + pmt1 = emb1_kv[t] + w1 = pmt1.full_tensor() # write value into ssd for both emb module for later comparison - w.wrapped.set_range( - 0, 0, w1_full_tensor.size(0), w1_full_tensor - ) - w1.wrapped.set_range( - 0, 0, w1_full_tensor.size(0), w1_full_tensor - ) + pmt.wrapped.set_range(0, 0, w1.size(0), w1) # purge after loading. This is needed, since we pass a batch # through dmp when instantiating them. @@ -193,6 +190,7 @@ def test_ssd_load_state_dict( table.name: ParameterConstraints( sharding_types=[sharding_type], compute_kernels=[kernel_type], + key_value_params=KeyValueParams(bulk_init_chunk_size=1024), ) for i, table in enumerate(self.tables) } @@ -396,6 +394,7 @@ def test_ssd_fused_optimizer( table.name: ParameterConstraints( sharding_types=[sharding_type], compute_kernels=[kernel_type], + key_value_params=KeyValueParams(bulk_init_chunk_size=1024), ) for i, table in enumerate(self.tables) } @@ -515,6 +514,7 @@ def test_ssd_mixed_kernels( compute_kernels=( [base_kernel_type] if i % 2 == fused_first else [kernel_type] ), + key_value_params=KeyValueParams(bulk_init_chunk_size=1024), ) for i, table in enumerate(self.tables) } @@ -586,6 +586,7 @@ def test_ssd_mixed_sharding_types( else [ShardingType.ROW_WISE.value] ), compute_kernels=[kernel_type], + key_value_params=KeyValueParams(bulk_init_chunk_size=1024), ) for i, table in enumerate(self.tables) } @@ -686,25 +687,20 @@ def _copy_ssd_emb_modules( ) emb1_kv = { - t: (w, w_id, bucket_cnt) - for t, w, w_id, bucket_cnt in emb_module1.get_named_split_embedding_weights_snapshot() + t: pmt + for t, pmt, _, _ in emb_module1.get_named_split_embedding_weights_snapshot() } for ( t, - w, + pmt, _, _, ) in emb_module2.get_named_split_embedding_weights_snapshot(): - w1 = emb1_kv[t][0] - w1_full_tensor = w1.full_tensor() + pmt1 = emb1_kv[t] + w1 = pmt1.full_tensor() # write value into ssd for both emb module for later comparison - w.wrapped.set_range( - 0, 0, w1_full_tensor.size(0), w1_full_tensor - ) - w1.wrapped.set_range( - 0, 0, w1_full_tensor.size(0), w1_full_tensor - ) + pmt.wrapped.set_range(0, 0, w1.size(0), w1) # purge after loading. This is needed, since we pass a batch # through dmp when instantiating them. @@ -769,6 +765,7 @@ def test_ssd_load_state_dict( table.name: ParameterConstraints( sharding_types=[sharding_type], compute_kernels=[kernel_type], + key_value_params=KeyValueParams(bulk_init_chunk_size=1024), ) for i, table in enumerate(self.tables) } @@ -786,6 +783,359 @@ def test_ssd_load_state_dict( self._compare_models(m1, m2, is_deterministic=is_deterministic) +class ZeroCollisionSequenceModelParallelStateDictTest(ModelParallelSingleRankBase): + def setUp(self, backend: str = "nccl") -> None: + self.shared_features = [] + self.embedding_groups = {} + + super().setUp(backend=backend) + + def _create_tables(self) -> None: + num_features = 4 + shared_features = 2 + + initial_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 1000, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + total_num_buckets=10, + use_virtual_table=True, + ) + for i in range(num_features) + ] + + shared_features_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 20, + embedding_dim=16, + name="table_" + str(i + num_features), + feature_names=["feature_" + str(i)], + total_num_buckets=10, + use_virtual_table=True, + ) + for i in range(shared_features) + ] + + self.tables += initial_tables + shared_features_tables + self.shared_features += [f"feature_{i}" for i in range(shared_features)] + + self.embedding_groups["group_0"] = [ + (f"{feature}@{table.name}" if feature in self.shared_features else feature) + for table in self.tables + for feature in table.feature_names + ] + + def _create_model(self) -> nn.Module: + return TestSequenceSparseNN( + tables=self.tables, + num_float_features=self.num_float_features, + embedding_groups=self.embedding_groups, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + + @staticmethod + def _copy_ssd_emb_modules( + m1: DistributedModelParallel, m2: DistributedModelParallel + ) -> None: + """ + Util function to copy and set the SSD TBE modules of two models. It + requires both DMP modules to have the same sharding plan. + """ + for lookup1, lookup2 in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + m1.module.sparse.ec._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + m2.module.sparse.ec._lookups, + ): + for emb_module1, emb_module2 in zip( + lookup1._emb_modules, lookup2._emb_modules + ): + ssd_emb_modules = { + ZeroCollisionKeyValueEmbedding, + } + if type(emb_module1) in ssd_emb_modules: + assert type(emb_module1) is type(emb_module2), ( + "Expect two emb_modules to be of the same type, either both " + "SSDEmbeddingBag or SSDEmbeddingBag." + ) + emb_module1.flush() + emb_module2.flush() + + emb1_kv = { + t: (sharded_t, sharded_w_id, bucket) + for t, sharded_t, sharded_w_id, bucket in emb_module1.get_named_split_embedding_weights_snapshot() + } + for ( + t, + sharded_t2, + _, + _, + ) in emb_module2.get_named_split_embedding_weights_snapshot(): + assert t in emb1_kv + sharded_t1 = emb1_kv[t][0] + sharded_w1_id = emb1_kv[t][1] + w1_id = sharded_w1_id.local_shards()[0].tensor + + pmt1 = sharded_t1.local_shards()[0].tensor + w1 = pmt1.get_weights_by_ids(w1_id) + + # write value into ssd for both emb module for later comparison + pmt2 = sharded_t2.local_shards()[0].tensor + pmt2.wrapped.set_weights_and_ids(w1, w1_id.view(-1)) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + emb_module1.purge() + emb_module2.purge() + + @staticmethod + def _copy_fused_modules_into_ssd_emb_modules( + fused_m: DistributedModelParallel, ssd_m: DistributedModelParallel + ) -> None: + """ + Util function to copy from fused embedding module to SSD TBE for initialization. It + requires both DMP modules to have the same sharding plan. + """ + + for fused_lookup, ssd_lookup in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + fused_m.module.sparse.ec._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + ssd_m.module.sparse.ec._lookups, + ): + for fused_emb_module, ssd_emb_module in zip( + fused_lookup._emb_modules, ssd_lookup._emb_modules + ): + ssd_emb_modules = {ZeroCollisionKeyValueEmbedding} + if type(ssd_emb_module) in ssd_emb_modules: + fused_state_dict = fused_emb_module.state_dict() + for ( + t, + sharded_t, + _, + _, + ) in ssd_emb_module.get_named_split_embedding_weights_snapshot(): + weight_key = f"{t}.weight" + fused_sharded_t = fused_state_dict[weight_key] + fused_weight = fused_sharded_t.local_shards()[0].tensor.to( + "cpu" + ) + + # write value into ssd for both emb module for later comparison + pmt = sharded_t.local_shards()[0].tensor + pmt.wrapped.set_range(0, 0, fused_weight.size(0), fused_weight) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + fused_emb_module.purge() + ssd_emb_module.purge() + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + # TODO: add other test cases when kv embedding support other sharding + ShardingType.ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_kv_zch_load_state_dict( + self, + sharding_type: str, + kernel_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + This test checks that if SSD TBE is deterministic. That is, if two SSD + TBEs start with the same state, they would produce the same output. + """ + self._set_table_weights_precision(dtype) + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + sharders = [ + cast( + ModuleSharder[nn.Module], + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ), + ), + ] + + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models( + m1, m2, is_deterministic=is_deterministic, use_virtual_table=True + ) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + ] + ), + sharding_type=st.sampled_from( + [ + # TODO: add other test cases when kv embedding support other sharding + ShardingType.ROW_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_kv_zch_numerical_accuracy( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Make sure it produces same numbers as normal TBE. + """ + self._set_table_weights_precision(dtype) + base_kernel_type = EmbeddingComputeKernel.FUSED.value + learning_rate = 0.1 + fused_params = { + "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, + "learning_rate": learning_rate, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + fused_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + base_kernel_type, # base kernel type + fused_params=fused_params, + ), + ), + ] + ssd_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ), + ] + ssd_constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + # for fused model, we need to change the table config to non-kvzch + ssd_tables = copy.deepcopy(self.tables) + for table in self.tables: + table.total_num_buckets = None + table.use_virtual_table = False + (fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders) + self.tables = ssd_tables + (ssd_model, _), batch = self._generate_dmps_and_batch( + ssd_sharders, constraints=ssd_constraints + ) + + # load state dict for dense modules + ssd_model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", fused_model.state_dict()) + ) + + # for this to work, we expect the order of lookups to be the same + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + assert len(fused_model.module.sparse.ec._lookups) == len( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + ssd_model.module.sparse.ec._lookups + ), "Expect same number of lookups" + + for fused_lookup, ssd_lookup in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + fused_model.module.sparse.ec._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + ssd_model.module.sparse.ec._lookups, + ): + assert len(fused_lookup._emb_modules) == len( + ssd_lookup._emb_modules + ), "Expect same number of emb modules" + + self._copy_fused_modules_into_ssd_emb_modules(fused_model, ssd_model) + + if is_training: + self._train_models(fused_model, ssd_model, batch) + self._eval_models( + fused_model, ssd_model, batch, is_deterministic=is_deterministic + ) + + # TODO: uncomment this when we have optimizer plumb through + # def test_ssd_fused_optimizer( + + # TODO: uncomment this when we have multiple kernels in rw support(unblock input dist) + # def test_ssd_mixed_kernels + + # TODO: uncomment this when we support different sharding types, e.g. tw, tw_rw together with rw + # def test_ssd_mixed_sharding_types + + # TODO: remove after development is done def main() -> None: unittest.main() diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 6adca02c0..31e284cb0 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -423,7 +423,12 @@ def add_params_from_parameter_sharding( fused_params["output_dtype"] = parameter_sharding.output_dtype if ( - parameter_sharding.compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value} + parameter_sharding.compute_kernel + in { + EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, + } and parameter_sharding.key_value_params is not None ): kv_params = parameter_sharding.key_value_params