Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 307 additions & 20 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
SplitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
from fbgemm_gpu.tbe.ssd.training import BackendType, KVZCHParams
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
PartiallyMaterializedTensor,
)
Expand All @@ -50,7 +51,11 @@
from torchrec.distributed.composable.table_batched_embedding_slice import (
TableBatchedEmbeddingSlice,
)
from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict
from torchrec.distributed.embedding_kernel import (
BaseEmbedding,
create_virtual_sharded_tensors,
get_state_dict,
)
from torchrec.distributed.embedding_types import (
compute_kernel_to_embedding_location,
DTensorMetadata,
Expand All @@ -65,7 +70,7 @@
ShardMetadata,
TensorProperties,
)
from torchrec.distributed.utils import append_prefix
from torchrec.distributed.utils import append_prefix, none_throws
from torchrec.modules.embedding_configs import (
data_type_to_sparse_type,
pooling_type_to_pooling_mode,
Expand Down Expand Up @@ -169,6 +174,24 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
return ssd_tbe_params


def _populate_zero_collision_tbe_params(
tbe_params: Dict[str, Any],
sharded_local_buckets: List[Tuple[int, int, int]],
) -> None:
"""
Construct Zero Collision TBE params from config and fused params dict.
"""
bucket_offsets: List[Tuple[int, int]] = [
(offset_start, offset_end)
for offset_start, offset_end, _ in sharded_local_buckets
]
bucket_sizes: List[int] = [size for _, _, size in sharded_local_buckets]

tbe_params["kv_zch_params"] = KVZCHParams(
bucket_offsets=bucket_offsets, bucket_sizes=bucket_sizes
)


class KeyValueEmbeddingFusedOptimizer(FusedOptimizer):
def __init__(
self,
Expand Down Expand Up @@ -676,24 +699,6 @@ def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None:
self._emb_module.update_hyper_parameters(params_dict)


def _gen_named_parameters_by_table_ssd(
emb_module: SSDTableBatchedEmbeddingBags,
table_name_to_count: Dict[str, int],
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
) -> Iterator[Tuple[str, nn.Parameter]]:
"""
Return an empty tensor to indicate that the table is on remote device.
"""
for table in config.embedding_tables:
table_name = table.name
# placeholder
weight: nn.Parameter = nn.Parameter(torch.empty(0))
# pyre-ignore
weight._in_backward_optimizers = [EmptyFusedOptimizer()]
yield (table_name, weight)


def _gen_named_parameters_by_table_ssd_pmt(
emb_module: SSDTableBatchedEmbeddingBags,
table_name_to_count: Dict[str, int],
Expand Down Expand Up @@ -956,6 +961,10 @@ def __init__(
**ssd_tbe_params,
).to(device)

logger.info(
f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}"
)

self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
config,
self._emb_module,
Expand Down Expand Up @@ -1064,6 +1073,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
"""
for config, tensor in zip(
self._config.embedding_tables,
Expand Down Expand Up @@ -1095,6 +1106,280 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
return self.emb_module.split_embedding_weights(no_snapshot)


class ZeroCollisionKeyValueEmbedding(
BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule
):
def __init__(
self,
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
backend_type: BackendType = BackendType.SSD,
) -> None:
super().__init__(config, pg, device)

assert (
len(config.embedding_tables) > 0
), "Expected to see at least one table in SSD TBE, but found 0."
assert (
len({table.embedding_dim for table in config.embedding_tables}) == 1
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."

ssd_tbe_params = _populate_ssd_tbe_params(config)
self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets()
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec)
compute_kernel = config.embedding_tables[0].compute_kernel
embedding_location = compute_kernel_to_embedding_location(compute_kernel)

self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags(
embedding_specs=list(zip(self._num_embeddings, self._local_cols)),
feature_table_map=self._feature_table_map,
ssd_cache_location=embedding_location,
pooling_mode=PoolingMode.NONE,
backend_type=backend_type,
**ssd_tbe_params,
).to(device)

logger.info(
f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}"
)

self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
config,
self._emb_module,
pg,
)
self._param_per_table: Dict[str, nn.Parameter] = dict(
_gen_named_parameters_by_table_ssd_pmt(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
pg=pg,
)
)
self.init_parameters()

# every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db
# use split weights result cache so that multiple calls in the same train iteration will only trigger once
self._split_weights_res: Optional[
Tuple[
List[ShardedTensor],
List[ShardedTensor],
List[ShardedTensor],
]
] = None

def init_parameters(self) -> None:
"""
An advantage of KV TBE is that we don't need to init weights. Hence skipping.
"""
pass

@property
def emb_module(
self,
) -> SSDTableBatchedEmbeddingBags:
return self._emb_module

@property
def fused_optimizer(self) -> FusedOptimizer:
"""
SSD Embedding fuses backward with backward.
"""
return self._optim

def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]:
"""
utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec
"""
sharded_local_buckets: List[Tuple[int, int, int]] = []
world_size = dist.get_world_size(self._pg)
local_rank = dist.get_rank(self._pg)

for table in self._config.embedding_tables:
total_num_buckets = none_throws(table.total_num_buckets)
assert (
total_num_buckets % world_size == 0
), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}"
assert (
table.total_num_buckets
and table.num_embeddings % table.total_num_buckets == 0
), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'"
bucket_offset_start = total_num_buckets // world_size * local_rank
bucket_offset_end = min(
total_num_buckets, total_num_buckets // world_size * (local_rank + 1)
)
bucket_size = (
table.num_embeddings + total_num_buckets - 1
) // total_num_buckets
sharded_local_buckets.append(
(bucket_offset_start, bucket_offset_end, bucket_size)
)
logger.info(
f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}"
)
return sharded_local_buckets

def state_dict(
self,
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
no_snapshot: bool = True,
) -> Dict[str, Any]:
"""
Args:
no_snapshot (bool): the tensors in the returned dict are
PartiallyMaterializedTensors. this argument controls wether the
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
PartiallyMaterializedTensor has a RocksDB snapshot handle
"""
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
ret = get_state_dict(
emb_table_config_copy,
emb_tables,
self._pg,
destination,
prefix,
)
return ret

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
"""
Only allowed ways to get state_dict.
"""
for name, tensor in self.named_split_embedding_weights(
prefix, recurse, remove_duplicate
):
# hack before we support optimizer on sharded parameter level
# can delete after PEA deprecation
# pyre-ignore [6]
param = nn.Parameter(tensor)
# pyre-ignore
param._in_backward_optimizers = [EmptyFusedOptimizer()]
yield name, param

# pyre-ignore [15]
def named_split_embedding_weights(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
assert (
remove_duplicate
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights()[0],
):
key = append_prefix(prefix, f"{config.name}.weight")
yield key, tensor

def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
Tuple[
str,
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Return an iterator over embedding tables, for each table yielding
table name,
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
optional ShardedTensor for weight_id
optional ShardedTensor for bucket_cnt
"""
if self._split_weights_res is not None:
pmt_sharded_t_list = self._split_weights_res[0]
# pyre-ignore
weight_id_sharded_t_list = self._split_weights_res[1]
bucket_cnt_sharded_t_list = self._split_weights_res[2]
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
table_config = self._config.embedding_tables[table_idx]
key = append_prefix(prefix, f"{table_config.name}")

yield key, pmt_sharded_t, weight_id_sharded_t_list[
table_idx
], bucket_cnt_sharded_t_list[table_idx]
return

pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
no_snapshot=False
)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")

pmt_sharded_t_list = create_virtual_sharded_tensors(
emb_table_config_copy, pmt_list, self._pg, prefix
)
weight_id_sharded_t_list = create_virtual_sharded_tensors(
emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore
)
bucket_cnt_sharded_t_list = create_virtual_sharded_tensors(
emb_table_config_copy, bucket_cnt_list, self._pg, prefix # pyre-ignore
)
# pyre-ignore
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
assert (
len(pmt_sharded_t_list)
== len(weight_id_sharded_t_list)
== len(bucket_cnt_sharded_t_list)
)
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
table_config = self._config.embedding_tables[table_idx]
key = append_prefix(prefix, f"{table_config.name}")

yield key, pmt_sharded_t, weight_id_sharded_t_list[
table_idx
], bucket_cnt_sharded_t_list[table_idx]

self._split_weights_res = (
pmt_sharded_t_list,
weight_id_sharded_t_list,
bucket_cnt_sharded_t_list,
)

def flush(self) -> None:
"""
Flush the embeddings in cache back to SSD. Should be pretty expensive.
"""
self.emb_module.flush()

def purge(self) -> None:
"""
Reset the cache space. This is needed when we load state dict.
"""
# TODO: move the following to SSD TBE.
self.emb_module.lxu_cache_weights.zero_()
self.emb_module.lxu_cache_state.fill_(-1)

# pyre-ignore [15]
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
List[PartiallyMaterializedTensor],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
return self.emb_module.split_embedding_weights(no_snapshot)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
# reset split weights during training
self._split_weights_res = None

return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
)


class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
def __init__(
self,
Expand Down Expand Up @@ -1563,6 +1848,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
"""
for config, tensor in zip(
self._config.embedding_tables,
Expand Down
Loading
Loading