Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, torch.Tensor]:
return self.regroup([self.ebc(features)])
return self.regroup(keyed_tensors=[self.ebc(features)])

class myModel(nn.Module):
def __init__(self, ebc, regroup):
Expand Down Expand Up @@ -813,6 +813,96 @@ def forward(
for key in eager_out.keys():
torch.testing.assert_close(deserialized_out[key], eager_out[key])

def test_key_order_with_ebc_and_regroup_input_kwargs(self) -> None:
tb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1"],
)
tb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f2"],
)
tb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=5,
num_embeddings=10,
feature_names=["f3"],
)
id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3", "f4", "f5"],
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]),
offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]),
)
ebc1 = EmbeddingBagCollection(
tables=[tb1_config, tb2_config, tb3_config],
is_weighted=False,
)
ebc2 = EmbeddingBagCollection(
tables=[tb1_config, tb3_config, tb2_config],
is_weighted=False,
)
ebc2.load_state_dict(ebc1.state_dict())
regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"])

class mySparse(nn.Module):
def __init__(self, ebc):
super().__init__()
self.ebc = ebc

def forward(
self,
features: KeyedJaggedTensor,
) -> KeyedTensor:
return self.ebc(features)

class myModel(nn.Module):
def __init__(self, ebc, regroup):
super().__init__()
self.regroup = regroup
self.sparse = mySparse(ebc)

def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, torch.Tensor]:
sparse_out = self.sparse(features)
return self.regroup(keyed_tensors=[sparse_out])

model = myModel(ebc1, regroup)
eager_out = model(id_list_features)

model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(
unflatten_ep,
JsonSerializer,
short_circuit_pytree_ebc_regroup=True,
finalize_interpreter_modules=True,
)

# we export the model with ebc1 and unflatten the model,
# and then swap with ebc2 (you can think this as the the sharding process
# resulting a shardedEBC), so that we can mimic the key-order change
# pyre-fixme[16]: `Module` has no attribute `ebc`.
# pyre-fixme[16]: `Tensor` has no attribute `ebc`.
deserialized_model.sparse.ebc = ebc2

deserialized_out = deserialized_model(id_list_features)
for key in eager_out.keys():
torch.testing.assert_close(deserialized_out[key], eager_out[key])

def test_cast_in_regroup(self) -> None:
class Model(nn.Module):
def __init__(self, ebc, fpebc, regroup):
Expand Down
20 changes: 14 additions & 6 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import logging
import operator
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import cast, Dict, List, Optional, Tuple, Type, Union

import torch

Expand Down Expand Up @@ -370,25 +370,33 @@ def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]:
# remove tree_unflatten from the in_fqns (in-coming nodes)
for fqn in in_fqns:
submodule, node = _get_graph_node(module, fqn)
assert len(node.args) == 1
getitem_getitem: Node = node.args[0] # pyre-ignore[9]
# kt_regroup node will have either one arg or one kwarg
assert len(node.args) == 1 or len(node.kwargs) == 1
use_args = len(node.args) == 1

getitem_getitem = cast(
Node, node.args[0] if use_args else list(node.kwargs.values())[0]
)
assert (
getitem_getitem.op == "call_function"
and getitem_getitem.target == operator.getitem
)
tree_unflatten_getitem = node.args[0].args[0] # pyre-ignore[16]
tree_unflatten_getitem = cast(Node, getitem_getitem.args[0])
assert (
tree_unflatten_getitem.op == "call_function"
and tree_unflatten_getitem.target == operator.getitem
)
tree_unflatten = tree_unflatten_getitem.args[0]
tree_unflatten = cast(Node, tree_unflatten_getitem.args[0])
assert (
tree_unflatten.op == "call_function"
and tree_unflatten.target == torch.utils._pytree.tree_unflatten
)
logger.info(f"Removing tree_unflatten from {fqn}")
input_nodes = tree_unflatten.args[0]
node.args = (input_nodes,)
if use_args:
node.args = (input_nodes,)
else:
node.kwargs = {list(node.kwargs.keys())[0]: input_nodes}
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `eliminate_dead_code`.
submodule.graph.eliminate_dead_code()
Expand Down
Loading