Skip to content

Commit

Permalink
Fake Tensor Part 1
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#77969

Approved by: https://github.com/ezyang
  • Loading branch information
Elias Ellison authored and pytorchmergebot committed May 31, 2022
1 parent d136852 commit 678213e
Show file tree
Hide file tree
Showing 13 changed files with 273 additions and 36 deletions.
2 changes: 1 addition & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
onlyCPU, onlyCUDA, dtypes, dtypesIfCUDA,
deviceCountAtLeast, skipMeta, dtypesIfMPS)
from torch.testing._internal.common_dtype import floating_types_and
from torch.testing._internal.logging_tensor import no_dispatch
from torch.utils._mode_utils import no_dispatch

import pickle

Expand Down
2 changes: 1 addition & 1 deletion test/test_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch._decomp import decomposition_table

from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.testing._internal.logging_tensor import no_dispatch
from torch.utils._mode_utils import no_dispatch
from torch.testing._internal.common_utils import (
is_iterable_of_tensors,
TestCase,
Expand Down
81 changes: 81 additions & 0 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Owner(s): ["module: unknown"]

from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import itertools
from torch.testing._internal.jit_utils import RUN_CUDA
import unittest
from torch._subclasses import FakeTensor


class FakeTensorTest(TestCase):
def test_basic(self):
x = FakeTensor.from_tensor(torch.empty(2, 2, device="cpu"))
y = x = FakeTensor.from_tensor(torch.empty(4, 2, 2, device="cpu"))
y = x + x
self.assertEqual(y.shape, (4, 2, 2))
self.assertEqual(y.device, torch.device("cpu"))

@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_shape_take_not_device(self):
x = FakeTensor.from_tensor(torch.empty(1, device="cpu"))
y = FakeTensor.from_tensor(torch.empty(8, 8, device="cuda"))
out = x.resize_as_(y)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.device.type, "cpu")

@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_zero_dim(self):
x = FakeTensor.from_tensor(torch.tensor(0.0))
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda"))
out = x + y
self.assertEqual(out.shape, (4, 4))
self.assertEqual(out.device, y.device)

@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_throw(self):
x = FakeTensor.from_tensor(torch.tensor(0.0))
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda"))
z = FakeTensor.from_tensor(torch.rand([4, 4], device="cpu"))
self.assertRaises(Exception, lambda: torch.lerp(x, y, z))


def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
)


class FakeTensorOperatorInvariants(TestCase):
def test_non_kwarg_only_device(self):
def get_op(schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)

for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
continue

ten_type = torch._C.TensorType.get()
if not any(
contains_type(arg.type, ten_type)
for arg in itertools.chain(schema.arguments, schema.returns)
):
continue

opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_non_kwarg_device = any(
not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
if has_non_kwarg_device:
self.assertTrue(
get_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
)


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
instantiate_device_type_tests,
onlyCUDA,
)
from torch.testing._internal.logging_tensor import no_dispatch
from torch.utils._mode_utils import no_dispatch
from torch.testing._internal.common_methods_invocations import op_db
from torchgen.utils import YamlLoader
from torchgen.model import OperatorName
Expand Down
3 changes: 2 additions & 1 deletion test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from torch.cuda.jiterator import _create_jit_fn
import unittest
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, IS_WINDOWS
from torch.utils._mode_utils import no_dispatch
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
log_input, capture_logs, no_dispatch, capture_logs_with_logging_tensor_mode
log_input, capture_logs, capture_logs_with_logging_tensor_mode
from torch.utils._pytree import tree_map
from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_dispatch_mode, TorchDispatchMode

Expand Down
10 changes: 10 additions & 0 deletions torch/_subclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

from torch._subclasses.base_tensor import BaseTensor
from torch._subclasses.fake_tensor import FakeTensor, _device_not_kwarg_ops

__all__ = [
"BaseTensor",
"FakeTensor",
"_device_not_kwarg_ops",
]
19 changes: 19 additions & 0 deletions torch/_subclasses/base_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

# Ideally, tensor subclasses would would inherit directly from Tensor.
# This is just our staging ground for applying behavior that hasn't yet made it
# into the core Tensor class but that we would like to apply by default.
class BaseTensor(torch.Tensor):
# See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
# to ensure that super().__new__ can cooperate with each other
@staticmethod
def __new__(cls, elem, *, requires_grad=None):
if requires_grad is None:
return super().__new__(cls, elem) # type: ignore[call-arg]
else:
return cls._make_subclass(cls, elem, requires_grad)

# If __torch_dispatch__ is defined (which it will be for all our examples)
# the default torch function implementation (which preserves subclasses)
# typically must be disabled
__torch_function__ = torch._C._disabled_torch_function_impl
142 changes: 142 additions & 0 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import torch

from torch._subclasses import BaseTensor
from torch.utils._pytree import tree_map
from functools import partial
from torch.fx.operator_schemas import normalize_function
from typing import Union

aten = torch.ops.aten

_device_not_kwarg_ops = (
aten._resize_output_.default,
aten.nested_tensor.default,
aten.pin_memory.default,
aten.is_pinned.default,
aten.to.device,
aten.to.prim_Device,
aten._pin_memory.default,
aten._resize_output.functional,
aten._resize_output.out,
)

# Meta tensors give you the ability to run PyTorch code without having to
# actually do computation through tensors allocated on a `meta` device.
# Because the device is `meta`, meta tensors do not model device propagation.
# FakeTensor extends MetaTensors to also carry an additional `fake_device`
# which tracks devices that would have been used.


