Skip to content

Commit

Permalink
add support for type_as/_to_copy
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#77971

Approved by: https://github.com/ezyang
  • Loading branch information
Elias Ellison authored and pytorchmergebot committed May 31, 2022
1 parent 98e0816 commit 4c18f36
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def test_dispatch_device(self):
x = FakeTensor.from_tensor(torch.rand([4, 4]))
self.assertEqual(x.device.type, "cpu")

@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_type_as(self):
x = FakeTensor.from_tensor(torch.rand([16, 1], device='cpu'))
y = FakeTensor.from_tensor(torch.rand([4, 4], device='cuda'))
out = x.type_as(y)
self.assertEqual(out.device.type, "cuda")

def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
Expand Down
12 changes: 12 additions & 0 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.utils._pytree import tree_map
from functools import partial
from torch.fx.operator_schemas import normalize_function
from torch.utils._mode_utils import no_dispatch
from typing import Union

aten = torch.ops.aten
Expand Down Expand Up @@ -61,6 +62,17 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

# Run the original computation

# _to_copy fails when run with FakeTensors to cuda device
# TODO: debug
if func == torch.ops.aten._to_copy.default:
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
out_device = new_kwargs.pop("device", new_kwargs["input"].device)
with no_dispatch():
input = new_kwargs.pop("input").to("meta")
return FakeTensor(torch.ops.aten._to_copy(input, **new_kwargs), out_device)

r = super().__torch_dispatch__(func, types, args, kwargs)

def wrap(e, device):
Expand Down

0 comments on commit 4c18f36

Please sign in to comment.