Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,12 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no
"""
This provides consistency between this class and the EmbeddingBagCollection's
nn.Module API calls (state_dict, named_modules, etc)

Args:
skip_registering (bool): If True, skips registering state_dict hooks. This is useful
for dynamic sharding where the state_dict hooks do not need to be
reregistered when being resharded. Default is False.

"""
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
for table_name in self._table_names:
Expand Down
32 changes: 27 additions & 5 deletions torchrec/distributed/test_utils/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":
)

def record_stream(self, stream: torch.Stream) -> None:
"""
need to explicitly call `record_stream` for non-pytorch native object (KJT)
"""
self.float_features.record_stream(stream)
if isinstance(self.idlist_features, KeyedJaggedTensor):
self.idlist_features.record_stream(stream)
Expand Down Expand Up @@ -204,7 +207,7 @@ def generate_local_batches(
all_zeros: bool = False,
) -> List["ModelInput"]:
"""
Returns multi-rank batches of world_size
Returns multi-rank batches (ModelInput) of world_size
"""
return [
cls.generate(
Expand Down Expand Up @@ -255,15 +258,15 @@ def generate(
all_zeros: bool = False,
) -> "ModelInput":
"""
Returns a single batch
Returns a single batch of `ModelInput`
"""
float_features = (
torch.zeros((batch_size, num_float_features), device=device)
if all_zeros
else torch.rand((batch_size, num_float_features), device=device)
)
idlist_features = (
ModelInput._create_standard_kjt(
ModelInput.create_standard_kjt(
batch_size=batch_size,
tables=tables,
pooling_avg=pooling_avg,
Expand All @@ -281,7 +284,7 @@ def generate(
else None
)
idscore_features = (
ModelInput._create_standard_kjt(
ModelInput.create_standard_kjt(
batch_size=batch_size,
tables=weighted_tables,
pooling_avg=pooling_avg,
Expand Down Expand Up @@ -324,6 +327,13 @@ def _create_features_lengths_indices(
lengths_dtype: torch.dtype = torch.int64,
all_zeros: bool = False,
) -> Tuple[List[str], List[torch.Tensor], List[torch.Tensor]]:
"""
Create keys, lengths, and indices for a KeyedJaggedTensor from embedding table configs.

Returns:
Tuple[List[str], List[torch.Tensor], List[torch.Tensor]]:
Feature names, per-feature lengths, and per-feature indices.
"""
pooling_factor_per_feature: List[int] = []
num_embeddings_per_feature: List[int] = []
max_length_per_feature: List[Optional[int]] = []
Expand Down Expand Up @@ -395,6 +405,14 @@ def _assemble_kjt(
use_offsets: bool = False,
offsets_dtype: torch.dtype = torch.int64,
) -> KeyedJaggedTensor:
"""

Assembles a KeyedJaggedTensor (KJT) from the provided per-feature lengths and indices.

This method is used to generate corresponding local_batches and global_batch KJTs.
It concatenates the lengths and indices for each feature to form a complete KJT.
"""

lengths = torch.cat(lengths_per_feature)
indices = torch.cat(indices_per_feature)
offsets = None
Expand All @@ -407,7 +425,7 @@ def _assemble_kjt(
return KeyedJaggedTensor(features, indices, weights, lengths, offsets)

@staticmethod
def _create_standard_kjt(
def create_standard_kjt(
batch_size: int,
tables: Union[
List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig]
Expand Down Expand Up @@ -464,6 +482,10 @@ def _create_batched_standard_kjts(
lengths_dtype: torch.dtype = torch.int64,
all_zeros: bool = False,
) -> Tuple[KeyedJaggedTensor, List[KeyedJaggedTensor]]:
"""
generate a global KJT and corresponding per-rank KJTs, the data are the same
so that they can be used for result comparison.
"""
data_per_rank = [
ModelInput._create_features_lengths_indices(
batch_size,
Expand Down
55 changes: 10 additions & 45 deletions torchrec/distributed/tests/test_dynamic_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MultiProcessContext,
MultiProcessTestBase,
)
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.distributed.test_utils.test_sharding import copy_state_dict

from torchrec.distributed.types import (
Expand All @@ -58,46 +59,6 @@ def feature_name(i: int) -> str:
return "feature_" + str(i)


def generate_input_by_world_size(
world_size: int,
num_tables: int,
num_embeddings: int = 4,
max_mul: int = 3,
) -> List[KeyedJaggedTensor]:
# TODO merge with new ModelInput generator in TestUtils
kjt_input_per_rank = []
mul = random.randint(1, max_mul)
total_size = num_tables * mul

for _ in range(world_size):
feature_names = [feature_name(i) for i in range(num_tables)]
lengths = []
values = []
counting_l = 0
for i in range(total_size):
if i == total_size - 1:
lengths.append(total_size - counting_l)
break
next_l = random.randint(0, total_size - counting_l)
values.extend(
[random.randint(0, num_embeddings - 1) for _ in range(next_l)]
)
lengths.append(next_l)
counting_l += next_l

# for length in lengths:

kjt_input_per_rank.append(
KeyedJaggedTensor.from_lengths_sync(
keys=feature_names,
values=torch.LongTensor(values),
lengths=torch.LongTensor(lengths),
)
)

return kjt_input_per_rank


def generate_embedding_bag_config(
data_type: DataType,
num_tables: int = 3,
Expand Down Expand Up @@ -372,9 +333,13 @@ def _run_ebc_resharding_test(
):
return

kjt_input_per_rank = generate_input_by_world_size(
world_size, num_tables, num_embeddings
)
kjt_input_per_rank = [
ModelInput.create_standard_kjt(
batch_size=2,
tables=embedding_bag_config,
)
for _ in range(world_size)
]

# initial_state_dict filled with deterministic dummy values
initial_state_dict = create_test_initial_state_dict(
Expand Down Expand Up @@ -418,8 +383,8 @@ def test_dynamic_sharding_ebc_tw(
old_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)]
new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)]

if new_ranks == old_ranks:
return
while new_ranks == old_ranks:
new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)]
per_param_sharding = {}
new_per_param_sharding = {}

Expand Down
Loading