From 105cb6b71b040ff3b5db140424a0321de8342292 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 2 Jul 2024 22:08:52 -0700 Subject: [PATCH 1/2] serialization device issue in KT.regroup Differential Revision: D59172050 --- torchrec/ir/tests/test_serializer.py | 48 +++++++++++++++++++ torchrec/sparse/jagged_tensor.py | 69 +++++++++++++++++++++++++--- 2 files changed, 111 insertions(+), 6 deletions(-) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 746708355..2624eac5e 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -351,6 +351,54 @@ def test_deserialized_device(self) -> None: continue assert param.device.type == device.type, f"{name} should be on {device}" + def test_deserialize_device_kt_regroup(self) -> None: + class Model(nn.Module): + def __init__(self, ebc): + super().__init__() + self.ebc = ebc + + def forward( + self, + features: KeyedJaggedTensor, + ) -> List[torch.Tensor]: + kt = self.ebc(features) + return KeyedTensor.regroup([kt], [[key] for key in kt.keys()]) + + model = self.generate_model() + model = Model(model.ebc1) + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), + ) + eager_out = model(id_list_features) + + # Serialize EBC + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + unflatten_model = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules( + unflatten_model, JsonSerializer, torch.device("cuda") + ) + device = torch.device("cuda") + deserialized_model.to(device) + id_list_features = id_list_features.to(device) + + deserialized_model.load_state_dict(model.state_dict()) + # Run forward on deserialized model + deserialized_out = deserialized_model(id_list_features) + + for i, tensor in enumerate(deserialized_out): + assert eager_out[i].shape == tensor.shape + assert torch.allclose(eager_out[i].to(tensor), tensor) + def test_compound_module(self) -> None: tb1_config = EmbeddingBagConfig( name="t1", diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index b5fe73002..d0fed8609 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -58,6 +58,22 @@ def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: ) +@torch.library.custom_op("torchrec::_pin_and_move_op_int", mutates_args=()) +def _pin_and_move_op_int_impl(data: List[int], other: torch.Tensor) -> torch.Tensor: + tensor = torch.Tensor(data) + device = other.device + return ( + tensor.pin_memory().to(device=device, non_blocking=True) + if device.type == "cuda" and tensor.device.type == "cpu" + else tensor.to(device=device, non_blocking=True) + ) + + +@torch.library.register_fake("torchrec::_pin_and_move_op_int") +def _pin_and_move_op_int_fake(data: List[int], other: torch.Tensor) -> torch.Tensor: + return torch.empty(len(data), dtype=torch.int64, device=other.device) + + def _cumsum(o: List[int]) -> List[int]: ret = [0] * (len(o) + 1) for i in range(len(o)): @@ -169,17 +185,22 @@ def _fbgemm_permute_pooled_embs( keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] ) -> List[torch.Tensor]: keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) - permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( + permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups_list( keys, lengths, groups ) + args = torch.ops.torchrec._pin_and_move_op_int( + permute + inv_permute + offsets + inv_offsets, values[0] + ) + permute, inv_permute, offsets, inv_offsets = args.split( + [len(permute), len(permute), len(offsets), len(offsets)] + ) values = torch.concat(values, dim=1) - device = values.device permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( values, - _pin_and_move(offsets, device), - _pin_and_move(permute, device), - _pin_and_move(inv_offsets, device), - _pin_and_move(inv_permute, device), + offsets, + permute, + inv_offsets, + inv_permute, ) return list(torch.split(permuted_values, splits, dim=1)) @@ -198,6 +219,42 @@ def _desugar_keyed_tensors( ) +def _remap_to_groups_list( + keys: List[List[str]], + key_lengths: List[List[int]], + groups: List[List[str]], +) -> Tuple[List[int], List[int], List[int], List[int], List[int]]: + """ + Given a list of keys and lengths per key for each group, return the permute indices, inverse_permute indices, offsets, inv_offsets, splits. + The output is used to re-arrange values based on groups with a single cat operation. + """ + + lengths: List[int] = [] + flat_keys: List[str] = [] + flat_groups: List[str] = [] + + for sub_keys_length in key_lengths: + lengths.extend(sub_keys_length) + for sub_keys in keys: + flat_keys.extend(sub_keys) + + for sub_group in groups: + flat_groups.extend(sub_group) + + key_splits = [len(sub_group) for sub_group in groups] + + index_map = {key: idx for idx, key in enumerate(flat_keys)} + permute = [index_map[key] for key in flat_groups] + inv_lengths = [lengths[i] for i in permute] + splits = _sum_by_splits(inv_lengths, key_splits) + + inv_permute = [0] * len(permute) + for i, p in enumerate(permute): + inv_permute[p] = i + + return permute, inv_permute, _cumsum(lengths), _cumsum(inv_lengths), splits + + @torch.fx.wrap def _remap_to_groups( keys: List[List[str]], From 75fa08183d60110238637bec041ef07a0fd1e01f Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 2 Jul 2024 23:18:46 -0700 Subject: [PATCH 2/2] how to mark KJT offsets as dynamic? (#2202) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2202 # context 1. KJT contains three necessary tensors: `_values`, `_lengths`, `_offsets` **a.** the shape of `_values` is independent **b.** dim(`_lengths`) = dim(`batch_size`) * const(`len(kjt.keys())`) **c.** dim(`_offsets`) = dim(`lengths`) + 1 2. `_lengths` and `_offsets` can be calculated from the other, so usually a KJT only stores one is the memory and calculate the other when needed. 3. previously only the `_lengths` is marked as dynamic shape, because `batch_size` and `len(kjt.keys())` are constant across iterations. 4. however, when we declare a KJT has both `_values` and `_offsets` as the dynamic shape, it won't pass the export function # notes 1. the `feature2` in the test has **NO** impact on the failure because it errors out before `feature2` is used 2. the error is purely due to the change that marks `_offsets` as dynamic. # investigation * `_offsets` is set to `3 * batch_size + 1` as shown below: ``` {'features': [(,), None, None, (,)]} ``` * dynamic_shape `s1` is created for `_offsets`, dynamic_shape `s2` is craeted for `batch_size` * why there is no `s1 == 3*batch_size + 1`? ``` 0702 09:50:39.181000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s1 = 7 for L['args'][0][0]._offsets.size()[0] [2, 12884901886] (_export/non_strict_utils.py:93 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" V0702 09:50:39.183000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval False == False [statically known] I0702 09:50:39.190000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s2 = 2 for batch_size1 [2, 4294967295] (export/dynamic_shapes.py:569 in _process_equalities), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" V0702 09:50:39.267000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval ((s1 - 1)//3) >= 0 == True [statically known] I0702 09:50:39.273000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5104] eval Ne(((s1 - 1)//3), 0) [guard added] (_subclasses/functional_tensor.py:134 in __new__), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne(((s1 - 1)//3), 0)" V0702 09:50:39.322000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4736] _update_var_to_range s1 = VR[7, 7] (update) I0702 09:50:39.330000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4855] set_replacement s1 = 7 (range_refined_to_singleton) VR[7, 7] ``` # resolve the issue * there is an internal flag `_allow_complex_guards_as_runtime_asserts=True` can support this correlation * before ``` ep = torch.export.export( model, (feature1,), {}, dynamic_shapes=collection.dynamic_shapes(model, (feature1,)), strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP preserve_module_call_signature=tuple(sparse_fqns), ) ``` * after ``` ep = torch.export._trace._export( model, (feature1,), {}, dynamic_shapes=collection.dynamic_shapes(model, (feature1,)), strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP preserve_module_call_signature=tuple(sparse_fqns), _allow_complex_guards_as_runtime_asserts=True, ) ``` Differential Revision: D59201188 --- torchrec/ir/tests/test_serializer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 2624eac5e..96b209f9a 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -258,15 +258,15 @@ def test_dynamic_shape_ebc(self) -> None: feature2 = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2", "f3"], - values=torch.tensor([0, 1, 2, 3, 2, 3, 4]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]), + values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 8, 1, 2]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7, 8, 10, 12]), ) eager_out = model(feature2) # Serialize EBC - collection = mark_dynamic_kjt(feature1) + collection = mark_dynamic_kjt(feature1, variable_length=True) model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) - ep = torch.export.export( + ep = torch.export._trace._export( model, (feature1,), {}, @@ -274,6 +274,7 @@ def test_dynamic_shape_ebc(self) -> None: strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP preserve_module_call_signature=tuple(sparse_fqns), + _allow_complex_guards_as_runtime_asserts=True, ) # Run forward on ExportedProgram