Skip to content

Commit

Permalink
Add FakeTensorMode
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#77972

Approved by: https://github.com/ezyang
  • Loading branch information
Elias Ellison authored and pytorchmergebot committed May 31, 2022
1 parent 4c18f36 commit cea7dd1
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 22 deletions.
63 changes: 53 additions & 10 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Owner(s): ["module: unknown"]
# Owner(s): ["module: meta tensors"]

from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import itertools
from torch.testing._internal.jit_utils import RUN_CUDA
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.utils._python_dispatch import enable_torch_dispatch_mode
import unittest
from torch._subclasses import FakeTensor


class FakeTensorTest(TestCase):
def test_basic(self):
Expand Down Expand Up @@ -44,24 +46,45 @@ def test_dispatch_device(self):

@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_type_as(self):
x = FakeTensor.from_tensor(torch.rand([16, 1], device='cpu'))
y = FakeTensor.from_tensor(torch.rand([4, 4], device='cuda'))
x = FakeTensor.from_tensor(torch.rand([16, 1], device="cpu"))
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda"))
out = x.type_as(y)
self.assertEqual(out.device.type, "cuda")

def test_constructor(self):
with enable_torch_dispatch_mode(FakeTensorMode):
x = torch.rand([4, 4], device="cpu")

self.assertTrue(isinstance(x, FakeTensor))
self.assertTrue(x.device.type == "cpu")

@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_fake_mode_non_fake_inputs(self):
x = torch.tensor(0.1)
y = torch.rand([4, 4], device="cuda")

with enable_torch_dispatch_mode(FakeTensorMode):
out = x + y

self.assertTrue(isinstance(out, FakeTensor))
self.assertTrue(out.device.type == "cuda")


def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
)


class FakeTensorOperatorInvariants(TestCase):
@staticmethod
def get_aten_op(schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)

def test_non_kwarg_only_device(self):
def get_op(schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)

for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
Expand All @@ -82,9 +105,29 @@ def get_op(schema):
)
if has_non_kwarg_device:
self.assertTrue(
get_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
)

def test_tensor_constructors_all_have_kwarg_device(self):
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
continue

op = self.get_aten_op(schema)
if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
continue

opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_kwarg_device = any(
arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)

self.assertTrue(
has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
)


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torch/_subclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
__all__ = [
"FakeTensor",
"_device_not_kwarg_ops",
"_is_tensor_constructor",
]
68 changes: 56 additions & 12 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch.fx.operator_schemas import normalize_function
from torch.utils._mode_utils import no_dispatch
from typing import Union
from torch._ops import OpOverload
import functools

aten = torch.ops.aten

Expand All @@ -20,6 +22,29 @@
aten._resize_output.out,
)

# this op is never actually used
_non_kwarg_device_constructors = (torch.ops.aten._list_to_tensor,)


def contains_tensor_types(type):
tensor_type = torch._C.TensorType.get()
return type.isSubtypeOf(tensor_type) or any(
contains_tensor_types(e) for e in type.containedTypes()
)


@functools.lru_cache(None)
def _is_tensor_constructor(func: OpOverload):
assert isinstance(func, OpOverload)
schema = func._schema
if any(contains_tensor_types(arg.type) for arg in schema.arguments):
return False
# TODO: no real reason to restrict multiple outputs
return (
len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
)


# Meta tensors give you the ability to run PyTorch code without having to
# actually do computation through tensors allocated on a `meta` device.
# Because the device is `meta`, meta tensors do not model device propagation.
Expand Down Expand Up @@ -60,7 +85,20 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
assert len(args) == 1 and isinstance(args[0], FakeTensor)
return args[0].fake_device

# Run the original computation
def wrap(e, device=None):
if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor):
if device:
return FakeTensor(e, device)
else:
return FakeTensor.from_tensor(e)
else:
return e

# if we are in the dispatch mode, we will enter this function even if the inputs
# are not FakeTensors, and they need to be wrapped
if cls == FakeTensorMode:
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)

# _to_copy fails when run with FakeTensors to cuda device
# TODO: debug
Expand All @@ -71,22 +109,25 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
out_device = new_kwargs.pop("device", new_kwargs["input"].device)
with no_dispatch():
input = new_kwargs.pop("input").to("meta")
return FakeTensor(torch.ops.aten._to_copy(input, **new_kwargs), out_device)
return FakeTensor(
torch.ops.aten._to_copy(input, **new_kwargs), out_device
)

r = super().__torch_dispatch__(func, types, args, kwargs)
if _is_tensor_constructor(func):
assert func not in _non_kwarg_device_constructors
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
# cpu is default device if none is specified
out_device = new_kwargs.pop("device", torch.device("cpu"))
new_kwargs["device"] = torch.device("meta")
r = super().__torch_dispatch__(func, types, (), new_kwargs)
return FakeTensor(r, out_device)

def wrap(e, device):
# inplace ops can return fake tensors
if isinstance(e, torch.Tensor) and not isinstance(e, cls):
return FakeTensor(e, device)
else:
return e
r = super().__torch_dispatch__(func, types, args, kwargs)

# TODO: handle non-kwarg devices
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
assert (
func != aten._pin_memory.default and func != aten.pin_memory.default
), f"NYI: {func}"

# if device is specified, use that
if kwargs.get("device", None):
Expand Down Expand Up @@ -159,3 +200,6 @@ def merge_devices(t):
return common_device

__torch_function__ = torch._C._disabled_torch_function_impl

class FakeTensorMode(FakeTensor):
pass

0 comments on commit cea7dd1

Please sign in to comment.