Skip to content

Commit

Permalink
[Frontend] Add torch.tensor binding (#61)
Browse files Browse the repository at this point in the history
add torch.tensor binding
  • Loading branch information
yaoyaoding committed Jan 6, 2023
1 parent 582d573 commit 054405e
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional, Union, Sequence
from typing import Optional, Union, Sequence, Any
import operator
import torch
from hidet.graph.tensor import Tensor, full_like
from hidet.graph.tensor import Tensor, full_like, from_torch
from hidet.graph import ops
from hidet.utils import same_list
from hidet.ir.type import DataType
Expand Down Expand Up @@ -383,3 +383,17 @@ def bmm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.bmm(..., out=...)")
return ops.matmul(input, mat2)


@register_function(torch.tensor)
def torch_tensor(
data: Any, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False
) -> Tensor:
if requires_grad and torch.is_grad_enabled():
warnings.warn_once("hidet: requires_grad=True when torch.is_grad_enabled(), treating as requires_grad=False")
if isinstance(data, Tensor):
device = device_from_torch(torch_device=device) if device is not None else device
return data.to(device=device, dtype=dtype_from_torch(dtype))
else:
tt = torch.tensor(data, dtype=dtype, device=device)
return from_torch(tt)

0 comments on commit 054405e

Please sign in to comment.