Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 266 additions & 4 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import copy
import random
from dataclasses import dataclass
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -239,10 +240,16 @@ def _validate_pooling_factor(
else None
)

global_float = torch.rand(
(batch_size * world_size, num_float_features), device=device
)
global_label = torch.rand(batch_size * world_size, device=device)
if randomize_indices:
global_float = torch.rand(
(batch_size * world_size, num_float_features), device=device
)
global_label = torch.rand(batch_size * world_size, device=device)
else:
global_float = torch.zeros(
(batch_size * world_size, num_float_features), device=device
)
global_label = torch.zeros(batch_size * world_size, device=device)

# Split global batch into local batches.
local_inputs = []
Expand Down Expand Up @@ -939,6 +946,7 @@ def __init__(
max_feature_lengths_list: Optional[List[Dict[str, int]]] = None,
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
over_arch_clazz: Type[nn.Module] = TestOverArch,
preproc_module: Optional[nn.Module] = None,
) -> None:
super().__init__(
tables=cast(List[BaseEmbeddingConfig], tables),
Expand All @@ -960,13 +968,22 @@ def __init__(
embedding_names = (
list(embedding_groups.values())[0] if embedding_groups else None
)
self._embedding_names: List[str] = (
embedding_names
if embedding_names
else [feature for table in tables for feature in table.feature_names]
)
self._weighted_features: List[str] = [
feature for table in weighted_tables for feature in table.feature_names
]
self.over: nn.Module = over_arch_clazz(
tables, weighted_tables, embedding_names, dense_device
)
self.register_buffer(
"dummy_ones",
torch.ones(1, device=dense_device),
)
self.preproc_module = preproc_module

def sparse_forward(self, input: ModelInput) -> KeyedTensor:
return self.sparse(
Expand All @@ -993,6 +1010,8 @@ def forward(
self,
input: ModelInput,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.preproc_module:
input = self.preproc_module(input)
return self.dense_forward(input, self.sparse_forward(input))


Expand Down Expand Up @@ -1409,3 +1428,246 @@ def _post_ebc_test_wrap_function(kt: KeyedTensor) -> KeyedTensor:
continue

return kt


class TestPreprocNonWeighted(nn.Module):
"""
Basic module for testing

Args: None
Examples:
>>> TestPreprocNonWeighted()
Returns:
List[KeyedJaggedTensor]
"""

def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
"""
Selects 3 features from a specific KJT
"""
# split
jt_0 = kjt["feature_0"]
jt_1 = kjt["feature_1"]
jt_2 = kjt["feature_2"]

# merge only features 0,1,2, removing feature 3
return [
KeyedJaggedTensor.from_jt_dict(
{
"feature_0": jt_0,
"feature_1": jt_1,
"feature_2": jt_2,
}
)
]


class TestPreprocWeighted(nn.Module):
"""
Basic module for testing

Args: None
Examples:
>>> TestPreprocWeighted()
Returns:
List[KeyedJaggedTensor]
"""

def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
"""
Selects 1 feature from specific weighted KJT
"""

# split
jt_0 = kjt["weighted_feature_0"]

# keep only weighted_feature_0
return [
KeyedJaggedTensor.from_jt_dict(
{
"weighted_feature_0": jt_0,
}
)
]


class TestModelWithPreproc(nn.Module):
"""
Basic module with up to 3 preproc modules:
- preproc on idlist_features for non-weighted EBC
- preproc on idscore_features for weighted EBC
- optional preproc on model input shared by both EBCs

Args:
tables,
weighted_tables,
device,
preproc_module,
num_float_features,
run_preproc_inline,

Example:
>>> TestModelWithPreproc(tables, weighted_tables, device)

Returns:
Tuple[torch.Tensor, torch.Tensor]
"""

def __init__(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
device: torch.device,
preproc_module: Optional[nn.Module] = None,
num_float_features: int = 10,
run_preproc_inline: bool = False,
) -> None:
super().__init__()
self.dense = TestDenseArch(num_float_features, device)

self.ebc: EmbeddingBagCollection = EmbeddingBagCollection(
tables=tables,
device=device,
)
self.weighted_ebc = EmbeddingBagCollection(
tables=weighted_tables,
is_weighted=True,
device=device,
)
self.preproc_nonweighted = TestPreprocNonWeighted()
self.preproc_weighted = TestPreprocWeighted()
self._preproc_module = preproc_module
self._run_preproc_inline = run_preproc_inline

def forward(
self,
input: ModelInput,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Runs preprco for EBC and weighted EBC, optionally runs preproc for input

Args:
input
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""
modified_input = input

if self._preproc_module is not None:
modified_input = self._preproc_module(modified_input)
elif self._run_preproc_inline:
modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
modified_input.idlist_features.keys(),
modified_input.idlist_features.values(),
modified_input.idlist_features.lengths(),
)

modified_idlist_features = self.preproc_nonweighted(
modified_input.idlist_features
)
modified_idscore_features = self.preproc_weighted(
modified_input.idscore_features
)
ebc_out = self.ebc(modified_idlist_features[0])
weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0])

pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1)
return pred.sum(), pred


class TestNegSamplingModule(torch.nn.Module):
"""
Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing

Args:
extra_input
has_params

Example:
>>> preproc = TestNegSamplingModule(extra_input)
>>> out = preproc(in)

Returns:
ModelInput
"""

def __init__(
self,
extra_input: ModelInput,
has_params: bool = False,
) -> None:
super().__init__()
self._extra_input = extra_input
if has_params:
self._linear: nn.Module = nn.Linear(30, 30)

def forward(self, input: ModelInput) -> ModelInput:
"""
Appends extra features to model input

Args:
input
Returns:
ModelInput
"""

# merge extra input
modified_input = copy.deepcopy(input)

# dim=0 (batch dimensions) increases by self._extra_input.float_features.shape[0]
modified_input.float_features = torch.concat(
(modified_input.float_features, self._extra_input.float_features), dim=0
)

# stride will be same but features will be joined
modified_input.idlist_features = KeyedJaggedTensor.concat(
[modified_input.idlist_features, self._extra_input.idlist_features]
)
if self._extra_input.idscore_features is not None:
# stride will be smae but features will be joined
modified_input.idscore_features = KeyedJaggedTensor.concat(
# pyre-ignore
[modified_input.idscore_features, self._extra_input.idscore_features]
)

# dim=0 (batch dimensions) increases by self._extra_input.input_label.shape[0]
modified_input.label = torch.concat(
(modified_input.label, self._extra_input.label), dim=0
)

return modified_input


class TestPositionWeightedPreprocModule(torch.nn.Module):
"""
Basic module for testing

Args: None
Example:
>>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
>>> out = preproc(in)
Returns:
ModelInput
"""

def __init__(
self, max_feature_lengths: Dict[str, int], device: torch.device
) -> None:
super().__init__()
self.fp_proc = PositionWeightedProcessor(
max_feature_lengths=max_feature_lengths,
device=device,
)

def forward(self, input: ModelInput) -> ModelInput:
"""
Runs PositionWeightedProcessor

Args:
input
Returns:
ModelInput
"""
modified_input = copy.deepcopy(input)
modified_input.idlist_features = self.fp_proc(modified_input.idlist_features)
return modified_input
Loading