Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add gather operator and torch.zeros, torch.neg mapping #174

Merged
merged 1 commit into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 3 additions & 4 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ def generate_executor(flow_graph: FlowGraph) -> Callable:
logger.info('schedule search space: %d', search_space)

has_cpu_tensor = any(tensor.device.type == 'cpu' for tensor in graph_opt.inputs + graph_opt.outputs)
has_cuda_tensor = any(tensor.device.type == 'cuda' for tensor in graph_opt.inputs + graph_opt.outputs)

if has_cpu_tensor and has_cuda_tensor:
raise RuntimeError('the flow graph contains both CPU and CUDA tensors, currently not supported by hidet')
# has_cuda_tensor = any(tensor.device.type == 'cuda' for tensor in graph_opt.inputs + graph_opt.outputs)
# if has_cpu_tensor and has_cuda_tensor:
# raise RuntimeError('the flow graph contains both CPU and CUDA tensors, currently not supported by hidet')

def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
torch_inputs: List[torch.Tensor] = []
Expand Down
35 changes: 35 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)


@register_function(operator.neg)
def neg(x: Tensor):
return -x


@register_function(torch.sin)
def sin(x: Tensor):
return ops.sin(x)
Expand Down Expand Up @@ -285,6 +290,27 @@ def matmul(x: Tensor, y: Tensor):
return ops.matmul(x, y)


