diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 746708355..96b209f9a 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -258,15 +258,15 @@ def test_dynamic_shape_ebc(self) -> None: feature2 = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2", "f3"], - values=torch.tensor([0, 1, 2, 3, 2, 3, 4]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]), + values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 8, 1, 2]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7, 8, 10, 12]), ) eager_out = model(feature2) # Serialize EBC - collection = mark_dynamic_kjt(feature1) + collection = mark_dynamic_kjt(feature1, variable_length=True) model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) - ep = torch.export.export( + ep = torch.export._trace._export( model, (feature1,), {}, @@ -274,6 +274,7 @@ def test_dynamic_shape_ebc(self) -> None: strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP preserve_module_call_signature=tuple(sparse_fqns), + _allow_complex_guards_as_runtime_asserts=True, ) # Run forward on ExportedProgram @@ -351,6 +352,54 @@ def test_deserialized_device(self) -> None: continue assert param.device.type == device.type, f"{name} should be on {device}" + def test_deserialize_device_kt_regroup(self) -> None: + class Model(nn.Module): + def __init__(self, ebc): + super().__init__() + self.ebc = ebc + + def forward( + self, + features: KeyedJaggedTensor, + ) -> List[torch.Tensor]: + kt = self.ebc(features) + return KeyedTensor.regroup([kt], [[key] for key in kt.keys()]) + + model = self.generate_model() + model = Model(model.ebc1) + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), + ) + eager_out = model(id_list_features) + + # Serialize EBC + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + unflatten_model = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules( + unflatten_model, JsonSerializer, torch.device("cuda") + ) + device = torch.device("cuda") + deserialized_model.to(device) + id_list_features = id_list_features.to(device) + + deserialized_model.load_state_dict(model.state_dict()) + # Run forward on deserialized model + deserialized_out = deserialized_model(id_list_features) + + for i, tensor in enumerate(deserialized_out): + assert eager_out[i].shape == tensor.shape + assert torch.allclose(eager_out[i].to(tensor), tensor) + def test_compound_module(self) -> None: tb1_config = EmbeddingBagConfig( name="t1", diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index b5fe73002..d0fed8609 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -58,6 +58,22 @@ def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: ) +@torch.library.custom_op("torchrec::_pin_and_move_op_int", mutates_args=()) +def _pin_and_move_op_int_impl(data: List[int], other: torch.Tensor) -> torch.Tensor: + tensor = torch.Tensor(data) + device = other.device + return ( + tensor.pin_memory().to(device=device, non_blocking=True) + if device.type == "cuda" and tensor.device.type == "cpu" + else tensor.to(device=device, non_blocking=True) + ) + + +@torch.library.register_fake("torchrec::_pin_and_move_op_int") +def _pin_and_move_op_int_fake(data: List[int], other: torch.Tensor) -> torch.Tensor: + return torch.empty(len(data), dtype=torch.int64, device=other.device) + + def _cumsum(o: List[int]) -> List[int]: ret = [0] * (len(o) + 1) for i in range(len(o)): @@ -169,17 +185,22 @@ def _fbgemm_permute_pooled_embs( keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] ) -> List[torch.Tensor]: keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) - permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( + permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups_list( keys, lengths, groups ) + args = torch.ops.torchrec._pin_and_move_op_int( + permute + inv_permute + offsets + inv_offsets, values[0] + ) + permute, inv_permute, offsets, inv_offsets = args.split( + [len(permute), len(permute), len(offsets), len(offsets)] + ) values = torch.concat(values, dim=1) - device = values.device permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( values, - _pin_and_move(offsets, device), - _pin_and_move(permute, device), - _pin_and_move(inv_offsets, device), - _pin_and_move(inv_permute, device), + offsets, + permute, + inv_offsets, + inv_permute, ) return list(torch.split(permuted_values, splits, dim=1)) @@ -198,6 +219,42 @@ def _desugar_keyed_tensors( ) +def _remap_to_groups_list( + keys: List[List[str]], + key_lengths: List[List[int]], + groups: List[List[str]], +) -> Tuple[List[int], List[int], List[int], List[int], List[int]]: + """ + Given a list of keys and lengths per key for each group, return the permute indices, inverse_permute indices, offsets, inv_offsets, splits. + The output is used to re-arrange values based on groups with a single cat operation. + """ + + lengths: List[int] = [] + flat_keys: List[str] = [] + flat_groups: List[str] = [] + + for sub_keys_length in key_lengths: + lengths.extend(sub_keys_length) + for sub_keys in keys: + flat_keys.extend(sub_keys) + + for sub_group in groups: + flat_groups.extend(sub_group) + + key_splits = [len(sub_group) for sub_group in groups] + + index_map = {key: idx for idx, key in enumerate(flat_keys)} + permute = [index_map[key] for key in flat_groups] + inv_lengths = [lengths[i] for i in permute] + splits = _sum_by_splits(inv_lengths, key_splits) + + inv_permute = [0] * len(permute) + for i, p in enumerate(permute): + inv_permute[p] = i + + return permute, inv_permute, _cumsum(lengths), _cumsum(inv_lengths), splits + + @torch.fx.wrap def _remap_to_groups( keys: List[List[str]],