From 818ea937ac4b5979dacd4ddcf1f4a6c371348267 Mon Sep 17 00:00:00 2001 From: James Dong Date: Thu, 10 Jul 2025 13:38:42 -0700 Subject: [PATCH 1/2] Add missing fields to KJT's PyTree flatten/unflatten logic for VBE KJT (#2952) Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Reviewed By: TroyGarden Differential Revision: D74295924 --- torchrec/ir/tests/test_serializer.py | 80 ++++++++++++++++------------ torchrec/sparse/jagged_tensor.py | 38 +++++++++++-- 2 files changed, 79 insertions(+), 39 deletions(-) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 5b582902b..561c67cd5 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -207,8 +207,14 @@ def forward( num_embeddings=10, feature_names=["f2"], ) + config3 = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) ebc = EmbeddingBagCollection( - tables=[config1, config2], + tables=[config1, config2, config3], is_weighted=False, ) @@ -293,24 +299,37 @@ def test_serialize_deserialize_ebc(self) -> None: self.assertEqual(deserialized.shape, orginal.shape) self.assertTrue(torch.allclose(deserialized, orginal)) - @unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.") def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None: model = self.generate_model_for_vbe_kjt() - id_list_features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), - lengths=torch.tensor([3, 3, 2]), - stride_per_key_per_rank=[[2], [1]], - inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])), + kjt_1 = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + lengths=torch.tensor([1, 2, 3, 2, 1, 1]), + stride_per_key_per_rank=torch.tensor([[3], [2], [1]]), + inverse_indices=( + ["f1", "f2", "f3"], + torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]), + ), + ) + kjt_2 = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + lengths=torch.tensor([1, 2, 3, 2, 1, 1]), + stride_per_key_per_rank=torch.tensor([[1], [2], [3]]), + inverse_indices=( + ["f1", "f2", "f3"], + torch.tensor([[0, 0, 0], [0, 1, 0], [0, 1, 2]]), + ), ) - eager_out = model(id_list_features) + eager_out = model(kjt_1) + eager_out_2 = model(kjt_2) # Serialize EBC model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) ep = torch.export.export( model, - (id_list_features,), + (kjt_1,), {}, strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP @@ -318,17 +337,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None: ) # Run forward on ExportedProgram - ep_output = ep.module()(id_list_features) + ep_output = ep.module()(kjt_1) + ep_output_2 = ep.module()(kjt_2) + self.assertEqual(len(ep_output), len(kjt_1.keys())) + self.assertEqual(len(ep_output_2), len(kjt_2.keys())) for i, tensor in enumerate(ep_output): - self.assertEqual(eager_out[i].shape, tensor.shape) + self.assertEqual(eager_out[i].shape[1], tensor.shape[1]) + for i, tensor in enumerate(ep_output_2): + self.assertEqual(eager_out_2[i].shape[1], tensor.shape[1]) # Deserialize EBC unflatten_ep = torch.export.unflatten(ep) deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) # check EBC config - for i in range(5): + for i in range(1): ebc_name = f"ebc{i + 1}" self.assertIsInstance( getattr(deserialized_model, ebc_name), EmbeddingBagCollection @@ -343,36 +367,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None: self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) self.assertEqual(deserialized.feature_names, orginal.feature_names) - # check FPEBC config - for i in range(2): - fpebc_name = f"fpebc{i + 1}" - assert isinstance( - getattr(deserialized_model, fpebc_name), - FeatureProcessedEmbeddingBagCollection, - ) - - for deserialized, orginal in zip( - getattr( - deserialized_model, fpebc_name - )._embedding_bag_collection.embedding_bag_configs(), - getattr( - model, fpebc_name - )._embedding_bag_collection.embedding_bag_configs(), - ): - self.assertEqual(deserialized.name, orginal.name) - self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) - self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) - self.assertEqual(deserialized.feature_names, orginal.feature_names) - # Run forward on deserialized model and compare the output deserialized_model.load_state_dict(model.state_dict()) - deserialized_out = deserialized_model(id_list_features) + deserialized_out = deserialized_model(kjt_1) self.assertEqual(len(deserialized_out), len(eager_out)) for deserialized, orginal in zip(deserialized_out, eager_out): self.assertEqual(deserialized.shape, orginal.shape) self.assertTrue(torch.allclose(deserialized, orginal)) + deserialized_out_2 = deserialized_model(kjt_2) + + self.assertEqual(len(deserialized_out_2), len(eager_out_2)) + for deserialized, orginal in zip(deserialized_out_2, eager_out_2): + self.assertEqual(deserialized.shape, orginal.shape) + self.assertTrue(torch.allclose(deserialized, orginal)) + def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None: model = self.generate_model() feature1 = KeyedJaggedTensor.from_offsets_sync( diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index ebdce6acb..6eba8cd75 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1728,6 +1728,8 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): "_weights", "_lengths", "_offsets", + "_stride_per_key_per_rank", + "_inverse_indices", ] def __init__( @@ -3016,7 +3018,26 @@ def dist_init( def _kjt_flatten( t: KeyedJaggedTensor, ) -> Tuple[List[Optional[torch.Tensor]], List[str]]: - return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys + """ + Used by PyTorch's pytree utilities for serialization and processing. + Extracts tensor attributes of a KeyedJaggedTensor and returns them + as a flat list, along with the necessary metadata to reconstruct the KeyedJaggedTensor. + + Component tensors are returned as dynamic attributes. + KJT metadata are added as static specs. + + Returns: + Tuple containing: + - List[Optional[torch.Tensor]]: All tensor attributes (_values, _weights, _lengths, + _offsets, _stride_per_key_per_rank, and the tensor part of _inverse_indices if present) + - Tuple[List[str], List[str]]: Metadata needed for reconstruction: + - List of keys from the original KeyedJaggedTensor + - List of inverse indices keys (if present, otherwise empty list) + """ + values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]] + values.append(t._inverse_indices[1] if t._inverse_indices is not None else None) + + return values, t._keys def _kjt_flatten_with_keys( @@ -3030,15 +3051,24 @@ def _kjt_flatten_with_keys( def _kjt_unflatten( - values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys + values: List[Optional[torch.Tensor]], + context: List[str], # context is _keys ) -> KeyedJaggedTensor: - return KeyedJaggedTensor(context, *values) + return KeyedJaggedTensor( + context, + *values[:-2], + stride_per_key_per_rank=values[-2], + inverse_indices=(context, values[-1]) if values[-1] is not None else None, + ) def _kjt_flatten_spec( t: KeyedJaggedTensor, spec: TreeSpec ) -> List[Optional[torch.Tensor]]: - return [getattr(t, a) for a in KeyedJaggedTensor._fields] + values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]] + values.append(t._inverse_indices[1] if t._inverse_indices is not None else None) + + return values register_pytree_node( From 2a398b91e9cfbc975ca8561c4751a66c256ad822 Mon Sep 17 00:00:00 2001 From: James Dong Date: Thu, 10 Jul 2025 13:38:42 -0700 Subject: [PATCH 2/2] Update KJT stride calculation logic to be based off of inverse_indices for VBE KJTs. (#3119) Summary: For VBE KJTs, ppdate the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` when its set. Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization. Reviewed By: TroyGarden Differential Revision: D76997485 --- torchrec/sparse/jagged_tensor.py | 6 ++++++ torchrec/sparse/tests/test_keyed_jagged_tensor.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 6eba8cd75..bd6b4bcef 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1095,6 +1095,7 @@ def _maybe_compute_stride_kjt( lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[torch.IntTensor], + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> int: if stride is None: if len(keys) == 0: @@ -1102,6 +1103,10 @@ def _maybe_compute_stride_kjt( elif ( stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0 ): + # For VBE KJT, batch size should be based on inverse_indices when set. + if inverse_indices is not None: + return inverse_indices[1].shape[-1] + s = stride_per_key_per_rank.sum(dim=1).max().item() if not torch.jit.is_scripting() and is_non_strict_exporting(): stride = torch.sym_int(s) @@ -2146,6 +2151,7 @@ def stride(self) -> int: self._lengths, self._offsets, self._stride_per_key_per_rank, + self._inverse_indices, ) self._stride = stride return stride diff --git a/torchrec/sparse/tests/test_keyed_jagged_tensor.py b/torchrec/sparse/tests/test_keyed_jagged_tensor.py index 1636a06bd..bac0a2c52 100644 --- a/torchrec/sparse/tests/test_keyed_jagged_tensor.py +++ b/torchrec/sparse/tests/test_keyed_jagged_tensor.py @@ -1017,6 +1017,18 @@ def test_meta_device_compatibility(self) -> None: lengths=torch.tensor([], device=torch.device("meta")), ) + def test_vbe_kjt_stride(self) -> None: + inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]]) + kjt = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), + lengths=torch.tensor([3, 3, 2]), + stride_per_key_per_rank=[[2], [1]], + inverse_indices=(["f1", "f2"], inverse_indices), + ) + + self.assertEqual(kjt.stride(), inverse_indices.shape[-1]) + class TestKeyedJaggedTensorScripting(unittest.TestCase): def test_scriptable_forward(self) -> None: