Skip to content

Commit

Permalink
backport: fall back to eager on NotImplementedError (pytorch#107863)
Browse files Browse the repository at this point in the history
    Follow-up to pytorch#107710:

    Help  dynamo fall back to eager when compiling unimplemented numpy constructs:

    - arrays of strings
    - (arg){min, max} for complex types
    - various arguments typed as NotImplemented (`np.ones(4, order="F")` etc)
    - numpy functions which torch._numpy does not implement

    To test, run (we do not implement arrays of strings)

    ```
    import torch
    import numpy as np

    @torch.compile(fullgraph=False)
    def fn():
        return np.asarray(["L", "U"])
    ```

    and observe it compiles with fullgraph=False and fails with fullgraph=True

    Fixes pytorch#107970

    Pull Request resolved: pytorch#107863
    Approved by: https://github.com/ezyang, https://github.com/lezcano
  • Loading branch information
ev-br committed Oct 26, 2023
1 parent 3f59221 commit a3a161f
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 6 deletions.
34 changes: 34 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,40 @@ def fn(x):
self.assertEqual(r.dtype, torch.int64)
self.assertEqual(cnts.frame_count, 1)

def test_numpy_unique_f16(self):
def fn():
x = np.asarray([1, 1, 2, 2, 3], dtype=np.float16)
return np.unique(x)

cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)

r = opt_fn()
self.assertEqual(r.dtype, np.float16)
self.assertEqual(cnts.frame_count, 1)

def test_numpy_fallback_on_eager(self):
def fn():
return np.asarray(["L", "U"])

cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)

r = opt_fn()
self.assertEqual(cnts.frame_count, 0) # graph break
self.assertEqual(r, np.asarray(["L", "U"]))

# repeat with a different function
def fn2():
return np.random.choice(["L", "U"])

cnts2 = torch._dynamo.testing.CompileCounter()
opt_fn2 = torch._dynamo.optimize(cnts2)(fn2)

r2 = fn2()
self.assertEqual(cnts.frame_count, 0)
assert r2 in ("L", "U")

