Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,65 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None:
# TODO: turn on AOT Inductor test once the support is ready
test_aot_inductor=False,
)

def test_kjt_values_specialization(self):
with dynamo_skipfiles_allow("torchrec"):
from torch._dynamo.testing import CompileCounter

kjt0 = KeyedJaggedTensor(
values=torch.tensor([3, 4, 5, 6, 7, 8], dtype=torch.int64),
keys=["f0", "f1", "f2"],
lengths=torch.tensor([0, 0, 1, 1, 2, 2]),
stride=2,
)
torch._dynamo.decorators.mark_unbacked(kjt0._values, 0)

counter = CompileCounter()

@torch._dynamo.optimize(counter, nopython=True)
def f(kjt):
l: List[KeyedJaggedTensor] = kjt.split([1, 1, 1])
return l[0].values().sum() + l[1].values().sum() + l[2].values().sum()

f(kjt0)
self.assertEqual(counter.frame_count, 1)

kjt1 = KeyedJaggedTensor(
values=torch.tensor([], dtype=torch.int64),
keys=["f0", "f1", "f2"],
lengths=torch.tensor([0, 0, 0, 0, 0, 0]),
stride=2,
)
torch._dynamo.decorators.mark_unbacked(kjt1._values, 0)
f(kjt1)
self.assertEqual(counter.frame_count, 1)

def test_kjt_values_specialization_utils(self):
with dynamo_skipfiles_allow("torchrec"):
from torch._dynamo.testing import CompileCounter

kjt0 = KeyedJaggedTensor(
values=torch.tensor([3, 4, 5, 6, 7, 8], dtype=torch.int64),
keys=["f0", "f1", "f2"],
lengths=torch.tensor([0, 0, 1, 1, 2, 2]),
stride=2,
).sync()

counter = CompileCounter()

@torch._dynamo.optimize(counter, nopython=True)
def f(kjt):
l: List[KeyedJaggedTensor] = kjt.split([1, 1, 1])
return l[0].values().sum() + l[1].values().sum() + l[2].values().sum()

f(kjt_for_pt2_tracing(kjt0))
self.assertEqual(counter.frame_count, 1)

kjt1 = KeyedJaggedTensor(
values=torch.tensor([], dtype=torch.int64),
keys=["f0", "f1", "f2"],
lengths=torch.tensor([0, 0, 0, 0, 0, 0]),
stride=2,
).sync()
f(kjt_for_pt2_tracing(kjt1))
self.assertEqual(counter.frame_count, 1)
4 changes: 2 additions & 2 deletions torchrec/pt2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def kjt_for_pt2_tracing(
# We can mark static lengths dimension as we have fixed batch_size, but using VB path for tracing
torch._dynamo.decorators.mark_static(lengths, 0)
values = kjt.values().long()
torch._dynamo.decorators.mark_dynamic(values, 0)
torch._dynamo.decorators.mark_unbacked(values, 0)

return KeyedJaggedTensor(
keys=kjt.keys(),
Expand All @@ -71,7 +71,7 @@ def kjt_for_pt2_tracing(
stride = kjt.stride()

values = kjt.values().long()
torch._dynamo.decorators.mark_dynamic(values, 0)
torch._dynamo.decorators.mark_unbacked(values, 0)

return KeyedJaggedTensor(
keys=kjt.keys(),
Expand Down