From 4c18f362a9d88e664fa1ced1ac77e1c6ccac960e Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 31 May 2022 07:02:18 -0700 Subject: [PATCH] add support for type_as/_to_copy Pull Request resolved: https://github.com/pytorch/pytorch/pull/77971 Approved by: https://github.com/ezyang --- test/test_fake_tensor.py | 6 ++++++ torch/_subclasses/fake_tensor.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 16758aa9945e4..cefb167ebef42 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -42,6 +42,12 @@ def test_dispatch_device(self): x = FakeTensor.from_tensor(torch.rand([4, 4])) self.assertEqual(x.device.type, "cpu") + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_type_as(self): + x = FakeTensor.from_tensor(torch.rand([16, 1], device='cpu')) + y = FakeTensor.from_tensor(torch.rand([4, 4], device='cuda')) + out = x.type_as(y) + self.assertEqual(out.device.type, "cuda") def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type): return maybe_contained_type.isSubtypeOf(type) or any( diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 54bcafa4ffb21..e83754fe81037 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -3,6 +3,7 @@ from torch.utils._pytree import tree_map from functools import partial from torch.fx.operator_schemas import normalize_function +from torch.utils._mode_utils import no_dispatch from typing import Union aten = torch.ops.aten @@ -61,6 +62,17 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # Run the original computation + # _to_copy fails when run with FakeTensors to cuda device + # TODO: debug + if func == torch.ops.aten._to_copy.default: + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + out_device = new_kwargs.pop("device", new_kwargs["input"].device) + with no_dispatch(): + input = new_kwargs.pop("input").to("meta") + return FakeTensor(torch.ops.aten._to_copy(input, **new_kwargs), out_device) + r = super().__torch_dispatch__(func, types, args, kwargs) def wrap(e, device):