Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,20 +327,45 @@ def __init__(
else:
shard_offsets_for_kv_zch = None

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
embedding_specs=embedding_specs,
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
# Determine embedding cache mode for KV embedding tables
embedding_cache_mode = False # Default: False = randomized initialization
if tbe_clazz == KVEmbeddingInference:
# For KV embedding tables, set cache mode based on embedding table configuration
# Check if any table has NoEvictionPolicy - use zero init for those
for table in config.embedding_tables:
if (
table.virtual_table_eviction_policy is not None
and type(table.virtual_table_eviction_policy).__name__
== "NoEvictionPolicy"
):
embedding_cache_mode = True # True = zero initialization
break

# Build kwargs for module construction
module_kwargs: Dict[str, Any] = {
"embedding_specs": embedding_specs,
"device": device,
"pooling_mode": self._pooling,
"feature_table_map": self._feature_table_map,
"row_alignment": self._tbe_row_alignment,
"uvm_host_mapped": True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
"bounds_check_mode": (
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
feature_names_per_table=[
"feature_names_per_table": [
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
}

# Add KV-specific parameters
if tbe_clazz == KVEmbeddingInference:
module_kwargs["embedding_cache_mode"] = embedding_cache_mode

# Add fused params
module_kwargs.update(**(tbe_fused_params(fused_params) or {}))

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
**module_kwargs
)
if device is not None:
self._emb_module.initialize_weights()
Expand Down Expand Up @@ -495,6 +520,7 @@ def __init__(

managed: List[EmbeddingLocation] = []
is_virtual_table = False
embedding_cache_mode = False
for table in config.embedding_tables:
if device is not None and device.type == "cuda":
managed.append(
Expand All @@ -504,6 +530,8 @@ def __init__(
managed.append(EmbeddingLocation.HOST)
if table.use_virtual_table:
is_virtual_table = True
if table.enable_embedding_update:
embedding_cache_mode = True
self._config: GroupedEmbeddingConfig = config
self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params)
self._quant_state_dict_split_scale_bias: bool = (
Expand All @@ -529,8 +557,9 @@ def __init__(
else:
shard_offsets_for_kv_zch = None

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz(
embedding_specs=[
# Build kwargs for module construction
module_kwargs: Dict[str, Any] = {
"embedding_specs": [
(
table.name,
local_rows,
Expand All @@ -549,15 +578,25 @@ def __init__(
managed,
)
],
device=device,
pooling_mode=PoolingMode.NONE,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
feature_names_per_table=[
"device": device,
"pooling_mode": PoolingMode.NONE,
"feature_table_map": self._feature_table_map,
"row_alignment": self._tbe_row_alignment,
"uvm_host_mapped": True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
"feature_names_per_table": [
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
}

# Add KV-specific parameters
if embedding_clazz == KVEmbeddingInference:
module_kwargs["embedding_cache_mode"] = embedding_cache_mode

# Add fused params
module_kwargs.update(**(tbe_fused_params(fused_params) or {}))

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz(
**module_kwargs
)
if device is not None:
self._emb_module.initialize_weights()
Expand Down
49 changes: 34 additions & 15 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,9 @@ def __init__( # noqa C901
self._output_dtype = output_dtype
self._device = device
self.row_alignment = row_alignment
self._key_to_tables: Dict[Tuple[DataType, bool], List[EmbeddingConfig]] = (
defaultdict(list)
)
self._key_to_tables: Dict[
Tuple[DataType, bool, bool], List[EmbeddingConfig]
] = defaultdict(list)
self._feature_names: List[str] = []
self._features_order: Optional[List[int]] = None

Expand All @@ -789,12 +789,24 @@ def __init__( # noqa C901
+ f" {self._embedding_dim}"
)
if hasattr(table, "use_virtual_table"):
key = (table.data_type, table.use_virtual_table)
key = (table.data_type, table.use_virtual_table, False)
if hasattr(table, "use_virtual_table") and hasattr(
table, "enable_embedding_update"
):
key = (
table.data_type,
table.use_virtual_table,
table.enable_embedding_update,
)
else:
key = (table.data_type, False)
key = (table.data_type, False, False)
self._key_to_tables[key].append(table)
self._feature_splits: List[int] = []
for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items():
for (
data_type,
use_virtual_table,
enable_embedding_update,
), emb_configs in self._key_to_tables.items():
embedding_specs = []
weight_lists: Optional[
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
Expand Down Expand Up @@ -825,15 +837,20 @@ def __init__( # noqa C901
if use_virtual_table
else IntNBitTableBatchedEmbeddingBagsCodegen
)
emb_module = embedding_clazz(
embedding_specs=embedding_specs,
pooling_mode=PoolingMode.NONE,
weight_lists=weight_lists,
device=device,
output_dtype=data_type_to_sparse_type(dtype_to_data_type(output_dtype)),
row_alignment=row_alignment,
feature_table_map=feature_table_map,
)
kwargs: Dict[str, Any] = {
"embedding_specs": embedding_specs,
"pooling_mode": PoolingMode.NONE,
"weight_lists": weight_lists,
"device": device,
"output_dtype": data_type_to_sparse_type(
dtype_to_data_type(output_dtype)
),
"row_alignment": row_alignment,
"feature_table_map": feature_table_map,
}
if embedding_clazz == KVEmbeddingInference:
kwargs["embedding_cache_mode"] = enable_embedding_update
emb_module = embedding_clazz(**kwargs)
if weight_lists is None:
emb_module.initialize_weights()
self._emb_modules.append(emb_module)
Expand Down Expand Up @@ -869,6 +886,7 @@ def __init__( # noqa C901
"weight_qbias", qbias
)

# pyre-ignore [8]
self._embedding_names_by_batched_tables: Dict[
Tuple[DataType, bool], List[str]
] = {
Expand Down Expand Up @@ -934,6 +952,7 @@ def forward(
f = kjts_per_key[i]
lengths = _get_feature_length(f)
indices, offsets = _fx_trec_unwrap_kjt(f)
# pyre-ignore [6]
embedding_names = self._embedding_names_by_batched_tables[key]
lookup = (
emb_module(indices=indices, offsets=offsets)
Expand Down
Loading