From 484d6d830332cef3affc57d65bfed59fff1b0826 Mon Sep 17 00:00:00 2001 From: Jie You Date: Mon, 10 Jun 2024 01:13:32 -0700 Subject: [PATCH] Move ITEP TorchRec Module to OSS: Step 1 (#2074) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2074 Moving `ITEPEmbeddingBagCollection` to OSS folder Reviewed By: PaulZhang12 Differential Revision: D58173777 --- torchrec/distributed/itep_embeddingbag.py | 199 ++++++++ torchrec/modules/itep_embedding_modules.py | 79 +++ torchrec/modules/itep_modules.py | 470 ++++++++++++++++++ .../tests/test_itep_embedding_modules.py | 363 ++++++++++++++ 4 files changed, 1111 insertions(+) create mode 100644 torchrec/distributed/itep_embeddingbag.py create mode 100644 torchrec/modules/itep_embedding_modules.py create mode 100644 torchrec/modules/itep_modules.py create mode 100644 torchrec/modules/tests/test_itep_embedding_modules.py diff --git a/torchrec/distributed/itep_embeddingbag.py b/torchrec/distributed/itep_embeddingbag.py new file mode 100644 index 000000000..855c1894c --- /dev/null +++ b/torchrec/distributed/itep_embeddingbag.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass +from typing import Dict, List, Optional, Type + +import torch + +from torchrec.distributed.embedding_types import ( + BaseEmbeddingSharder, + KJTList, + ShardedEmbeddingModule, +) +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionContext, + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, + ShardingType, +) +from torchrec.modules.itep_embedding_modules import ITEPEmbeddingBagCollection +from torchrec.modules.itep_modules import GenericITEPModule +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +@dataclass +class ITEPEmbeddingBagCollectionContext(EmbeddingBagCollectionContext): + is_reindexed: bool = False + + +class ShardedITEPEmbeddingBagCollection( + ShardedEmbeddingModule[ + KJTList, + List[torch.Tensor], + KeyedTensor, + ITEPEmbeddingBagCollectionContext, + ] +): + def __init__( + self, + module: ITEPEmbeddingBagCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + ebc_sharder: EmbeddingBagCollectionSharder, + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__() + + self._device = device + self._env = env + + # Iteration counter for ITEP Module. Pinning on CPU because used for condition checking and checkpointing. + self.register_buffer( + "_iter", torch.tensor(0, dtype=torch.int64, device=torch.device("cpu")) + ) + + self._embedding_bag_collection: ShardedEmbeddingBagCollection = ( + ebc_sharder.shard( + module._embedding_bag_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + ) + ) + + # Instantiate ITEP Module in sharded case, re-using metadata from non-sharded case + self._itep_module: GenericITEPModule = GenericITEPModule( + table_name_to_unpruned_hash_sizes=module._itep_module.table_name_to_unpruned_hash_sizes, + lookups=self._embedding_bag_collection._lookups, + pruning_interval=module._itep_module.pruning_interval, + enable_pruning=module._itep_module.enable_pruning, + ) + + def prefetch( + self, + dist_input: KJTList, + forward_stream: Optional[torch.cuda.Stream] = None, + ctx: Optional[ITEPEmbeddingBagCollectionContext] = None, + ) -> None: + assert ( + ctx is not None + ), "ITEP Prefetch call requires ITEPEmbeddingBagCollectionContext" + dist_input = self._reindex(dist_input) + ctx.is_reindexed = True + self._embedding_bag_collection.prefetch(dist_input, forward_stream, ctx) + + # pyre-ignore + def input_dist( + self, + ctx: ITEPEmbeddingBagCollectionContext, + features: KeyedJaggedTensor, + force_insert: bool = False, + ) -> Awaitable[Awaitable[KJTList]]: + return self._embedding_bag_collection.input_dist(ctx, features) + + def _reindex(self, dist_input: KJTList) -> KJTList: + for i in range(len(dist_input)): + remapped_kjt = self._itep_module(dist_input[i], self._iter.item()) + dist_input[i] = remapped_kjt + return dist_input + + def compute( + self, + ctx: ITEPEmbeddingBagCollectionContext, + dist_input: KJTList, + ) -> List[torch.Tensor]: + if not ctx.is_reindexed: + dist_input = self._reindex(dist_input) + ctx.is_reindexed = True + + self._iter += 1 + return self._embedding_bag_collection.compute(ctx, dist_input) + + def output_dist( + self, + ctx: ITEPEmbeddingBagCollectionContext, + output: List[torch.Tensor], + ) -> LazyAwaitable[KeyedTensor]: + + ebc_awaitable = self._embedding_bag_collection.output_dist(ctx, output) + return ebc_awaitable + + def compute_and_output_dist( + self, ctx: ITEPEmbeddingBagCollectionContext, input: KJTList + ) -> LazyAwaitable[KeyedTensor]: + # Insert forward() function of GenericITEPModule into compute_and_output_dist() + for i in range(len(input)): + remapped_kjt = self._itep_module(input[i], self._iter.item()) + input[i] = remapped_kjt + self._iter += 1 + ebc_awaitable = self._embedding_bag_collection.compute_and_output_dist( + ctx, input + ) + return ebc_awaitable + + def create_context(self) -> ITEPEmbeddingBagCollectionContext: + return ITEPEmbeddingBagCollectionContext() + + +class ITEPEmbeddingBagCollectionSharder( + BaseEmbeddingSharder[ITEPEmbeddingBagCollection] +): + def __init__( + self, + ebc_sharder: Optional[EmbeddingBagCollectionSharder] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._ebc_sharder: EmbeddingBagCollectionSharder = ( + ebc_sharder or EmbeddingBagCollectionSharder(self.qcomm_codecs_registry) + ) + + def shard( + self, + module: ITEPEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedITEPEmbeddingBagCollection: + + # Enforce GPU for ITEPEmbeddingBagCollection + if device is None: + device = torch.device("cuda") + + return ShardedITEPEmbeddingBagCollection( + module, + params, + ebc_sharder=self._ebc_sharder, + env=env, + device=device, + ) + + def shardable_parameters( + self, module: ITEPEmbeddingBagCollection + ) -> Dict[str, torch.nn.Parameter]: + return self._ebc_sharder.shardable_parameters(module._embedding_bag_collection) + + @property + def module_type(self) -> Type[ITEPEmbeddingBagCollection]: + return ITEPEmbeddingBagCollection + + def sharding_types(self, compute_device_type: str) -> List[str]: + types = [ + ShardingType.COLUMN_WISE.value, + ShardingType.TABLE_WISE.value, + ] + return types diff --git a/torchrec/modules/itep_embedding_modules.py b/torchrec/modules/itep_embedding_modules.py new file mode 100644 index 000000000..32a8f45b5 --- /dev/null +++ b/torchrec/modules/itep_embedding_modules.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +from typing import List + +import torch +import torch.nn as nn +from torchrec.modules.embedding_configs import EmbeddingBagConfig + +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.itep_modules import GenericITEPModule + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +class ITEPEmbeddingBagCollection(nn.Module): + """ + ITEPEmbeddingBagCollection represents a EmbeddingBagCollection module and an In-Training Embedding Pruning (ITEP) module. + The inputs into the ITEP-EBC will first be modified by the ITEP module before being passed into the embedding bag collection. + Args: + embedding_bag_collection (EmbeddingBagCollection): The EmbeddingBagCollection module to lookup embeddings. + itep_module (GenericITEPModule): A single ITEP module that modifies the input features. + Example: + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=ebc, + itep_module=itep_module + ) + Note: + The forward method modifies the input features using the ITEP module before passing them to the EmbeddingBagCollection. + It also increments an internal iteration counter each time it is called. + For details of input and output types, see EmbeddingBagCollection. + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + itep_module: GenericITEPModule, + ) -> None: + super().__init__() + self._embedding_bag_collection = embedding_bag_collection + self._itep_module = itep_module + # Iteration counter for ITEP. Pinning on CPU because used for condition checking and checkpointing. + self.register_buffer( + "_iter", + torch.tensor(0, dtype=torch.int64, device=torch.device("cpu")), + ) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + """ + Forward pass for the ITEPEmbeddingBagCollection module. + The input features are first passed through the ITEP module, which modifies them. + The modified features are then passed to the EmbeddingBagCollection to get the pooled embeddings. + The internal iteration counter is incremented at each call. + Args: + features (KeyedJaggedTensor): The input features for the embedding lookup. + Returns: + KeyedTensor: The pooled embeddings from the EmbeddingBagCollection. + Note: + The iteration counter is incremented after each forward pass to keep track of the number of iterations. + """ + + features = self._itep_module(features, self._iter.item()) + pooled_embeddings = self._embedding_bag_collection(features) + self._iter += 1 + + return pooled_embeddings + + def embedding_bag_configs(self) -> List[EmbeddingBagConfig]: + return self._embedding_bag_collection.embedding_bag_configs() diff --git a/torchrec/modules/itep_modules.py b/torchrec/modules/itep_modules.py new file mode 100644 index 000000000..3dfa22ed4 --- /dev/null +++ b/torchrec/modules/itep_modules.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from torchrec.distributed.embedding_types import ShardedEmbeddingTable +from torchrec.modules.embedding_modules import reorder_inverse_indices +from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor + + +torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:intraining_embedding_pruning_gpu" +) + +logger: logging.Logger = logging.getLogger(__name__) + + +class GenericITEPModule(nn.Module): + """ + A generic Module for applying In-Training Embedding Pruning (ITEP). + This module can be hooked into the forward() of EmbeddingBagCollection. + It will prune the embedding tables during-training, by applying a remapping transform to the embedding lookup indices. + + Args: + table_name_to_unpruned_hash_sizes (Dict[str, int]): Map of table name to unpruned hash size. + lookups (List[nn.Module], optional): List of lookups in the EBC. Defaults to None. + enable_pruning (bool, optional): Enable pruning or not. Defaults to True. + pruning_interval (int, optional): Pruning interval. Defaults to 1001. + Example: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes={"table1": 1000, "table2": 2000}, + lookups=ShardedEmbeddingBagCollection._lookups, + enable_pruning=True, + pruning_interval=1001 + ) + Note: + The `lookups` argument is optional and is used in the sharded case. If not provided, the module will skip initialization for the dummy module. + The `table_name_to_unpruned_hash_sizes` argument must not be empty. It is a map of table names to their unpruned hash sizes. + """ + + def __init__( + self, + table_name_to_unpruned_hash_sizes: Dict[str, int], + lookups: Optional[List[nn.Module]] = None, + enable_pruning: bool = True, + pruning_interval: int = 1001, # Default pruning interval 1001 iterations + ) -> None: + + super(GenericITEPModule, self).__init__() + + # Construct in-training embedding pruning args + self.enable_pruning: bool = enable_pruning + self.pruning_interval: int = pruning_interval + self.lookups: Optional[List[nn.Module]] = lookups + self.table_name_to_unpruned_hash_sizes: Dict[str, int] = ( + table_name_to_unpruned_hash_sizes + ) + + # Map each feature to a physical address_lookup/row_util buffer + self.feature_table_map: Dict[str, int] = {} + self.table_name_to_idx: Dict[str, int] = {} + self.buffer_offsets_list: List[int] = [] + self.idx_to_table_name: Dict[int, str] = {} + # Prevent multi-pruning, after moving iteration counter to outside. + self.last_pruned_iter = -1 + + if self.lookups is not None: + self.init_itep_state() + else: + logger.info( + "ITEP init: no lookups provided. Skipping init for dummy module." + ) + + def print_itep_eviction_stats( + self, + pruned_indices_offsets: torch.Tensor, + pruned_indices_total_length: torch.Tensor, + cur_iter: int, + ) -> None: + table_name_to_eviction_ratio = {} + + num_buffers = len(self.buffer_offsets_list) - 1 + for buffer_idx in range(num_buffers): + pruned_start = pruned_indices_offsets[buffer_idx] + pruned_end = pruned_indices_offsets[buffer_idx + 1] + pruned_length = pruned_end - pruned_start + + if pruned_length > 0: + start = self.buffer_offsets_list[buffer_idx] + end = self.buffer_offsets_list[buffer_idx + 1] + buffer_length = end - start + assert buffer_length > 0 + eviction_ratio = pruned_length.item() / buffer_length + table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = ( + eviction_ratio + ) + + # Sort the mapping by eviction ratio in descending order + sorted_mapping = dict( + sorted( + table_name_to_eviction_ratio.items(), + key=lambda item: item[1], + reverse=True, + ) + ) + # Print the sorted mapping + logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}") + + # Calculate percentage of indiced updated/evicted during ITEP iter + pruned_indices_ratio = ( + pruned_indices_total_length / self.buffer_offsets_list[-1] + if self.buffer_offsets_list[-1] > 0 + else 0 + ) + logger.info( + f"Performed ITEP in iter {cur_iter}, evicted {pruned_indices_total_length} ({pruned_indices_ratio:%}) indices." + ) + + def get_table_hash_sizes(self, table: ShardedEmbeddingTable) -> Tuple[int, int]: + unpruned_hash_size = table.num_embeddings + + if table.name in self.table_name_to_unpruned_hash_sizes: + unpruned_hash_size = self.table_name_to_unpruned_hash_sizes[table.name] + else: + # Tables are not pruned by ITEP if table.name not in table_name_to_unpruned_hash_sizes + unpruned_hash_size = table.num_embeddings + logger.info( + f"ITEP: table {table.name} not pruned, because table name is not present in table_name_to_unpruned_hash_sizes." + ) + + return (table.num_embeddings, unpruned_hash_size) + + def create_itep_buffers( + self, + buffer_size: int, + buffer_offsets: List[int], + table_names: List[str], + emb_sizes: List[int], + ) -> None: + """ + Register ITEP specific buffer in a way that it can be accessed by torch.ops.fbgemm.init_address_lookup, and also in checkpoint is invididual + """ + # Buffers do not enter backward pass + with torch.no_grad(): + # Don't use register_buffer for buffer_offsets and emb_sizes. Because they may change as sharding plan change between preemption/resume + self.buffer_offsets = torch.tensor( + buffer_offsets, dtype=torch.int64, device=self.current_device + ) + self.emb_sizes = torch.tensor( + emb_sizes, dtype=torch.int64, device=self.current_device + ) + + self.address_lookup = torch.zeros( + buffer_size, dtype=torch.int64, device=self.current_device + ) + self.row_util = torch.zeros( + buffer_size, dtype=torch.float32, device=self.current_device + ) + + # Register buffers + for idx, table_name in enumerate(table_names): + self.register_buffer( + f"{table_name}_itp_address_lookup", + self.address_lookup[buffer_offsets[idx] : buffer_offsets[idx + 1]], + ) + self.register_buffer( + f"{table_name}_itp_row_util", + self.row_util[buffer_offsets[idx] : buffer_offsets[idx + 1]], + ) + + def init_itep_state(self) -> None: + idx = 0 + buffer_size = 0 + # Record address_lookup/row_util buffer lengths and offsets for each feature + buffer_offsets: List[int] = [0] # number of buffers + 1 + table_names: List[str] = [] # number of buffers + 1 + emb_sizes: List[int] = [] # Store embedding table sizes + self.current_device = None + + # Iterate over all tables + # pyre-ignore + for lookup in self.lookups: + for emb in lookup._emb_modules: + + emb_tables: List[ShardedEmbeddingTable] = emb._config.embedding_tables + for table in emb_tables: + + ( + pruned_hash_size, + unpruned_hash_size, + ) = self.get_table_hash_sizes(table) + + # Skip tables that are not pruned, aka pruned_hash_size == unpruned_hash_size. + if pruned_hash_size == unpruned_hash_size: + continue + + logger.info( + f"ITEP: Pruning enabled for table {table.name} with features {table.feature_names}, pruned_hash_size {pruned_hash_size} vs unpruned_hash_size {unpruned_hash_size}" + ) + + # buffer size for address_lookup and row_util + buffer_size += unpruned_hash_size + buffer_offsets.append(buffer_size) + table_names.append(table.name) + emb_sizes.append(pruned_hash_size) + + # Create feature to table mappings + for feature_name in table.feature_names: + self.feature_table_map[feature_name] = idx + + # Create table_name to buffer idx mappings + self.table_name_to_idx[table.name] = idx + self.idx_to_table_name[idx] = table.name + idx += 1 + + # Check that all features have the same device + if ( + table.local_metadata is not None + and table.local_metadata.placement is not None + ): + if self.current_device is None: + self.current_device = ( + table.local_metadata.placement.device() + ) + else: + assert ( + self.current_device + == table.local_metadata.placement.device() + ), f"Device of table {table}: {table.local_metadata.placement.device()} does not match existing device: {self.current_device}" + + if self.current_device is None: + self.current_device = torch.device("cuda") + + self.buffer_offsets_list = buffer_offsets + + # Create buffers for address_lookup and row_util + self.create_itep_buffers( + buffer_size=buffer_size, + buffer_offsets=buffer_offsets, + table_names=table_names, + emb_sizes=emb_sizes, + ) + + logger.info( + f"ITEP: done init_state with feature_table_map {self.feature_table_map} and buffer_offsets {self.buffer_offsets_list}" + ) + + # initialize address_lookup + torch.ops.fbgemm.init_address_lookup( + self.address_lookup, + self.buffer_offsets, + self.emb_sizes, + ) + + def reset_weight_momentum( + self, + pruned_indices: torch.Tensor, + pruned_indices_offsets: torch.Tensor, + ) -> None: + if self.lookups is not None: + # pyre-ignore + for lookup in self.lookups: + for emb in lookup._emb_modules: + emb_tables: List[ShardedEmbeddingTable] = ( + emb._config.embedding_tables + ) + + logical_idx = 0 + logical_table_ids = [] + buffer_ids = [] + for table in emb_tables: + name = table.name + if name in self.table_name_to_idx: + buffer_idx = self.table_name_to_idx[name] + start = pruned_indices_offsets[buffer_idx] + end = pruned_indices_offsets[buffer_idx + 1] + length = end - start + if length > 0: + logical_table_ids.append(logical_idx) + buffer_ids.append(buffer_idx) + logical_idx += table.num_features() + + if len(logical_table_ids) > 0: + emb.emb_module.reset_embedding_weight_momentum( + pruned_indices, + pruned_indices_offsets, + torch.tensor( + logical_table_ids, + dtype=torch.int32, + requires_grad=False, + ), + torch.tensor( + buffer_ids, dtype=torch.int32, requires_grad=False + ), + ) + + # Flush UVM cache after ITEP eviction to remove stale states + def flush_uvm_cache(self) -> None: + if self.lookups is not None: + # pyre-ignore + for lookup in self.lookups: + for emb in lookup._emb_modules: + emb.emb_module.flush() + emb.emb_module.reset_cache_states() + + def get_remap_info(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: + keys = features.keys() + length_per_key = features.length_per_key() + offset_per_key = features.offset_per_key() + + buffer_idx = [] + feature_lengths = [] + feature_offsets = [] + for i in range(len(keys)): + key = keys[i] + if key not in self.feature_table_map: + continue + buffer_idx.append(self.feature_table_map[key]) + feature_lengths.append(length_per_key[i]) + feature_offsets.append(offset_per_key[i]) + + return [ + torch.tensor(buffer_idx, dtype=torch.int32, device=torch.device("cpu")), + torch.tensor( + feature_lengths, dtype=torch.int64, device=torch.device("cpu") + ), + torch.tensor( + feature_offsets, dtype=torch.int64, device=torch.device("cpu") + ), + ] + + def get_full_values_list(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: + inverse_indices = features.inverse_indices() + batch_size = inverse_indices[1].numel() // len(inverse_indices[0]) + keys = features.keys() + if not all(key in self.feature_table_map for key in keys): + keys = [key for key in keys if key in self.feature_table_map] + key_indices = [features._key_indices()[key] for key in keys] + features = features.permute(key_indices) + indices = ( + inverse_indices[1] + if keys == inverse_indices[0] + else reorder_inverse_indices(inverse_indices, keys) + ) + spk_tensor = _pin_and_move( + torch.tensor(features.stride_per_key()), features.device() + ) + offset_indices = ( + indices + _to_offsets(spk_tensor)[:-1].unsqueeze(-1) + ).flatten() + full_values, full_lengths = torch.ops.fbgemm.keyed_jagged_index_select_dim1( + features.values(), + features.lengths(), + features.offsets(), + offset_indices, + features.lengths().numel(), + ) + full_lpk = torch.sum(full_lengths.view(-1, batch_size), dim=1).tolist() + return list(torch.split(full_values, full_lpk)) + + def forward( + self, + sparse_features: KeyedJaggedTensor, + cur_iter: int, + ) -> KeyedJaggedTensor: + """ + Args: + sparse_features (KeyedJaggedTensor]): input embedding lookup indices to be + remapped. + cur_iter (int): iteration counter. + + Returns: + KeyedJaggedTensor: remapped KJT + + NOTE: + We use the same forward method for sharded and non-sharded case. + """ + + if not self.enable_pruning or self.lookups is None: + return sparse_features + + num_buffers = self.buffer_offsets.size(dim=0) - 1 + if num_buffers <= 0: + return sparse_features + + start_pruning: bool = ( + (cur_iter < 10 and (cur_iter + 1) % 3 == 0) + or (cur_iter < 100 and (cur_iter + 1) % 30 == 0) + or (cur_iter < 1000 and (cur_iter + 1) % 300 == 0) + or ((cur_iter + 1) % self.pruning_interval == 0) + ) + if start_pruning and self.training and self.last_pruned_iter != cur_iter: + # Pruning function outputs the indices that need weight/momentum reset + # The indices order is by physical buffer + ( + pruned_indices, + pruned_indices_offsets, + pruned_indices_total_length, + ) = torch.ops.fbgemm.prune_embedding_tables( + cur_iter, + self.pruning_interval, + self.address_lookup, + self.row_util, + self.buffer_offsets, + self.emb_sizes, + ) + # After pruning, reset weight and momentum of pruned indices + if pruned_indices_total_length > 0 and cur_iter > self.pruning_interval: + self.reset_weight_momentum(pruned_indices, pruned_indices_offsets) + + if pruned_indices_total_length > 0: + # Flush UVM cache after every ITEP eviction (every pruning_interval iterations) + self.flush_uvm_cache() + logger.info( + f"ITEP: trying to flush UVM after ITEP eviction, {cur_iter=}" + ) + + self.last_pruned_iter = cur_iter + + # Print eviction stats + self.print_itep_eviction_stats( + pruned_indices_offsets, pruned_indices_total_length, cur_iter + ) + + ( + buffer_idx, + feature_lengths, + feature_offsets, + ) = self.get_remap_info(sparse_features) + + update_utils: bool = ( + (cur_iter < 10) + or (cur_iter < 100 and (cur_iter + 1) % 19 == 0) + or ((cur_iter + 1) % 39 == 0) + ) + full_values_list = None + if update_utils and sparse_features.variable_stride_per_key(): + if sparse_features.inverse_indices_or_none() is not None: + # full util update mode require reconstructing original input indicies from VBE input + full_values_list = self.get_full_values_list(sparse_features) + else: + logger.info( + "Switching to deduped util updating mode due to features missing inverse indices. " + f"features {list(sparse_features.keys())=} with variable stride: {sparse_features.variable_stride_per_key()}" + ) + + remapped_values = torch.ops.fbgemm.remap_indices_update_utils( + cur_iter, + buffer_idx, + feature_lengths, + feature_offsets, + sparse_features.values(), + self.address_lookup, + self.row_util, + self.buffer_offsets, + full_values_list=full_values_list, + ) + + sparse_features._values = remapped_values + + return sparse_features diff --git a/torchrec/modules/tests/test_itep_embedding_modules.py b/torchrec/modules/tests/test_itep_embedding_modules.py new file mode 100644 index 000000000..3e9bc0801 --- /dev/null +++ b/torchrec/modules/tests/test_itep_embedding_modules.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import random +import unittest +from typing import Dict, List +from unittest.mock import MagicMock, patch + +import torch +from torchrec import KeyedJaggedTensor +from torchrec.distributed.embedding_types import ShardedEmbeddingTable +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection + +from torchrec.modules.itep_embedding_modules import ITEPEmbeddingBagCollection +from torchrec.modules.itep_modules import GenericITEPModule + +MOCK_NS: str = "torchrec.modules.itep_modules" + + +class TestITEPEmbeddingBagCollection(unittest.TestCase): + # Setting up the environment for the tests. + def setUp(self) -> None: + # Embedding bag configurations for testing + embedding_bag_config1 = EmbeddingBagConfig( + name="table1", + embedding_dim=4, + num_embeddings=50, + feature_names=["feature1"], + ) + embedding_bag_config2 = EmbeddingBagConfig( + name="table2", + embedding_dim=4, + num_embeddings=40, + feature_names=["feature2"], + ) + unpruned_hash_size_1, unpruned_hash_size_2 = (100, 80) + self._table_name_to_pruned_hash_sizes = {"table1": 50, "table2": 40} + self._table_name_to_unpruned_hash_sizes = { + "table1": unpruned_hash_size_1, + "table2": unpruned_hash_size_2, + } + self._feature_name_to_unpruned_hash_sizes = { + "feature1": unpruned_hash_size_1, + "feature2": unpruned_hash_size_2, + } + self._batch_size = 8 + + # Util function for creating sharded embedding tables from embedding bag configurations. + def embedding_bag_config_to_sharded_table( + config: EmbeddingBagConfig, + ) -> ShardedEmbeddingTable: + return ShardedEmbeddingTable( + name=config.name, + embedding_dim=config.embedding_dim, + num_embeddings=config.num_embeddings, + feature_names=config.feature_names, + ) + + sharded_et1 = embedding_bag_config_to_sharded_table(embedding_bag_config1) + sharded_et2 = embedding_bag_config_to_sharded_table(embedding_bag_config2) + + # Create test ebc + self._embedding_bag_collection = EmbeddingBagCollection( + tables=[ + embedding_bag_config1, + embedding_bag_config2, + ], + device=torch.device("cuda"), + ) + + # Create a mock object for tbe lookups + self._mock_list_emb_tables = [ + sharded_et1, + sharded_et2, + ] + self._mock_lookups = [MagicMock()] + self._mock_lookups[0]._emb_modules = [MagicMock()] + self._mock_lookups[0]._emb_modules[0]._config = MagicMock() + self._mock_lookups[0]._emb_modules[ + 0 + ]._config.embedding_tables = self._mock_list_emb_tables + + def generate_input_kjt_cuda( + self, feature_name_to_unpruned_hash_sizes: Dict[str, int], use_vbe: bool = False + ) -> KeyedJaggedTensor: + keys = [] + values = [] + lengths = [] + cuda_device = torch.device("cuda") + + # Input KJT uses unpruned hash size (same as sigrid hash), and feature names + for key, unpruned_hash_size in feature_name_to_unpruned_hash_sizes.items(): + value = [] + length = [] + for _ in range(self._batch_size): + L = random.randint(0, 8) + for _ in range(L): + index = random.randint(0, unpruned_hash_size - 1) + value.append(index) + length.append(L) + keys.append(key) + values += value + lengths += length + + # generate kjt + if use_vbe: + inverse_indices_list = [] + inverse_indices = None + num_keys = len(keys) + deduped_batch_size = len(lengths) // num_keys + # Fix the number of samples after duplicate to 2x the number of + # deduplicated ones + full_batch_size = deduped_batch_size * 2 + stride_per_key_per_rank = [] + + for _ in range(num_keys): + stride_per_key_per_rank.append([deduped_batch_size]) + # Generate random inverse indices for each key + keyed_inverse_indices = torch.randint( + low=0, + high=deduped_batch_size, + size=(full_batch_size,), + dtype=torch.int32, + device=cuda_device, + ) + inverse_indices_list.append(keyed_inverse_indices) + inverse_indices = ( + keys, + torch.stack(inverse_indices_list), + ) + + input_kjt_cuda = KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.tensor( + copy.deepcopy(values), + dtype=torch.int32, + device=cuda_device, + ), + lengths=torch.tensor( + copy.deepcopy(lengths), + dtype=torch.int32, + device=cuda_device, + ), + stride_per_key_per_rank=stride_per_key_per_rank, + inverse_indices=inverse_indices, + ) + else: + input_kjt_cuda = KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.tensor( + copy.deepcopy(values), + dtype=torch.int32, + device=cuda_device, + ), + lengths=torch.tensor( + copy.deepcopy(lengths), + dtype=torch.int32, + device=cuda_device, + ), + ) + + return input_kjt_cuda + + def generate_expected_address_lookup_buffer( + self, + list_et: List[ShardedEmbeddingTable], + table_name_to_unpruned_hash_sizes: Dict[str, int], + table_name_to_pruned_hash_sizes: Dict[str, int], + ) -> torch.Tensor: + + address_lookup = [] + for et in list_et: + table_name = et.name + unpruned_hash_size = table_name_to_unpruned_hash_sizes[table_name] + pruned_hash_size = table_name_to_pruned_hash_sizes[table_name] + for idx in range(unpruned_hash_size): + if idx < pruned_hash_size: + address_lookup.append(idx) + else: + address_lookup.append(0) + + return torch.tensor(address_lookup, dtype=torch.int64) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + def test_init_itep_module(self) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=5, + ) + + # Check the address lookup and row util after initialization + expected_address_lookup = self.generate_expected_address_lookup_buffer( + self._mock_list_emb_tables, + self._table_name_to_unpruned_hash_sizes, + self._table_name_to_pruned_hash_sizes, + ) + expetec_row_util = torch.zeros( + expected_address_lookup.shape, dtype=torch.float32 + ) + torch.testing.assert_close( + expected_address_lookup, + itep_module.address_lookup.cpu(), + atol=0, + rtol=0, + equal_nan=True, + ) + torch.testing.assert_close( + expetec_row_util, + itep_module.row_util.cpu(), + atol=1.0e-5, + rtol=1.0e-5, + equal_nan=True, + ) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + def test_init_itep_module_without_pruned_table(self) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes={}, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=5, + ) + + self.assertEqual(itep_module.address_lookup.cpu().shape, torch.Size([0])) + self.assertEqual(itep_module.row_util.cpu().shape, torch.Size([0])) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + def test_train_forward(self) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=500, + ) + + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=self._embedding_bag_collection, + itep_module=itep_module, + ) + + # Test forward 2000 times + for _ in range(2000): + input_kjt = self.generate_input_kjt_cuda( + self._feature_name_to_unpruned_hash_sizes + ) + _ = itep_ebc(input_kjt) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + def test_train_forward_vbe(self) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=500, + ) + + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=self._embedding_bag_collection, + itep_module=itep_module, + ) + + # Test forward 2000 times + for _ in range(5): + input_kjt = self.generate_input_kjt_cuda( + self._feature_name_to_unpruned_hash_sizes, use_vbe=True + ) + _ = itep_ebc(input_kjt) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + # Mock out reset_weight_momentum to count calls + @patch(f"{MOCK_NS}.GenericITEPModule.reset_weight_momentum") + def test_check_pruning_schedule( + self, + mock_reset_weight_momentum: MagicMock, + ) -> None: + random.seed(1) + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=500, + ) + + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=self._embedding_bag_collection, + itep_module=itep_module, + ) + + # Test forward 2000 times + for _ in range(2000): + input_kjt = self.generate_input_kjt_cuda( + self._feature_name_to_unpruned_hash_sizes + ) + _ = itep_ebc(input_kjt) + + # Check that reset_weight_momentum is called + self.assertEqual(mock_reset_weight_momentum.call_count, 5) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + # Mock out reset_weight_momentum to count calls + @patch(f"{MOCK_NS}.GenericITEPModule.reset_weight_momentum") + def test_eval_forward( + self, + mock_reset_weight_momentum: MagicMock, + ) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=500, + ) + + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=self._embedding_bag_collection, + itep_module=itep_module, + ) + + # Set eval mode + itep_ebc.eval() + + # Test forward 2000 times + for _ in range(2000): + input_kjt = self.generate_input_kjt_cuda( + self._feature_name_to_unpruned_hash_sizes + ) + _ = itep_ebc(input_kjt) + + # Check that reset_weight_momentum is not called + self.assertEqual(mock_reset_weight_momentum.call_count, 0)