Skip to content

Commit

Permalink
[Torch] Implements __torch_func__ protocol (#400)
Browse files Browse the repository at this point in the history
This PR implements the `__torch_func__` protocol as suggested at
https://pytorch.org/docs/stable/notes/extending.html#extending-torch-with-a-tensor-like-type
  • Loading branch information
yaoyaoding committed Dec 29, 2023
1 parent 4d71a8d commit 57154a3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/hidet/graph/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,32 @@ def __dlpack_device__(self) -> Tuple[int, int]:

return to_dlpack_device(self)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""
This function is used to support interoperability with PyTorch.
We can use hidet Tensor as the input of PyTorch function:
```
import torch
import hidet
a = hidet.randn([2, 3], dtype='float16', device='cuda')
b = torch.abs(a)
```
See the following documentation for more information:
https://pytorch.org/docs/stable/notes/extending.html#extending-torch-with-a-tensor-like-type
"""
import torch

if kwargs is None:
kwargs = {}
if not all(issubclass(t, (torch.Tensor, Tensor)) for t in types):
return NotImplemented
args = (arg.torch() if isinstance(arg, Tensor) else arg for arg in args)
kwargs = {k: v.torch() if isinstance(v, Tensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)

def tolist(self):
"""
Convert the tensor to a nested list of numbers.
Expand Down
17 changes: 17 additions & 0 deletions tests/frontends/torch/test_torch_interoperability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch
import hidet


def test_as_torch_tensor():
"""
test __torch_func__ protocol
"""
a = hidet.randn([32, 32], dtype='float16', device='cuda')
b = torch.abs(a)
c = hidet.ops.abs(a)
torch.testing.assert_close(b, c.torch())


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 57154a3

Please sign in to comment.