class FakeTensor(BaseTensor):
fake_device: torch.device

@staticmethod
def __new__(cls, elem, device):
return super().__new__(cls, elem)

def __init__(self, elem, device: Union[torch.device, str]):
# elem does not need to be recorded, because FakeTensor *is a* elem
assert elem.device.type == "meta"
device = device if isinstance(device, torch.device) else torch.device(device)
assert device.type != "meta"
self.fake_device = device

@staticmethod
def from_tensor(t):
existing_device = t.device
return FakeTensor(t.to(device="meta"), existing_device)

@property
def device(self):
return self.fake_device

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}


# Run the original computation
r = super().__torch_dispatch__(func, types, args, kwargs)

def wrap(e, device):
# inplace ops can return fake tensors
if isinstance(e, torch.Tensor) and not isinstance(e, cls):
return FakeTensor(e, device)
else:
return e

# TODO: handle non-kwarg devices
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
assert (
func != aten._pin_memory.default and func != aten.pin_memory.default
), f"NYI: {func}"

# if device is specified, use that
if kwargs.get("device", None):
return tree_map(partial(wrap, device=kwargs["device"]), r)

# operators which copy size from another tensor do not
# also take device from the size tensor
# other size_as operators are not builtin operators
if func == aten.resize_as_.default:
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
# device of the input is returned
return tree_map(partial(wrap, device=new_kwargs["input"].device), r)

common_device = FakeTensor._find_common_device(func, args, kwargs)

return tree_map(partial(wrap, device=common_device), r)

@staticmethod
def _find_common_device(func, args, kwargs):
# cpu - zero-dim tensors can be called in cuda kernels,
# so overwrite the common_device if it the only existing
# device comes from a cpu zero-dim tensor
common_device = None
is_cpu_zero_dim = None

def cpu_zero_dim(t):
return t.device.type == "cpu" and t.dim() == 0

def merge_devices(t):
nonlocal common_device
nonlocal is_cpu_zero_dim
if not isinstance(t, FakeTensor):
return

if common_device is None:
common_device = t.device
is_cpu_zero_dim = cpu_zero_dim(t)
return

t_is_cpu_zero_dim = cpu_zero_dim(t)
if t.device == common_device:
if is_cpu_zero_dim:
is_cpu_zero_dim = t_is_cpu_zero_dim
return

# mismatching devices !
# if current tensor is cpu 0 dim, defer to existing device
if t_is_cpu_zero_dim:
return

# current device is from cpu 0 dim tensor, overwrite
if is_cpu_zero_dim:
common_device = t.device
is_cpu_zero_dim = t_is_cpu_zero_dim
return

# mismatching devices of non-zero dim tensors, throw
# This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
raise Exception(
f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
)

tree_map(merge_devices, args)
tree_map(merge_devices, kwargs)

assert common_device is not None, f"Could not find common device for {func}"

return common_device
8 changes: 4 additions & 4 deletions torch/csrc/jit/python/python_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,9 @@ void initPythonIRBindings(PyObject* module_) {
s << t;
return s.str();
})
.def(
"containedTypes",
[](Type& self) { return self.containedTypes().vec(); })
.def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
.def(
"dim",
Expand Down Expand Up @@ -1008,10 +1011,7 @@ void initPythonIRBindings(PyObject* module_) {
});
py::class_<UnionType, Type, UnionTypePtr>(m, "UnionType")
.def(py::init(
[](const std::vector<TypePtr>& a) { return UnionType::create(a); }))
.def("containedTypes", [](UnionType& self) {
return self.containedTypes().vec();
});
[](const std::vector<TypePtr>& a) { return UnionType::create(a); }));
py::class_<ListType, Type, ListTypePtr>(m, "ListType")
.def(py::init([](TypePtr a) { return ListType::create(a); }))
.def_static("ofInts", &ListType::ofInts)
Expand Down
9 changes: 1 addition & 8 deletions torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
import torch.fx as fx
from torch.utils._mode_utils import no_dispatch
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager

Expand All @@ -21,14 +22,6 @@
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}


@contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
yield
finally:
del guard


@contextmanager
def decompose(decomposition_table):
Expand Down
11 changes: 1 addition & 10 deletions torch/testing/_internal/composite_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,15 @@
from torch import Tensor
import contextlib
import itertools
from typing import Iterator
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from functools import partial
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import enable_torch_dispatch_mode
import torch.autograd.forward_ad as fwAD
from torch.overrides import enable_reentrant_dispatch
import re


# TODO: move this into library proper
@contextlib.contextmanager
def no_dispatch() -> Iterator[None]:
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
yield
finally:
del guard

def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor):
elem = wrapper_tensor.elem
metadata_wrapper_tensor = metadata_accessor(wrapper_tensor)
Expand Down
10 changes: 0 additions & 10 deletions torch/testing/_internal/logging_tensor.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import torch
from torch.utils._pytree import tree_map

from typing import Iterator, List
import logging
import contextlib
import itertools
from torch.utils._python_dispatch import TorchDispatchMode, push_torch_dispatch_mode

# TODO: move this into library proper
@contextlib.contextmanager
def no_dispatch() -> Iterator[None]:
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
yield
finally:
del guard


# How the chain of calls works for LoggingTensor:
# 1. Call torch.sin
Expand Down
Loading

0 comments on commit 678213e

Please sign in to comment.