Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,22 +258,23 @@ def test_dynamic_shape_ebc(self) -> None:

feature2 = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=torch.tensor([0, 1, 2, 3, 2, 3, 4]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 8, 1, 2]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7, 8, 10, 12]),
)
eager_out = model(feature2)

# Serialize EBC
collection = mark_dynamic_kjt(feature1)
collection = mark_dynamic_kjt(feature1, variable_length=True)
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
ep = torch.export._trace._export(
model,
(feature1,),
{},
dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=tuple(sparse_fqns),
_allow_complex_guards_as_runtime_asserts=True,
)

# Run forward on ExportedProgram
Expand Down Expand Up @@ -351,6 +352,54 @@ def test_deserialized_device(self) -> None:
continue
assert param.device.type == device.type, f"{name} should be on {device}"

def test_deserialize_device_kt_regroup(self) -> None:
class Model(nn.Module):
def __init__(self, ebc):
super().__init__()
self.ebc = ebc

def forward(
self,
features: KeyedJaggedTensor,
) -> List[torch.Tensor]:
kt = self.ebc(features)
return KeyedTensor.regroup([kt], [[key] for key in kt.keys()])

model = self.generate_model()
model = Model(model.ebc1)
id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=torch.tensor([0, 1, 2, 3, 2, 3]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]),
)
eager_out = model(id_list_features)

# Serialize EBC
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)
unflatten_model = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(
unflatten_model, JsonSerializer, torch.device("cuda")
)
device = torch.device("cuda")
deserialized_model.to(device)
id_list_features = id_list_features.to(device)

deserialized_model.load_state_dict(model.state_dict())
# Run forward on deserialized model
deserialized_out = deserialized_model(id_list_features)

for i, tensor in enumerate(deserialized_out):
assert eager_out[i].shape == tensor.shape
assert torch.allclose(eager_out[i].to(tensor), tensor)

def test_compound_module(self) -> None:
tb1_config = EmbeddingBagConfig(
name="t1",
Expand Down
69 changes: 63 additions & 6 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
)


@torch.library.custom_op("torchrec::_pin_and_move_op_int", mutates_args=())
def _pin_and_move_op_int_impl(data: List[int], other: torch.Tensor) -> torch.Tensor:
tensor = torch.Tensor(data)
device = other.device
return (
tensor.pin_memory().to(device=device, non_blocking=True)
if device.type == "cuda" and tensor.device.type == "cpu"
else tensor.to(device=device, non_blocking=True)
)


@torch.library.register_fake("torchrec::_pin_and_move_op_int")
def _pin_and_move_op_int_fake(data: List[int], other: torch.Tensor) -> torch.Tensor:
return torch.empty(len(data), dtype=torch.int64, device=other.device)


def _cumsum(o: List[int]) -> List[int]:
ret = [0] * (len(o) + 1)
for i in range(len(o)):
Expand Down Expand Up @@ -169,17 +185,22 @@ def _fbgemm_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups_list(
keys, lengths, groups
)
args = torch.ops.torchrec._pin_and_move_op_int(
permute + inv_permute + offsets + inv_offsets, values[0]
)
permute, inv_permute, offsets, inv_offsets = args.split(
[len(permute), len(permute), len(offsets), len(offsets)]
)
values = torch.concat(values, dim=1)
device = values.device
permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
values,
_pin_and_move(offsets, device),
_pin_and_move(permute, device),
_pin_and_move(inv_offsets, device),
_pin_and_move(inv_permute, device),
offsets,
permute,
inv_offsets,
inv_permute,
)
return list(torch.split(permuted_values, splits, dim=1))

Expand All @@ -198,6 +219,42 @@ def _desugar_keyed_tensors(
)


def _remap_to_groups_list(
keys: List[List[str]],
key_lengths: List[List[int]],
groups: List[List[str]],
) -> Tuple[List[int], List[int], List[int], List[int], List[int]]:
"""
Given a list of keys and lengths per key for each group, return the permute indices, inverse_permute indices, offsets, inv_offsets, splits.
The output is used to re-arrange values based on groups with a single cat operation.
"""

lengths: List[int] = []
flat_keys: List[str] = []
flat_groups: List[str] = []

for sub_keys_length in key_lengths:
lengths.extend(sub_keys_length)
for sub_keys in keys:
flat_keys.extend(sub_keys)

for sub_group in groups:
flat_groups.extend(sub_group)

key_splits = [len(sub_group) for sub_group in groups]

index_map = {key: idx for idx, key in enumerate(flat_keys)}
permute = [index_map[key] for key in flat_groups]
inv_lengths = [lengths[i] for i in permute]
splits = _sum_by_splits(inv_lengths, key_splits)

inv_permute = [0] * len(permute)
for i, p in enumerate(permute):
inv_permute[p] = i

return permute, inv_permute, _cumsum(lengths), _cumsum(inv_lengths), splits


@torch.fx.wrap
def _remap_to_groups(
keys: List[List[str]],
Expand Down