Skip to content

Commit

Permalink
Extend __new__ on subclasses to set custom_device and custom_strides
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#77970

Approved by: https://github.com/Chillee
  • Loading branch information
Elias Ellison authored and pytorchmergebot committed May 31, 2022
1 parent 678213e commit 98e0816
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 29 deletions.
5 changes: 4 additions & 1 deletion test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import unittest
from torch._subclasses import FakeTensor


class FakeTensorTest(TestCase):
def test_basic(self):
x = FakeTensor.from_tensor(torch.empty(2, 2, device="cpu"))
Expand Down Expand Up @@ -39,6 +38,10 @@ def test_throw(self):
z = FakeTensor.from_tensor(torch.rand([4, 4], device="cpu"))
self.assertRaises(Exception, lambda: torch.lerp(x, y, z))

def test_dispatch_device(self):
x = FakeTensor.from_tensor(torch.rand([4, 4]))
self.assertEqual(x.device.type, "cpu")


def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
Expand Down
3 changes: 2 additions & 1 deletion tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,8 @@ def gen_pyi(
],
"as_subclass": ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
"_make_subclass": [
"def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."
"def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False,"
" dispatch_device: _bool=False) -> Tensor: ..."
],
"__getitem__": ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
"__setitem__": [
Expand Down
2 changes: 0 additions & 2 deletions torch/_subclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import torch

from torch._subclasses.base_tensor import BaseTensor
from torch._subclasses.fake_tensor import FakeTensor, _device_not_kwarg_ops

__all__ = [
"BaseTensor",
"FakeTensor",
"_device_not_kwarg_ops",
]
19 changes: 0 additions & 19 deletions torch/_subclasses/base_tensor.py

This file was deleted.

19 changes: 13 additions & 6 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch

from torch._subclasses import BaseTensor
from torch.utils._pytree import tree_map
from functools import partial
from torch.fx.operator_schemas import normalize_function
Expand All @@ -27,12 +26,12 @@
# which tracks devices that would have been used.


class FakeTensor(BaseTensor):
class FakeTensor(torch.Tensor):
fake_device: torch.device

@staticmethod
def __new__(cls, elem, device):
return super().__new__(cls, elem)
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad, dispatch_device=True)

def __init__(self, elem, device: Union[torch.device, str]):
# elem does not need to be recorded, because FakeTensor *is a* elem
Expand All @@ -46,16 +45,22 @@ def from_tensor(t):
existing_device = t.device
return FakeTensor(t.to(device="meta"), existing_device)

@property
def device(self):
return self.fake_device
# TODO: resolve error in default __repr__
def __repr__(self):
return f"FakeTensor({self.fake_device})"

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}

# This classes virtualizes .device() calls, need to short-circuit
# it insteead of calling device again or we would keep on recurring
if func == torch.ops.prim.device.default:
assert len(args) == 1 and isinstance(args[0], FakeTensor)
return args[0].fake_device

# Run the original computation

r = super().__torch_dispatch__(func, types, args, kwargs)

def wrap(e, device):
Expand Down Expand Up @@ -140,3 +145,5 @@ def merge_devices(t):
assert common_device is not None, f"Could not find common device for {func}"

return common_device

__torch_function__ = torch._C._disabled_torch_function_impl

0 comments on commit 98e0816

Please sign in to comment.