diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 90c8eade7..5440dba18 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -604,3 +604,65 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None: # TODO: turn on AOT Inductor test once the support is ready test_aot_inductor=False, ) + + def test_kjt_values_specialization(self): + with dynamo_skipfiles_allow("torchrec"): + from torch._dynamo.testing import CompileCounter + + kjt0 = KeyedJaggedTensor( + values=torch.tensor([3, 4, 5, 6, 7, 8], dtype=torch.int64), + keys=["f0", "f1", "f2"], + lengths=torch.tensor([0, 0, 1, 1, 2, 2]), + stride=2, + ) + torch._dynamo.decorators.mark_unbacked(kjt0._values, 0) + + counter = CompileCounter() + + @torch._dynamo.optimize(counter, nopython=True) + def f(kjt): + l: List[KeyedJaggedTensor] = kjt.split([1, 1, 1]) + return l[0].values().sum() + l[1].values().sum() + l[2].values().sum() + + f(kjt0) + self.assertEqual(counter.frame_count, 1) + + kjt1 = KeyedJaggedTensor( + values=torch.tensor([], dtype=torch.int64), + keys=["f0", "f1", "f2"], + lengths=torch.tensor([0, 0, 0, 0, 0, 0]), + stride=2, + ) + torch._dynamo.decorators.mark_unbacked(kjt1._values, 0) + f(kjt1) + self.assertEqual(counter.frame_count, 1) + + def test_kjt_values_specialization_utils(self): + with dynamo_skipfiles_allow("torchrec"): + from torch._dynamo.testing import CompileCounter + + kjt0 = KeyedJaggedTensor( + values=torch.tensor([3, 4, 5, 6, 7, 8], dtype=torch.int64), + keys=["f0", "f1", "f2"], + lengths=torch.tensor([0, 0, 1, 1, 2, 2]), + stride=2, + ).sync() + + counter = CompileCounter() + + @torch._dynamo.optimize(counter, nopython=True) + def f(kjt): + l: List[KeyedJaggedTensor] = kjt.split([1, 1, 1]) + return l[0].values().sum() + l[1].values().sum() + l[2].values().sum() + + f(kjt_for_pt2_tracing(kjt0)) + self.assertEqual(counter.frame_count, 1) + + kjt1 = KeyedJaggedTensor( + values=torch.tensor([], dtype=torch.int64), + keys=["f0", "f1", "f2"], + lengths=torch.tensor([0, 0, 0, 0, 0, 0]), + stride=2, + ).sync() + f(kjt_for_pt2_tracing(kjt1)) + self.assertEqual(counter.frame_count, 1) diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index bf0800042..271da8308 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -44,7 +44,7 @@ def kjt_for_pt2_tracing( # We can mark static lengths dimension as we have fixed batch_size, but using VB path for tracing torch._dynamo.decorators.mark_static(lengths, 0) values = kjt.values().long() - torch._dynamo.decorators.mark_dynamic(values, 0) + torch._dynamo.decorators.mark_unbacked(values, 0) return KeyedJaggedTensor( keys=kjt.keys(), @@ -71,7 +71,7 @@ def kjt_for_pt2_tracing( stride = kjt.stride() values = kjt.values().long() - torch._dynamo.decorators.mark_dynamic(values, 0) + torch._dynamo.decorators.mark_unbacked(values, 0) return KeyedJaggedTensor( keys=kjt.keys(),