Skip to content

Commit

Permalink
Use maxint to bound integers. (#96121)
Browse files Browse the repository at this point in the history
We don't actually support arbitrary precision integers.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: pytorch/pytorch#96121
Approved by: https://github.com/tugsbayasgalan, https://github.com/lezcano
  • Loading branch information
ezyang authored and cyyever committed Mar 12, 2023
1 parent c379f7b commit c9f39a8
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
10 changes: 10 additions & 0 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,16 @@ def forward(self, crop_camera_1, mask_1):
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None
return None""")

def test_unbacked_slice(self):
def f(x, m):
x = x[m]
return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]

make_fx(f, tracing_mode="symbolic")(
torch.randn((12, 3, 3)),
torch.randint(0, 2, (12,), dtype=torch.bool)
)

@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_unbacked_batch_resnet(self):
mod = torchvision.models.resnet18()
Expand Down
3 changes: 3 additions & 0 deletions test/test_sympy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Owner(s): ["oncall: pt2"]

import itertools
import sys

import sympy
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -50,6 +51,8 @@
2**24,
2**32,
2**37 - 1,
sys.maxsize - 1,
sys.maxsize,
]
# less constants for N^2 situations
LESS_CONSTANTS = [-1, 0, 1, 2, 100]
Expand Down
6 changes: 3 additions & 3 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ def nonzero(fake_mode, func, arg):
raise DynamicOutputShapeException(func)

if arg.nonzero_memo is None:
import sys

from torch.fx.experimental.symbolic_shapes import constrain_range

nnz = fake_mode.shape_env.create_unbacked_symint()
Expand All @@ -438,9 +440,7 @@ def nonzero(fake_mode, func, arg):
# disjoint with what can actually occur. But this is fine:
# remember, the hypothesis is that if your later code works
# with N >= 2, it will work with N = 1 and N = 0.
lower = 2
upper = None
constrain_range(nnz, min=lower, max=upper)
constrain_range(nnz, min=2, max=sys.maxsize - 1)

arg._nonzero_memo = nnz
arg._nonzero_memo_vc = arg._version
Expand Down
17 changes: 14 additions & 3 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,7 @@ def create_unbacked_symfloat(self):
def create_unbacked_symint(self):
symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
self.var_to_stack[symbol] = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
self.var_to_range[symbol] = ValueRanges.unknown()
self.var_to_range[symbol] = ValueRanges(-sys.maxsize - 1, sys.maxsize)
return SymInt(SymNode(symbol, self, int, None))

# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
Expand All @@ -1361,7 +1361,10 @@ def create_symbol(self, val: int, source: Source, dyn=False) -> "sympy.Expr":

# We also infer that it must be not 0/1
lower = 2 if self.specialize_zero_one else 0
self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
# NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
# as a sentinel sometimes. Your sizevar isn't going to be
# anywhere near the max 64-bit integer anyway.
self.var_to_range[sympy_expr] = ValueRanges(lower, sys.maxsize - 1)

if not dyn and self.duck_shape:
# This implements duck-shaping: input sizes that match are assigned
Expand Down Expand Up @@ -1577,12 +1580,20 @@ def _verify(expr, potential_expr):
if not _simplified:
for symbol, sources in symbol_to_source.items():
assert sources
assert symbol.is_integer
r = self.var_to_range[symbol]
bounds = []
if r.lower != -sympy.oo:
bounds.append(str(r.lower))
bounds.append(source_ref(sources[0]))
if r.upper != sympy.oo:
# NB: This looks like an off-by-one error but it's not: the
# upper bound may be sys.maxsize - 1 because we intentionally
# exclude sys.maxsize from our bounds to deal with direct
# == INT_MAX guards, but it's still dumb to actually test it.
# Note that you can be off by a pretty large constant and it
# won't matter because sizes in practice will be no where near
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
bounds.append(str(r.upper))
if len(bounds) > 1:
exprs.append(" <= ".join(bounds))
Expand Down

0 comments on commit c9f39a8

Please sign in to comment.