Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add torch.tensor binding #61

Merged
merged 1 commit into from
Jan 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 16 additions & 2 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional, Union, Sequence
from typing import Optional, Union, Sequence, Any
import operator
import torch
from hidet.graph.tensor import Tensor, full_like
from hidet.graph.tensor import Tensor, full_like, from_torch
from hidet.graph import ops
from hidet.utils import same_list
from hidet.ir.type import DataType
Expand Down Expand Up @@ -383,3 +383,17 @@ def bmm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.bmm(..., out=...)")
return ops.matmul(input, mat2)


@register_function(torch.tensor)
def torch_tensor(
data: Any, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False
) -> Tensor:
if requires_grad and torch.is_grad_enabled():
warnings.warn_once("hidet: requires_grad=True when torch.is_grad_enabled(), treating as requires_grad=False")
if isinstance(data, Tensor):
device = device_from_torch(torch_device=device) if device is not None else device
return data.to(device=device, dtype=dtype_from_torch(dtype))
else:
tt = torch.tensor(data, dtype=dtype, device=device)
return from_torch(tt)