def test_inplace_view_on_graph_input(self):
# graph break when calling methods with inplace_view tag on graph input
func_args_map = {
Expand Down
1 change: 1 addition & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def test_sparse_bool(self, device, dtype):
b = a.to_sparse().to_dense()
self.assertEqual(a, b)

@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/108667")
@dtypes(torch.double, torch.cdouble)
def test_scalar(self, device, dtype):
# tensor with value
Expand Down
1 change: 1 addition & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4950,6 +4950,7 @@ def compare_strides(s1, s2, div):

@onlyCUDA
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
@skipIfTorchDynamo("NotImplementedError: PrimTorch does not support pinned memory")
def test_pin_memory_from_constructor(self, device):
def _get_like(t, **kwargs):
return [
Expand Down
1 change: 1 addition & 0 deletions test/torch_np/numpy_tests/core/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
global_size_dict = dict(zip(chars, sizes))


@pytest.mark.skip
class TestEinsum:
def test_einsum_errors(self):
for do_opt in [True, False]:
Expand Down
1 change: 1 addition & 0 deletions test/torch_np/numpy_tests/lib/test_shape_base_.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def wrapped(a, axis, **kwargs):
return wrapped


@pytest.mark.skip
class TestTakeAlongAxis:
def test_argequivalent(self):
"""Test it translates from arg<func> to <func>"""
Expand Down
31 changes: 31 additions & 0 deletions test/torch_np/test_ndarray_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,34 @@ def test_extra_methods(name):
a = np.ones(3)
with pytest.raises(AttributeError):
getattr(a, name)


class TestNoExtraMethods:
# make sure ndarray does not carry extra methods/attributes
# >>> set(dir(a)) - set(dir(a.tensor.numpy()))
@pytest.mark.parametrize("name", ["fn", "ivar", "method", "name", "plain", "rvar"])
def test_extra_methods(self, name):
a = np.ones(3)
with pytest.raises(AttributeError):
getattr(a, name)


class TestIter:
def test_iter_1d(self):
# numpy generates array scalars, we do 0D arrays
a = np.arange(5)
lst = list(a)
assert all(type(x) == np.ndarray for x in lst)
assert all(x.ndim == 0 for x in lst)

def test_iter_2d(self):
# numpy iterates over the 0th axis
a = np.arange(5)[None, :]
lst = list(a)
assert len(lst) == 1
assert type(lst[0]) == np.ndarray
assert_equal(lst[0], np.arange(5))


if __name__ == "__main__":
run_tests()
7 changes: 7 additions & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,7 @@ def run_node(tracer, node, args, kwargs, nnmodule):
raise an AssertionError.
"""
op = node.op

try:
if op == "call_function":
return node.target(*args, **kwargs)
Expand All @@ -1405,6 +1406,12 @@ def run_node(tracer, node, args, kwargs, nnmodule):
elif op == "placeholder":
assert "example_value" in node.meta
return node.meta["example_value"]
except NotImplementedError as e:
# NB: mimic how wrap_fake_exception does it
from .exc import unimplemented

raise unimplemented(f"running {op} {node.target}(*{args}, **{kwargs})") from e

except Exception as e:
fn_str = f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n"
raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
Expand Down
8 changes: 3 additions & 5 deletions torch/_numpy/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,8 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=


def _sort_helper(tensor, axis, kind, order):
if tensor.dtype.is_complex:
raise NotImplementedError(f"sorting {tensor.dtype} is not supported")
(tensor,), axis = _util.axis_none_flatten(tensor, axis=axis)
axis = _util.normalize_axis_index(axis, tensor.ndim)

Expand All @@ -1387,17 +1389,13 @@ def _sort_helper(tensor, axis, kind, order):


def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
if a.dtype.is_complex:
return NotImplemented
# `order` keyword arg is only relevant for structured dtypes; so not supported here.
a, axis, stable = _sort_helper(a, axis, kind, order)
result = torch.sort(a, dim=axis, stable=stable)
return result.values


def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
if a.dtype.is_complex:
return NotImplemented
a, axis, stable = _sort_helper(a, axis, kind, order)
return torch.argsort(a, dim=axis, stable=stable)

Expand All @@ -1406,7 +1404,7 @@ def searchsorted(
a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None
):
if a.dtype.is_complex:
return NotImplemented
raise NotImplementedError(f"searchsorted with dtype={a.dtype}")

return torch.searchsorted(a, v, side=side, sorter=sorter)

Expand Down
2 changes: 2 additions & 0 deletions torch/_numpy/_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def wrapped(*args, **kwds):
sig = inspect.signature(func)
params = sig.parameters
first_param = next(iter(params.values()))

# NumPy's API does not have positional args before variadic positional args
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
args = [maybe_normalize(arg, first_param) for arg in args]
Expand All @@ -210,6 +211,7 @@ def wrapped(*args, **kwds):
name: maybe_normalize(arg, params[name]) if name in params else arg
for name, arg in kwds.items()
}

result = func(*args, **kwds)

# keepdims
Expand Down
12 changes: 12 additions & 0 deletions torch/_numpy/_reductions_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def argmax(
*,
keepdims: KeepDims = False,
):
if a.is_complex():
raise NotImplementedError(f"argmax with dtype={a.dtype}.")

axis = _util.allow_only_single_axis(axis)

if a.dtype == torch.bool:
Expand All @@ -88,6 +91,9 @@ def argmin(
*,
keepdims: KeepDims = False,
):
if a.is_complex():
raise NotImplementedError(f"argmin with dtype={a.dtype}.")

axis = _util.allow_only_single_axis(axis)

if a.dtype == torch.bool:
Expand Down Expand Up @@ -134,6 +140,9 @@ def amax(
initial: NotImplementedType = None,
where: NotImplementedType = None,
):
if a.is_complex():
raise NotImplementedError(f"amax with dtype={a.dtype}")

return a.amax(axis)


Expand All @@ -149,6 +158,9 @@ def amin(
initial: NotImplementedType = None,
where: NotImplementedType = None,
):
if a.is_complex():
raise NotImplementedError(f"amin with dtype={a.dtype}")

return a.amin(axis)


Expand Down
5 changes: 4 additions & 1 deletion torch/_numpy/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
if isinstance(obj, torch.Tensor):
tensor = obj
else:
tensor = torch.as_tensor(obj)
try:
tensor = torch.as_tensor(obj)
except Exception:
raise NotImplementedError(f"failed to convert {obj} to ndarray")

# tensor.dtype is the pytorch default, typically float32. If obj's elements
# are not exactly representable in float32, we've lost precision:
Expand Down
5 changes: 5 additions & 0 deletions torch/_numpy/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def assert_equal(actual, desired, err_msg="", verbose=True):
else:
return True

if isinstance(desired, str) and isinstance(actual, str):
assert actual == desired
return

if isinstance(desired, dict):
if not isinstance(actual, dict):
raise AssertionError(repr(type(actual)))
Expand All @@ -209,6 +213,7 @@ def assert_equal(actual, desired, err_msg="", verbose=True):
for k in range(len(desired)):
assert_equal(actual[k], desired[k], f"item={k!r}\n{err_msg}", verbose)
return

from torch._numpy import imag, iscomplexobj, isscalar, ndarray, real, signbit

if isinstance(actual, ndarray) or isinstance(desired, ndarray):
Expand Down

0 comments on commit a3a161f

Please sign in to comment.