forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pull Request resolved: pytorch#77969 Approved by: https://github.com/ezyang
- Loading branch information
1 parent
d136852
commit 678213e
Showing
13 changed files
with
273 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Owner(s): ["module: unknown"] | ||
|
||
from torch.testing._internal.common_utils import TestCase, run_tests | ||
import torch | ||
import itertools | ||
from torch.testing._internal.jit_utils import RUN_CUDA | ||
import unittest | ||
from torch._subclasses import FakeTensor | ||
|
||
|
||
class FakeTensorTest(TestCase): | ||
def test_basic(self): | ||
x = FakeTensor.from_tensor(torch.empty(2, 2, device="cpu")) | ||
y = x = FakeTensor.from_tensor(torch.empty(4, 2, 2, device="cpu")) | ||
y = x + x | ||
self.assertEqual(y.shape, (4, 2, 2)) | ||
self.assertEqual(y.device, torch.device("cpu")) | ||
|
||
@unittest.skipIf(not RUN_CUDA, "requires cuda") | ||
def test_shape_take_not_device(self): | ||
x = FakeTensor.from_tensor(torch.empty(1, device="cpu")) | ||
y = FakeTensor.from_tensor(torch.empty(8, 8, device="cuda")) | ||
out = x.resize_as_(y) | ||
self.assertEqual(out.shape, (8, 8)) | ||
self.assertEqual(out.device.type, "cpu") | ||
|
||
@unittest.skipIf(not RUN_CUDA, "requires cuda") | ||
def test_zero_dim(self): | ||
x = FakeTensor.from_tensor(torch.tensor(0.0)) | ||
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda")) | ||
out = x + y | ||
self.assertEqual(out.shape, (4, 4)) | ||
self.assertEqual(out.device, y.device) | ||
|
||
@unittest.skipIf(not RUN_CUDA, "requires cuda") | ||
def test_throw(self): | ||
x = FakeTensor.from_tensor(torch.tensor(0.0)) | ||
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda")) | ||
z = FakeTensor.from_tensor(torch.rand([4, 4], device="cpu")) | ||
self.assertRaises(Exception, lambda: torch.lerp(x, y, z)) | ||
|
||
|
||
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type): | ||
return maybe_contained_type.isSubtypeOf(type) or any( | ||
contains_type(e, maybe_contained_type) for e in type.containedTypes() | ||
) | ||
|
||
|
||
class FakeTensorOperatorInvariants(TestCase): | ||
def test_non_kwarg_only_device(self): | ||
def get_op(schema): | ||
namespace, name = schema.name.split("::") | ||
overload = schema.overload_name if schema.overload_name else "default" | ||
assert namespace == "aten" | ||
return getattr(getattr(torch.ops.aten, name), overload) | ||
|
||
for schema in torch._C._jit_get_all_schemas(): | ||
namespace = schema.name.split("::")[0] | ||
if namespace != "aten": | ||
continue | ||
|
||
ten_type = torch._C.TensorType.get() | ||
if not any( | ||
contains_type(arg.type, ten_type) | ||
for arg in itertools.chain(schema.arguments, schema.returns) | ||
): | ||
continue | ||
|
||
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get()) | ||
has_non_kwarg_device = any( | ||
not arg.kwarg_only and arg.type.isSubtypeOf(opt_device) | ||
for arg in schema.arguments | ||
) | ||
if has_non_kwarg_device: | ||
self.assertTrue( | ||
get_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
|
||
from torch._subclasses.base_tensor import BaseTensor | ||
from torch._subclasses.fake_tensor import FakeTensor, _device_not_kwarg_ops | ||
|
||
__all__ = [ | ||
"BaseTensor", | ||
"FakeTensor", | ||
"_device_not_kwarg_ops", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
|
||
# Ideally, tensor subclasses would would inherit directly from Tensor. | ||
# This is just our staging ground for applying behavior that hasn't yet made it | ||
# into the core Tensor class but that we would like to apply by default. | ||
class BaseTensor(torch.Tensor): | ||
# See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary | ||
# to ensure that super().__new__ can cooperate with each other | ||
@staticmethod | ||
def __new__(cls, elem, *, requires_grad=None): | ||
if requires_grad is None: | ||
return super().__new__(cls, elem) # type: ignore[call-arg] | ||
else: | ||
return cls._make_subclass(cls, elem, requires_grad) | ||
|
||
# If __torch_dispatch__ is defined (which it will be for all our examples) | ||
# the default torch function implementation (which preserves subclasses) | ||
# typically must be disabled | ||
__torch_function__ = torch._C._disabled_torch_function_impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import torch | ||
|
||
from torch._subclasses import BaseTensor | ||
from torch.utils._pytree import tree_map | ||
from functools import partial | ||
from torch.fx.operator_schemas import normalize_function | ||
from typing import Union | ||
|
||
aten = torch.ops.aten | ||
|
||
_device_not_kwarg_ops = ( | ||
aten._resize_output_.default, | ||
aten.nested_tensor.default, | ||
aten.pin_memory.default, | ||
aten.is_pinned.default, | ||
aten.to.device, | ||
aten.to.prim_Device, | ||
aten._pin_memory.default, | ||
aten._resize_output.functional, | ||
aten._resize_output.out, | ||
) | ||
|
||
# Meta tensors give you the ability to run PyTorch code without having to | ||
# actually do computation through tensors allocated on a `meta` device. | ||
# Because the device is `meta`, meta tensors do not model device propagation. | ||
# FakeTensor extends MetaTensors to also carry an additional `fake_device` | ||
# which tracks devices that would have been used. | ||
|
||
|
||
class FakeTensor(BaseTensor): | ||
fake_device: torch.device | ||
|
||
@staticmethod | ||
def __new__(cls, elem, device): | ||
return super().__new__(cls, elem) | ||
|
||
def __init__(self, elem, device: Union[torch.device, str]): | ||
# elem does not need to be recorded, because FakeTensor *is a* elem | ||
assert elem.device.type == "meta" | ||
device = device if isinstance(device, torch.device) else torch.device(device) | ||
assert device.type != "meta" | ||
self.fake_device = device | ||
|
||
@staticmethod | ||
def from_tensor(t): | ||
existing_device = t.device | ||
return FakeTensor(t.to(device="meta"), existing_device) | ||
|
||
@property | ||
def device(self): | ||
return self.fake_device | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||
kwargs = kwargs if kwargs else {} | ||
|
||
|
||
# Run the original computation | ||
r = super().__torch_dispatch__(func, types, args, kwargs) | ||
|
||
def wrap(e, device): | ||
# inplace ops can return fake tensors | ||
if isinstance(e, torch.Tensor) and not isinstance(e, cls): | ||
return FakeTensor(e, device) | ||
else: | ||
return e | ||
|
||
# TODO: handle non-kwarg devices | ||
assert func not in _device_not_kwarg_ops, f"NYI: {func}" | ||
assert ( | ||
func != aten._pin_memory.default and func != aten.pin_memory.default | ||
), f"NYI: {func}" | ||
|
||
# if device is specified, use that | ||
if kwargs.get("device", None): | ||
return tree_map(partial(wrap, device=kwargs["device"]), r) | ||
|
||
# operators which copy size from another tensor do not | ||
# also take device from the size tensor | ||
# other size_as operators are not builtin operators | ||
if func == aten.resize_as_.default: | ||
_, new_kwargs = normalize_function( | ||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | ||
) | ||
# device of the input is returned | ||
return tree_map(partial(wrap, device=new_kwargs["input"].device), r) | ||
|
||
common_device = FakeTensor._find_common_device(func, args, kwargs) | ||
|
||
return tree_map(partial(wrap, device=common_device), r) | ||
|
||
@staticmethod | ||
def _find_common_device(func, args, kwargs): | ||
# cpu - zero-dim tensors can be called in cuda kernels, | ||
# so overwrite the common_device if it the only existing | ||
# device comes from a cpu zero-dim tensor | ||
common_device = None | ||
is_cpu_zero_dim = None | ||
|
||
def cpu_zero_dim(t): | ||
return t.device.type == "cpu" and t.dim() == 0 | ||
|
||
def merge_devices(t): | ||
nonlocal common_device | ||
nonlocal is_cpu_zero_dim | ||
if not isinstance(t, FakeTensor): | ||
return | ||
|
||
if common_device is None: | ||
common_device = t.device | ||
is_cpu_zero_dim = cpu_zero_dim(t) | ||
return | ||
|
||
t_is_cpu_zero_dim = cpu_zero_dim(t) | ||
if t.device == common_device: | ||
if is_cpu_zero_dim: | ||
is_cpu_zero_dim = t_is_cpu_zero_dim | ||
return | ||
|
||
# mismatching devices ! | ||
# if current tensor is cpu 0 dim, defer to existing device | ||
if t_is_cpu_zero_dim: | ||
return | ||
|
||
# current device is from cpu 0 dim tensor, overwrite | ||
if is_cpu_zero_dim: | ||
common_device = t.device | ||
is_cpu_zero_dim = t_is_cpu_zero_dim | ||
return | ||
|
||
# mismatching devices of non-zero dim tensors, throw | ||
# This might be valid behavior and need to be explicitly modeled, e.g. reshape_as | ||
raise Exception( | ||
f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}" | ||
) | ||
|
||
tree_map(merge_devices, args) | ||
tree_map(merge_devices, kwargs) | ||
|
||
assert common_device is not None, f"Could not find common device for {func}" | ||
|
||
return common_device |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.