diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 2fbc68d54db11..16758aa9945e4 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -7,7 +7,6 @@ import unittest from torch._subclasses import FakeTensor - class FakeTensorTest(TestCase): def test_basic(self): x = FakeTensor.from_tensor(torch.empty(2, 2, device="cpu")) @@ -39,6 +38,10 @@ def test_throw(self): z = FakeTensor.from_tensor(torch.rand([4, 4], device="cpu")) self.assertRaises(Exception, lambda: torch.lerp(x, y, z)) + def test_dispatch_device(self): + x = FakeTensor.from_tensor(torch.rand([4, 4])) + self.assertEqual(x.device.type, "cpu") + def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type): return maybe_contained_type.isSubtypeOf(type) or any( diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index aec8aaf0a402c..1fef1e4c4ee42 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -614,7 +614,8 @@ def gen_pyi( ], "as_subclass": ["def as_subclass(self, cls: Tensor) -> Tensor: ..."], "_make_subclass": [ - "def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..." + "def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False," + " dispatch_device: _bool=False) -> Tensor: ..." ], "__getitem__": ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)], "__setitem__": [ diff --git a/torch/_subclasses/__init__.py b/torch/_subclasses/__init__.py index f6e50bec6f73e..53019165aa20f 100644 --- a/torch/_subclasses/__init__.py +++ b/torch/_subclasses/__init__.py @@ -1,10 +1,8 @@ import torch -from torch._subclasses.base_tensor import BaseTensor from torch._subclasses.fake_tensor import FakeTensor, _device_not_kwarg_ops __all__ = [ - "BaseTensor", "FakeTensor", "_device_not_kwarg_ops", ] diff --git a/torch/_subclasses/base_tensor.py b/torch/_subclasses/base_tensor.py deleted file mode 100644 index 9b9dff74adb0f..0000000000000 --- a/torch/_subclasses/base_tensor.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch - -# Ideally, tensor subclasses would would inherit directly from Tensor. -# This is just our staging ground for applying behavior that hasn't yet made it -# into the core Tensor class but that we would like to apply by default. -class BaseTensor(torch.Tensor): - # See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary - # to ensure that super().__new__ can cooperate with each other - @staticmethod - def __new__(cls, elem, *, requires_grad=None): - if requires_grad is None: - return super().__new__(cls, elem) # type: ignore[call-arg] - else: - return cls._make_subclass(cls, elem, requires_grad) - - # If __torch_dispatch__ is defined (which it will be for all our examples) - # the default torch function implementation (which preserves subclasses) - # typically must be disabled - __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 6903573e1d79d..54bcafa4ffb21 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,6 +1,5 @@ import torch -from torch._subclasses import BaseTensor from torch.utils._pytree import tree_map from functools import partial from torch.fx.operator_schemas import normalize_function @@ -27,12 +26,12 @@ # which tracks devices that would have been used. -class FakeTensor(BaseTensor): +class FakeTensor(torch.Tensor): fake_device: torch.device @staticmethod def __new__(cls, elem, device): - return super().__new__(cls, elem) + return torch.Tensor._make_subclass(cls, elem, elem.requires_grad, dispatch_device=True) def __init__(self, elem, device: Union[torch.device, str]): # elem does not need to be recorded, because FakeTensor *is a* elem @@ -46,16 +45,22 @@ def from_tensor(t): existing_device = t.device return FakeTensor(t.to(device="meta"), existing_device) - @property - def device(self): - return self.fake_device + # TODO: resolve error in default __repr__ + def __repr__(self): + return f"FakeTensor({self.fake_device})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + # This classes virtualizes .device() calls, need to short-circuit + # it insteead of calling device again or we would keep on recurring + if func == torch.ops.prim.device.default: + assert len(args) == 1 and isinstance(args[0], FakeTensor) + return args[0].fake_device # Run the original computation + r = super().__torch_dispatch__(func, types, args, kwargs) def wrap(e, device): @@ -140,3 +145,5 @@ def merge_devices(t): assert common_device is not None, f"Could not find common device for {func}" return common_device + + __torch_function__ = torch._C._disabled_torch_function_impl