From a05d41204738b2de1e7ac0e3301b663e63f19d67 Mon Sep 17 00:00:00 2001 From: Peter Fu Date: Mon, 10 Feb 2025 11:48:45 -0800 Subject: [PATCH] Refactor itep logger Summary: put interface in the torchrec module and scuba impl in fbgemm, because it seems all GenericITEPModule needs a logger but only the references of GenericITEPModule in fbgemm module use scuba. Differential Revision: D69308574 --- torchrec/distributed/itep_embeddingbag.py | 1 + torchrec/modules/itep_logger.py | 60 +++++++++++++++++++++++ torchrec/modules/itep_modules.py | 45 ++++++++++++++++- 3 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 torchrec/modules/itep_logger.py diff --git a/torchrec/distributed/itep_embeddingbag.py b/torchrec/distributed/itep_embeddingbag.py index d8daa4bb3..8646eca30 100644 --- a/torchrec/distributed/itep_embeddingbag.py +++ b/torchrec/distributed/itep_embeddingbag.py @@ -81,6 +81,7 @@ def __init__( lookups=self._embedding_bag_collection._lookups, pruning_interval=module._itep_module.pruning_interval, enable_pruning=module._itep_module.enable_pruning, + itep_logger=module._itep_module.itep_logger, ) def prefetch( diff --git a/torchrec/modules/itep_logger.py b/torchrec/modules/itep_logger.py new file mode 100644 index 000000000..fa729488a --- /dev/null +++ b/torchrec/modules/itep_logger.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from abc import ABC, abstractmethod +from typing import Mapping, Optional, Tuple, Union + +logger: logging.Logger = logging.getLogger(__name__) + + +class ITEPLogger(ABC): + @abstractmethod + def log_table_eviction_info( + self, + iteration: Optional[Union[bool, float, int]], + rank: Optional[int], + table_to_sizes_mapping: Mapping[str, Tuple[int, int]], + eviction_tables: Mapping[str, float], + ) -> None: + pass + + @abstractmethod + def log_run_info( + self, + ) -> None: + pass + + +class ITEPLoggerDefault(ITEPLogger): + """ + noop logger as a default + """ + + def __init__( + self, + ) -> None: + """ + Initialize ITEPLoggerScuba. + """ + pass + + def log_table_eviction_info( + self, + iteration: Optional[Union[bool, float, int]], + rank: Optional[int], + table_to_sizes_mapping: Mapping[str, Tuple[int, int]], + eviction_tables: Mapping[str, float], + ) -> None: + logger.info( + f"iteration={iteration}, rank={rank}, table_to_sizes_mapping={table_to_sizes_mapping}, eviction_tables={eviction_tables}" + ) + + def log_run_info( + self, + ) -> None: + pass diff --git a/torchrec/modules/itep_modules.py b/torchrec/modules/itep_modules.py index 4e0002ca8..0e74baf3b 100644 --- a/torchrec/modules/itep_modules.py +++ b/torchrec/modules/itep_modules.py @@ -15,6 +15,8 @@ from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_types import ShardedEmbeddingTable from torchrec.modules.embedding_modules import reorder_inverse_indices +from torchrec.modules.itep_logger import ITEPLogger, ITEPLoggerDefault + from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor try: @@ -63,8 +65,8 @@ def __init__( lookups: Optional[List[nn.Module]] = None, enable_pruning: bool = True, pruning_interval: int = 1001, # Default pruning interval 1001 iterations + itep_logger: Optional[ITEPLogger] = None, ) -> None: - super(GenericITEPModule, self).__init__() # Construct in-training embedding pruning args @@ -75,6 +77,11 @@ def __init__( table_name_to_unpruned_hash_sizes ) + self.itep_logger: ITEPLogger = ( + itep_logger if itep_logger is not None else ITEPLoggerDefault() + ) + self.itep_logger.log_run_info() + # Map each feature to a physical address_lookup/row_util buffer self.feature_table_map: Dict[str, int] = {} self.table_name_to_idx: Dict[str, int] = {} @@ -97,6 +104,8 @@ def print_itep_eviction_stats( cur_iter: int, ) -> None: table_name_to_eviction_ratio = {} + buffer_idx_to_eviction_ratio = {} + buffer_idx_to_sizes = {} num_buffers = len(self.buffer_offsets_list) - 1 for buffer_idx in range(num_buffers): @@ -113,6 +122,8 @@ def print_itep_eviction_stats( table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = ( eviction_ratio ) + buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio + buffer_idx_to_sizes[buffer_idx] = (pruned_length.item(), buffer_length) # Sort the mapping by eviction ratio in descending order sorted_mapping = dict( @@ -122,6 +133,34 @@ def print_itep_eviction_stats( reverse=True, ) ) + + logged_eviction_mapping = {} + for idx in sorted_mapping.keys(): + try: + logged_eviction_mapping[self.reversed_feature_table_map[idx]] = ( + sorted_mapping[idx] + ) + except KeyError: + # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map + pass + + table_to_sizes_mapping = {} + for idx in buffer_idx_to_sizes.keys(): + try: + table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = ( + buffer_idx_to_sizes[idx] + ) + except KeyError: + # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map + pass + + self.itep_logger.log_table_eviction_info( + iteration=None, + rank=None, + table_to_sizes_mapping=table_to_sizes_mapping, + eviction_tables=logged_eviction_mapping, + ) + # Print the sorted mapping logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}") @@ -263,8 +302,10 @@ def init_itep_state(self) -> None: if self.current_device is None: self.current_device = torch.device("cuda") + self.reversed_feature_table_map: Dict[int, str] = { + idx: feature_name for feature_name, idx in self.feature_table_map.items() + } self.buffer_offsets_list = buffer_offsets - # Create buffers for address_lookup and row_util self.create_itep_buffers( buffer_size=buffer_size,