From cea7dd1646ab147edac8f0e22f0aa85cf3136fef Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 31 May 2022 07:02:18 -0700 Subject: [PATCH] Add FakeTensorMode Pull Request resolved: https://github.com/pytorch/pytorch/pull/77972 Approved by: https://github.com/ezyang --- test/test_fake_tensor.py | 63 ++++++++++++++++++++++++----- torch/_subclasses/__init__.py | 1 + torch/_subclasses/fake_tensor.py | 68 ++++++++++++++++++++++++++------ 3 files changed, 110 insertions(+), 22 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index cefb167ebef42..433bd0f42d4e7 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1,11 +1,13 @@ -# Owner(s): ["module: unknown"] +# Owner(s): ["module: meta tensors"] from torch.testing._internal.common_utils import TestCase, run_tests import torch import itertools from torch.testing._internal.jit_utils import RUN_CUDA +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.utils._python_dispatch import enable_torch_dispatch_mode import unittest -from torch._subclasses import FakeTensor + class FakeTensorTest(TestCase): def test_basic(self): @@ -44,11 +46,30 @@ def test_dispatch_device(self): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_type_as(self): - x = FakeTensor.from_tensor(torch.rand([16, 1], device='cpu')) - y = FakeTensor.from_tensor(torch.rand([4, 4], device='cuda')) + x = FakeTensor.from_tensor(torch.rand([16, 1], device="cpu")) + y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda")) out = x.type_as(y) self.assertEqual(out.device.type, "cuda") + def test_constructor(self): + with enable_torch_dispatch_mode(FakeTensorMode): + x = torch.rand([4, 4], device="cpu") + + self.assertTrue(isinstance(x, FakeTensor)) + self.assertTrue(x.device.type == "cpu") + + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_fake_mode_non_fake_inputs(self): + x = torch.tensor(0.1) + y = torch.rand([4, 4], device="cuda") + + with enable_torch_dispatch_mode(FakeTensorMode): + out = x + y + + self.assertTrue(isinstance(out, FakeTensor)) + self.assertTrue(out.device.type == "cuda") + + def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type): return maybe_contained_type.isSubtypeOf(type) or any( contains_type(e, maybe_contained_type) for e in type.containedTypes() @@ -56,12 +77,14 @@ def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type): class FakeTensorOperatorInvariants(TestCase): + @staticmethod + def get_aten_op(schema): + namespace, name = schema.name.split("::") + overload = schema.overload_name if schema.overload_name else "default" + assert namespace == "aten" + return getattr(getattr(torch.ops.aten, name), overload) + def test_non_kwarg_only_device(self): - def get_op(schema): - namespace, name = schema.name.split("::") - overload = schema.overload_name if schema.overload_name else "default" - assert namespace == "aten" - return getattr(getattr(torch.ops.aten, name), overload) for schema in torch._C._jit_get_all_schemas(): namespace = schema.name.split("::")[0] @@ -82,9 +105,29 @@ def get_op(schema): ) if has_non_kwarg_device: self.assertTrue( - get_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops + self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops ) + def test_tensor_constructors_all_have_kwarg_device(self): + for schema in torch._C._jit_get_all_schemas(): + namespace = schema.name.split("::")[0] + if namespace != "aten": + continue + + op = self.get_aten_op(schema) + if not torch._subclasses.fake_tensor._is_tensor_constructor(op): + continue + + opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get()) + has_kwarg_device = any( + arg.kwarg_only and arg.type.isSubtypeOf(opt_device) + for arg in schema.arguments + ) + + self.assertTrue( + has_kwarg_device or op == torch.ops.aten._list_to_tensor.default + ) + if __name__ == "__main__": run_tests() diff --git a/torch/_subclasses/__init__.py b/torch/_subclasses/__init__.py index 53019165aa20f..4dfd90f451232 100644 --- a/torch/_subclasses/__init__.py +++ b/torch/_subclasses/__init__.py @@ -5,4 +5,5 @@ __all__ = [ "FakeTensor", "_device_not_kwarg_ops", + "_is_tensor_constructor", ] diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e83754fe81037..3f88cd3f5cfbc 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -5,6 +5,8 @@ from torch.fx.operator_schemas import normalize_function from torch.utils._mode_utils import no_dispatch from typing import Union +from torch._ops import OpOverload +import functools aten = torch.ops.aten @@ -20,6 +22,29 @@ aten._resize_output.out, ) +# this op is never actually used +_non_kwarg_device_constructors = (torch.ops.aten._list_to_tensor,) + + +def contains_tensor_types(type): + tensor_type = torch._C.TensorType.get() + return type.isSubtypeOf(tensor_type) or any( + contains_tensor_types(e) for e in type.containedTypes() + ) + + +@functools.lru_cache(None) +def _is_tensor_constructor(func: OpOverload): + assert isinstance(func, OpOverload) + schema = func._schema + if any(contains_tensor_types(arg.type) for arg in schema.arguments): + return False + # TODO: no real reason to restrict multiple outputs + return ( + len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() + ) + + # Meta tensors give you the ability to run PyTorch code without having to # actually do computation through tensors allocated on a `meta` device. # Because the device is `meta`, meta tensors do not model device propagation. @@ -60,7 +85,20 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): assert len(args) == 1 and isinstance(args[0], FakeTensor) return args[0].fake_device - # Run the original computation + def wrap(e, device=None): + if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): + if device: + return FakeTensor(e, device) + else: + return FakeTensor.from_tensor(e) + else: + return e + + # if we are in the dispatch mode, we will enter this function even if the inputs + # are not FakeTensors, and they need to be wrapped + if cls == FakeTensorMode: + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) # _to_copy fails when run with FakeTensors to cuda device # TODO: debug @@ -71,22 +109,25 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): out_device = new_kwargs.pop("device", new_kwargs["input"].device) with no_dispatch(): input = new_kwargs.pop("input").to("meta") - return FakeTensor(torch.ops.aten._to_copy(input, **new_kwargs), out_device) + return FakeTensor( + torch.ops.aten._to_copy(input, **new_kwargs), out_device + ) - r = super().__torch_dispatch__(func, types, args, kwargs) + if _is_tensor_constructor(func): + assert func not in _non_kwarg_device_constructors + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + # cpu is default device if none is specified + out_device = new_kwargs.pop("device", torch.device("cpu")) + new_kwargs["device"] = torch.device("meta") + r = super().__torch_dispatch__(func, types, (), new_kwargs) + return FakeTensor(r, out_device) - def wrap(e, device): - # inplace ops can return fake tensors - if isinstance(e, torch.Tensor) and not isinstance(e, cls): - return FakeTensor(e, device) - else: - return e + r = super().__torch_dispatch__(func, types, args, kwargs) # TODO: handle non-kwarg devices assert func not in _device_not_kwarg_ops, f"NYI: {func}" - assert ( - func != aten._pin_memory.default and func != aten.pin_memory.default - ), f"NYI: {func}" # if device is specified, use that if kwargs.get("device", None): @@ -159,3 +200,6 @@ def merge_devices(t): return common_device __torch_function__ = torch._C._disabled_torch_function_impl + +class FakeTensorMode(FakeTensor): + pass