diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 4c9a29810..b8a0ed574 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -327,20 +327,45 @@ def __init__( else: shard_offsets_for_kv_zch = None - self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz( - embedding_specs=embedding_specs, - device=device, - pooling_mode=self._pooling, - feature_table_map=self._feature_table_map, - row_alignment=self._tbe_row_alignment, - uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue - bounds_check_mode=( + # Determine embedding cache mode for KV embedding tables + embedding_cache_mode = False # Default: False = randomized initialization + if tbe_clazz == KVEmbeddingInference: + # For KV embedding tables, set cache mode based on embedding table configuration + # Check if any table has NoEvictionPolicy - use zero init for those + for table in config.embedding_tables: + if ( + table.virtual_table_eviction_policy is not None + and type(table.virtual_table_eviction_policy).__name__ + == "NoEvictionPolicy" + ): + embedding_cache_mode = True # True = zero initialization + break + + # Build kwargs for module construction + module_kwargs: Dict[str, Any] = { + "embedding_specs": embedding_specs, + "device": device, + "pooling_mode": self._pooling, + "feature_table_map": self._feature_table_map, + "row_alignment": self._tbe_row_alignment, + "uvm_host_mapped": True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + "bounds_check_mode": ( bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING ), - feature_names_per_table=[ + "feature_names_per_table": [ table.feature_names for table in config.embedding_tables ], - **(tbe_fused_params(fused_params) or {}), + } + + # Add KV-specific parameters + if tbe_clazz == KVEmbeddingInference: + module_kwargs["embedding_cache_mode"] = embedding_cache_mode + + # Add fused params + module_kwargs.update(**(tbe_fused_params(fused_params) or {})) + + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz( + **module_kwargs ) if device is not None: self._emb_module.initialize_weights() @@ -495,6 +520,7 @@ def __init__( managed: List[EmbeddingLocation] = [] is_virtual_table = False + embedding_cache_mode = False for table in config.embedding_tables: if device is not None and device.type == "cuda": managed.append( @@ -504,6 +530,8 @@ def __init__( managed.append(EmbeddingLocation.HOST) if table.use_virtual_table: is_virtual_table = True + if table.enable_embedding_update: + embedding_cache_mode = True self._config: GroupedEmbeddingConfig = config self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params) self._quant_state_dict_split_scale_bias: bool = ( @@ -529,8 +557,9 @@ def __init__( else: shard_offsets_for_kv_zch = None - self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz( - embedding_specs=[ + # Build kwargs for module construction + module_kwargs: Dict[str, Any] = { + "embedding_specs": [ ( table.name, local_rows, @@ -549,15 +578,25 @@ def __init__( managed, ) ], - device=device, - pooling_mode=PoolingMode.NONE, - feature_table_map=self._feature_table_map, - row_alignment=self._tbe_row_alignment, - uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue - feature_names_per_table=[ + "device": device, + "pooling_mode": PoolingMode.NONE, + "feature_table_map": self._feature_table_map, + "row_alignment": self._tbe_row_alignment, + "uvm_host_mapped": True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + "feature_names_per_table": [ table.feature_names for table in config.embedding_tables ], - **(tbe_fused_params(fused_params) or {}), + } + + # Add KV-specific parameters + if embedding_clazz == KVEmbeddingInference: + module_kwargs["embedding_cache_mode"] = embedding_cache_mode + + # Add fused params + module_kwargs.update(**(tbe_fused_params(fused_params) or {})) + + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz( + **module_kwargs ) if device is not None: self._emb_module.initialize_weights() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 3e979b34d..ddd9de087 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -764,9 +764,9 @@ def __init__( # noqa C901 self._output_dtype = output_dtype self._device = device self.row_alignment = row_alignment - self._key_to_tables: Dict[Tuple[DataType, bool], List[EmbeddingConfig]] = ( - defaultdict(list) - ) + self._key_to_tables: Dict[ + Tuple[DataType, bool, bool], List[EmbeddingConfig] + ] = defaultdict(list) self._feature_names: List[str] = [] self._features_order: Optional[List[int]] = None @@ -789,12 +789,24 @@ def __init__( # noqa C901 + f" {self._embedding_dim}" ) if hasattr(table, "use_virtual_table"): - key = (table.data_type, table.use_virtual_table) + key = (table.data_type, table.use_virtual_table, False) + if hasattr(table, "use_virtual_table") and hasattr( + table, "enable_embedding_update" + ): + key = ( + table.data_type, + table.use_virtual_table, + table.enable_embedding_update, + ) else: - key = (table.data_type, False) + key = (table.data_type, False, False) self._key_to_tables[key].append(table) self._feature_splits: List[int] = [] - for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items(): + for ( + data_type, + use_virtual_table, + enable_embedding_update, + ), emb_configs in self._key_to_tables.items(): embedding_specs = [] weight_lists: Optional[ List[Tuple[torch.Tensor, Optional[torch.Tensor]]] @@ -825,15 +837,20 @@ def __init__( # noqa C901 if use_virtual_table else IntNBitTableBatchedEmbeddingBagsCodegen ) - emb_module = embedding_clazz( - embedding_specs=embedding_specs, - pooling_mode=PoolingMode.NONE, - weight_lists=weight_lists, - device=device, - output_dtype=data_type_to_sparse_type(dtype_to_data_type(output_dtype)), - row_alignment=row_alignment, - feature_table_map=feature_table_map, - ) + kwargs: Dict[str, Any] = { + "embedding_specs": embedding_specs, + "pooling_mode": PoolingMode.NONE, + "weight_lists": weight_lists, + "device": device, + "output_dtype": data_type_to_sparse_type( + dtype_to_data_type(output_dtype) + ), + "row_alignment": row_alignment, + "feature_table_map": feature_table_map, + } + if embedding_clazz == KVEmbeddingInference: + kwargs["embedding_cache_mode"] = enable_embedding_update + emb_module = embedding_clazz(**kwargs) if weight_lists is None: emb_module.initialize_weights() self._emb_modules.append(emb_module) @@ -869,6 +886,7 @@ def __init__( # noqa C901 "weight_qbias", qbias ) + # pyre-ignore [8] self._embedding_names_by_batched_tables: Dict[ Tuple[DataType, bool], List[str] ] = { @@ -934,6 +952,7 @@ def forward( f = kjts_per_key[i] lengths = _get_feature_length(f) indices, offsets = _fx_trec_unwrap_kjt(f) + # pyre-ignore [6] embedding_names = self._embedding_names_by_batched_tables[key] lookup = ( emb_module(indices=indices, offsets=offsets)