Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchrec/distributed/itep_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
lookups=self._embedding_bag_collection._lookups,
pruning_interval=module._itep_module.pruning_interval,
enable_pruning=module._itep_module.enable_pruning,
itep_logger=module._itep_module.itep_logger,
)

def prefetch(
Expand Down
60 changes: 60 additions & 0 deletions torchrec/modules/itep_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from abc import ABC, abstractmethod
from typing import Mapping, Optional, Tuple, Union

logger: logging.Logger = logging.getLogger(__name__)


class ITEPLogger(ABC):
@abstractmethod
def log_table_eviction_info(
self,
iteration: Optional[Union[bool, float, int]],
rank: Optional[int],
table_to_sizes_mapping: Mapping[str, Tuple[int, int]],
eviction_tables: Mapping[str, float],
) -> None:
pass

@abstractmethod
def log_run_info(
self,
) -> None:
pass


class ITEPLoggerDefault(ITEPLogger):
"""
noop logger as a default
"""

def __init__(
self,
) -> None:
"""
Initialize ITEPLoggerScuba.
"""
pass

def log_table_eviction_info(
self,
iteration: Optional[Union[bool, float, int]],
rank: Optional[int],
table_to_sizes_mapping: Mapping[str, Tuple[int, int]],
eviction_tables: Mapping[str, float],
) -> None:
logger.info(
f"iteration={iteration}, rank={rank}, table_to_sizes_mapping={table_to_sizes_mapping}, eviction_tables={eviction_tables}"
)

def log_run_info(
self,
) -> None:
pass
45 changes: 43 additions & 2 deletions torchrec/modules/itep_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding_types import ShardedEmbeddingTable
from torchrec.modules.embedding_modules import reorder_inverse_indices
from torchrec.modules.itep_logger import ITEPLogger, ITEPLoggerDefault

from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor

try:
Expand Down Expand Up @@ -63,8 +65,8 @@ def __init__(
lookups: Optional[List[nn.Module]] = None,
enable_pruning: bool = True,
pruning_interval: int = 1001, # Default pruning interval 1001 iterations
itep_logger: Optional[ITEPLogger] = None,
) -> None:

super(GenericITEPModule, self).__init__()

# Construct in-training embedding pruning args
Expand All @@ -75,6 +77,11 @@ def __init__(
table_name_to_unpruned_hash_sizes
)

self.itep_logger: ITEPLogger = (
itep_logger if itep_logger is not None else ITEPLoggerDefault()
)
self.itep_logger.log_run_info()

# Map each feature to a physical address_lookup/row_util buffer
self.feature_table_map: Dict[str, int] = {}
self.table_name_to_idx: Dict[str, int] = {}
Expand All @@ -97,6 +104,8 @@ def print_itep_eviction_stats(
cur_iter: int,
) -> None:
table_name_to_eviction_ratio = {}
buffer_idx_to_eviction_ratio = {}
buffer_idx_to_sizes = {}

num_buffers = len(self.buffer_offsets_list) - 1
for buffer_idx in range(num_buffers):
Expand All @@ -113,6 +122,8 @@ def print_itep_eviction_stats(
table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = (
eviction_ratio
)
buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio
buffer_idx_to_sizes[buffer_idx] = (pruned_length.item(), buffer_length)

# Sort the mapping by eviction ratio in descending order
sorted_mapping = dict(
Expand All @@ -122,6 +133,34 @@ def print_itep_eviction_stats(
reverse=True,
)
)

logged_eviction_mapping = {}
for idx in sorted_mapping.keys():
try:
logged_eviction_mapping[self.reversed_feature_table_map[idx]] = (
sorted_mapping[idx]
)
except KeyError:
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
pass

table_to_sizes_mapping = {}
for idx in buffer_idx_to_sizes.keys():
try:
table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = (
buffer_idx_to_sizes[idx]
)
except KeyError:
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
pass

self.itep_logger.log_table_eviction_info(
iteration=None,
rank=None,
table_to_sizes_mapping=table_to_sizes_mapping,
eviction_tables=logged_eviction_mapping,
)

# Print the sorted mapping
logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}")

Expand Down Expand Up @@ -263,8 +302,10 @@ def init_itep_state(self) -> None:
if self.current_device is None:
self.current_device = torch.device("cuda")

self.reversed_feature_table_map: Dict[int, str] = {
idx: feature_name for feature_name, idx in self.feature_table_map.items()
}
self.buffer_offsets_list = buffer_offsets

# Create buffers for address_lookup and row_util
self.create_itep_buffers(
buffer_size=buffer_size,
Expand Down
Loading