Skip to content

Commit

Permalink
Generate type match guard for torch.Size input (#96421)
Browse files Browse the repository at this point in the history
I suppose hypothetically, if the user code ends up working
polymorphically over the SizeVariable, in such a way that a tuple would
work, this type match is not necessary.  But we do not carefully test
for this.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch/pytorch#96421
Approved by: https://github.com/jansel, https://github.com/voznesenskym
  • Loading branch information
ezyang authored and cyyever committed Mar 27, 2023
1 parent bf5bbbe commit 8269768
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
3 changes: 2 additions & 1 deletion test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,8 @@ def fn(x, s):
self.assertEqual(opt_fn(v, v.size())[0, 0], -10)
self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10)
self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10)
self.assertEqual(cnts.op_count, 2)
# One recompile per differing input type
self.assertEqual(cnts.frame_count, 3)

def test_cell_output1(self):
out = None
Expand Down
17 changes: 17 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,6 +2399,23 @@ def f(x):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

def test_size_typematch(self):
def f(x, y):
if isinstance(x, torch.Size):
return y + 1
else:
return y + 2

y = torch.zeros(1)
x1 = torch.Size((3,))
x2 = (3,)

cnt = torch._dynamo.testing.CompileCounter()
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
self.assertTrue(same(f(x1, y), opt_f(x1, y)))
self.assertTrue(same(f(x2, y), opt_f(x2, y)))
self.assertEqual(cnt.frame_count, 2)

@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", False)
def test_not_rewrite_assert(self):
def f(x):
Expand Down
12 changes: 9 additions & 3 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,28 @@ def EQUALS_MATCH(self, guard: Guard):
self._produce_guard_code(guard, code)
return

# Add type check to prevent equality check between tensor and non-tensor.
code = list()

# If matching equality against list/tuple, we must also check that
# the internal types match. (TODO: what about nested lists?)
if istype(val, (list, tuple)):
# NB: LIST_LENGTH takes care of the outer __check_type_id test
self.LIST_LENGTH(guard)

for idx, elem in enumerate(val):
code.append(
f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})"
)

elif not istype(val, torch.Size):
else:
# Add type check to prevent equality check between tensor and non-tensor.
code.append(f"___check_type_id({ref}, {self.id_ref(t)})")

if istype(val, torch.Size):
val = tuple(val)

# TODO: It feels like it would be better to just implement our own
# equality test in C that handles all of the necessary type checking
# and NaN tests
code.append(f"{ref} == {val!r}")
self._produce_guard_code(guard, code)

Expand Down

0 comments on commit 8269768

Please sign in to comment.