@register_function(torch.zeros)
def zeros(*size, out=None, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
import hidet

if out is not None:
raise NotImplementedError("out is not None")
if layout is not None:
raise NotImplementedError("layout is not None")
if len(size) == 1:
if isinstance(size[0], (list, tuple)):
size = size[0]
shape = [int(v) for v in size]
if dtype is None:
dtype = torch.get_default_dtype()

_ = pin_memory
_ = requires_grad

return hidet.zeros(shape, dtype=dtype_from_torch(dtype), device=device_from_torch(device))


@register_function(torch.ones)
def ones(
*size: Union[int, Sequence[int]],
Expand Down Expand Up @@ -653,3 +679,12 @@ def celu(x: Tensor, alpha: float):
@register_function(torch.nn.functional.logsigmoid)
def logsigmoid(x: Tensor):
return ops.logsigmoid(x)


@register_function(torch.gather)
def gather(x: Tensor, dim: int, index: Tensor, *, sparse_grad=False, out=None):
if sparse_grad:
warnings.warn_once('hidet: gather with sparse_grad=True is not supported. Treat as sparse_grad=False.')
if out is not None:
raise NotImplementedError('hidet: gather with out=... is not supported')
return ops.gather(x, index, axis=dim)
5 changes: 5 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def tensor_half(self: Tensor) -> Tensor:
return ops.cast(self, "float16")


@register_method(torch.Tensor.bool)
def tensor_bool(self: Tensor) -> Tensor:
return ops.cast(self, "bool")


@register_method(torch.Tensor.to)
def tensor_to(self: Tensor, *args, **kwargs) -> Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .definitions.reduce import mean, sum, var, min, max, std, prod, argmin, argmax, all, any
from .definitions.cumulative import cumsum
from .definitions.transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape
from .definitions.transform import transpose, broadcast, pad, tile, split, conv_pad, expand_dims
from .definitions.transform import transpose, broadcast, pad, tile, split, conv_pad, expand_dims, gather
from .definitions.transform import permute_dims
from .definitions.fusion import fused_operator
from .definitions.special import barrier
Expand Down
33 changes: 32 additions & 1 deletion python/hidet/graph/ops/definitions/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from hidet.ir.utils import index_deserialize, index_serialize
from hidet.utils import prod
from .utils import Task, InverseMap, Operator, Tensor, TensorNode, compute, input_like, normalize_dim, can_broadcast
from .utils import TensorInput


def same_shape(shape_a: Sequence[int], shape_b: Sequence[int]) -> bool:
Expand Down Expand Up @@ -187,6 +188,22 @@ def fmap(*output_indices):
super().__init__(name='take', inputs=[data, indices], outputs=[output])


class GatherTask(Task):
def __init__(self, data: TensorInput, indices: TensorInput, axis=0):
data_shape = data.const_shape()
indices_shape = indices.const_shape()
output_shape = data_shape[:axis] + [indices_shape[axis]] + data_shape[axis + 1 :]

def fmap(*output_indices):
index_value = indices[output_indices]
index_value = if_then_else(index_value < 0, index_value + data_shape[axis], index_value)
data_indices = output_indices[:axis] + (index_value,) + output_indices[axis + 1 :]
return data[data_indices]

output = compute(name='output', shape=output_shape, fcompute=lambda *output_indices: fmap(*output_indices))
super().__init__(name='gather', inputs=[data, indices], outputs=[output])


class StridedSliceTask(Task):
def __init__(
self,
Expand Down Expand Up @@ -445,6 +462,15 @@ def __init__(self, data: Tensor, indices: Tensor, axis: int):
)


class GatherOp(Operator):
def __init__(self, data: Tensor, indices: Tensor, axis: int):
super().__init__(
inputs=[data, indices],
attributes={'axis': axis},
task=GatherTask(input_like(data, 'data'), input_like(indices, 'indices'), axis=axis),
)


class StridedSliceOp(Operator):
def __init__(
self,
Expand Down Expand Up @@ -478,8 +504,9 @@ def normalize(data_shape, starts, ends, axes: Optional[List[int]], strides: Opti
if k > 0:
i = i if i is not None else 0
j = j if j is not None else n
if not (-n <= i <= n and -n <= j <= n):
if not (-n <= i <= n and -n <= j):
raise IndexError('Invalid slice')
j = min(j, n)
if i < 0:
i += n
if j < 0:
Expand Down Expand Up @@ -648,6 +675,10 @@ def take(data: Tensor, indices: Tensor, axis: int = 0) -> Tensor:
return TakeOp(data, indices, axis).get_output(0)


def gather(data: Tensor, indices: Tensor, axis: int = 0) -> Tensor:
return GatherOp(data, indices, axis).outputs[0]


def strided_slice(
data: Tensor,
starts: Sequence[Optional[int]],
Expand Down
6 changes: 5 additions & 1 deletion python/hidet/graph/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,10 @@ def torch(self):
"""
import torch

if self.dtype == dtypes.boolean:
# workaround for torch not supporting exporting boolean to dlpack
return torch.from_dlpack(self.to(dtype='uint8')).bool()

return torch.from_dlpack(self)


Expand Down Expand Up @@ -1016,7 +1020,7 @@ def zeros(shape: Sequence[int], dtype='float32', device='cpu') -> Tensor:
shape: Sequence[int]
The shape of new tensor.

dtype: str
dtype: str or DataType
The data type of element of the tensor.

device: Device or str, default 'cpu'
Expand Down
19 changes: 19 additions & 0 deletions tests/operators/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Optional, List
import pytest
import numpy as np
import torch
import hidet as hi
from hidet import ops
from hidet.utils import prod
Expand All @@ -25,6 +26,13 @@ def check_transform(shape, numpy_op, hidet_op, dtype=np.float32, atol=0, rtol=0)
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=atol, rtol=rtol)


def check_transform_torch(shape, torch_op, hidet_op, dtype=np.float32, atol=0, rtol=0):
data = torch.asarray(np.array(np.random.randn(*shape)).astype(dtype))
torch_result = torch_op(data)
hidet_result = hidet_op(hi.asarray(data).cuda()).cpu().numpy()
np.testing.assert_allclose(actual=hidet_result, desired=torch_result.cpu().numpy(), atol=atol, rtol=rtol)


@pytest.mark.parametrize(
"shape, new_shape",
[
Expand Down Expand Up @@ -106,6 +114,17 @@ def test_take(shape, indices_shape, axis):
check_transform(shape, lambda x: np.take(x, indices, axis), lambda x: ops.take(x, hi.asarray(indices).cuda(), axis))


@pytest.mark.parametrize("shape, indices_shape, axis", [[[1234, 512], [2100, 512], 0], [[12, 34, 56], [12, 1, 56], 1]])
def test_gather(shape, indices_shape, axis):
dim_extent = shape[axis]
indices = np.random.randint(0, dim_extent - 1, indices_shape).astype(np.int64)
check_transform_torch(
shape,
lambda x: torch.gather(x, axis, torch.asarray(indices)),
lambda x: ops.gather(x, hi.asarray(indices).cuda(), axis),
)


@pytest.mark.parametrize(
"shape, starts, ends, axes, strides",
[
Expand Down