Skip to content

Conversation

TroyGarden
Copy link
Contributor

Summary:

context

  1. KJT contains three necessary tensors: _values, _lengths, _offsets
    a. the shape of _values is independent
    b. dim(_lengths) = dim(batch_size) * const(len(kjt.keys()))
    c. dim(_offsets) = dim(lengths) + 1
  2. _lengths and _offsets can be calculated from the other, so usually a KJT only stores one is the memory and calculate the other when needed.
  3. previously only the _lengths is marked as dynamic shape, because batch_size and len(kjt.keys()) are constant across iterations.
  4. however, when we declare a KJT has both _values and _offsets as the dynamic shape, it won't pass the export function

notes

  1. the feature2 in the test has NO impact on the failure because it errors out before feature2 is used
  2. the error is purely due to the change that marks _offsets as dynamic.

investigation

  • _offsets is set to 3 * batch_size + 1 as shown below:
{'features': [(<class 'torchrec.ir.utils.vlen1'>,), None, None, (<class 'torch.export.dynamic_shapes.3*batch_size1 + 1'>,)]}
  • dynamic_shape s1 is created for _offsets, dynamic_shape s2 is craeted for batch_size
  • why there is no s1 == 3*batch_size + 1?
0702 09:50:39.181000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s1 = 7 for L['args'][0][0]._offsets.size()[0] [2, 12884901886] (_export/non_strict_utils.py:93 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1"
V0702 09:50:39.183000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval False == False [statically known]
I0702 09:50:39.190000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s2 = 2 for batch_size1 [2, 4294967295] (export/dynamic_shapes.py:569 in _process_equalities), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2"
V0702 09:50:39.267000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval ((s1 - 1)//3) >= 0 == True [statically known]
I0702 09:50:39.273000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5104] eval Ne(((s1 - 1)//3), 0) [guard added] (_subclasses/functional_tensor.py:134 in __new__), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne(((s1 - 1)//3), 0)"
V0702 09:50:39.322000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4736] _update_var_to_range s1 = VR[7, 7] (update)
I0702 09:50:39.330000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4855] set_replacement s1 = 7 (range_refined_to_singleton) VR[7, 7]

resolve the issue

  • there is an internal flag _allow_complex_guards_as_runtime_asserts=True can support this correlation
  • before
        ep = torch.export.export(
            model,
            (feature1,),
            {},
            dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
            strict=False,
            # Allows KJT to not be unflattened and run a forward on unflattened EP
            preserve_module_call_signature=tuple(sparse_fqns),
        )
  • after
        ep = torch.export._trace._export(
            model,
            (feature1,),
            {},
            dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
            strict=False,
            # Allows KJT to not be unflattened and run a forward on unflattened EP
            preserve_module_call_signature=tuple(sparse_fqns),
            _allow_complex_guards_as_runtime_asserts=True,
        )

Differential Revision: D59201188

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 2, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59201188

Huanyu He and others added 2 commits July 2, 2024 22:08
Differential Revision: D59172050
Summary:
Pull Request resolved: meta-pytorch#2202

# context
1. KJT contains three necessary tensors: `_values`, `_lengths`, `_offsets`
**a.** the shape of `_values` is independent
**b.** dim(`_lengths`) = dim(`batch_size`) * const(`len(kjt.keys())`)
**c.** dim(`_offsets`) = dim(`lengths`) + 1
2. `_lengths` and `_offsets` can be calculated from the other, so usually a KJT only stores one is the memory and calculate the other when needed.
3. previously only the `_lengths` is marked as dynamic shape, because `batch_size` and `len(kjt.keys())` are constant across iterations.
4. however, when we declare a KJT has both `_values` and `_offsets` as the dynamic shape, it won't pass the export function

# notes
1. the `feature2` in the test has **NO** impact on the failure because it errors out before `feature2` is used
2. the error is purely due to the change that marks `_offsets` as dynamic.

# investigation
* `_offsets` is set to `3 * batch_size + 1` as shown below:
```
{'features': [(<class 'torchrec.ir.utils.vlen1'>,), None, None, (<class 'torch.export.dynamic_shapes.3*batch_size1 + 1'>,)]}
```
* dynamic_shape `s1` is created for `_offsets`, dynamic_shape `s2` is craeted for `batch_size`
* why there is no `s1 == 3*batch_size + 1`?
```
0702 09:50:39.181000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s1 = 7 for L['args'][0][0]._offsets.size()[0] [2, 12884901886] (_export/non_strict_utils.py:93 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1"
V0702 09:50:39.183000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval False == False [statically known]
I0702 09:50:39.190000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s2 = 2 for batch_size1 [2, 4294967295] (export/dynamic_shapes.py:569 in _process_equalities), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2"
V0702 09:50:39.267000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval ((s1 - 1)//3) >= 0 == True [statically known]
I0702 09:50:39.273000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5104] eval Ne(((s1 - 1)//3), 0) [guard added] (_subclasses/functional_tensor.py:134 in __new__), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne(((s1 - 1)//3), 0)"
V0702 09:50:39.322000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4736] _update_var_to_range s1 = VR[7, 7] (update)
I0702 09:50:39.330000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4855] set_replacement s1 = 7 (range_refined_to_singleton) VR[7, 7]
```

# resolve the issue
* there is an internal flag `_allow_complex_guards_as_runtime_asserts=True` can support this correlation
* before
```
        ep = torch.export.export(
            model,
            (feature1,),
            {},
            dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
            strict=False,
            # Allows KJT to not be unflattened and run a forward on unflattened EP
            preserve_module_call_signature=tuple(sparse_fqns),
        )
```
* after
```
        ep = torch.export._trace._export(
            model,
            (feature1,),
            {},
            dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
            strict=False,
            # Allows KJT to not be unflattened and run a forward on unflattened EP
            preserve_module_call_signature=tuple(sparse_fqns),
            _allow_complex_guards_as_runtime_asserts=True,
        )
```

Differential Revision: D59201188
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59201188

TroyGarden pushed a commit to TroyGarden/torchrec that referenced this pull request Jul 27, 2024
Summary:
Pull Request resolved: meta-pytorch#2202

# context
* Discussed this ctx.new_dynamic_size() with Angela and she told me it's just to register a brand new dynamic shape for the batch_size. so I traced back why my registering batch_size as dynamic shape has so many problem.
* all the complexity really comes from this line(s): serializer.py: batch_size = id_list_features.stride()
* we get the batch_size from this function, which is derived from either "_maybe_compute_stride_kjt" or other sources.
* Especially this line stride = lengths.numel() // len(keys), explained a lot of error messages I encountered before.
* this correlation between the lengths.numel() and the batch_size really complicates the dynamic shape guards. for this I made up many artificial workarounds like: utils.py

# details
* we still declare/register the dynamic shape in the mark_dynamic_kjt function,
* but pass the _dynamic_batch_size to kjt and retrieve it from the serializer's meta_forward.

Differential Revision: D59201188
@TroyGarden TroyGarden deleted the export-D59201188 branch August 8, 2024 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants