Skip to content

Commit

Permalink
[Operator] Add gather operator and torch.zeros, torch.neg mappi…
Browse files Browse the repository at this point in the history
…ng (#174)

.
  • Loading branch information
yaoyaoding committed Apr 13, 2023
1 parent 7f634d8 commit ef81b2a
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 6 deletions.
7 changes: 3 additions & 4 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ def generate_executor(flow_graph: FlowGraph) -> Callable:
logger.info('schedule search space: %d', search_space)

has_cpu_tensor = any(tensor.device.type == 'cpu' for tensor in graph_opt.inputs + graph_opt.outputs)
has_cuda_tensor = any(tensor.device.type == 'cuda' for tensor in graph_opt.inputs + graph_opt.outputs)

if has_cpu_tensor and has_cuda_tensor:
raise RuntimeError('the flow graph contains both CPU and CUDA tensors, currently not supported by hidet')
# has_cuda_tensor = any(tensor.device.type == 'cuda' for tensor in graph_opt.inputs + graph_opt.outputs)
# if has_cpu_tensor and has_cuda_tensor:
# raise RuntimeError('the flow graph contains both CPU and CUDA tensors, currently not supported by hidet')

def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
torch_inputs: List[torch.Tensor] = []
Expand Down
35 changes: 35 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)


@register_function(operator.neg)
def neg(x: Tensor):
return -x


@register_function(torch.sin)
def sin(x: Tensor):
return ops.sin(x)
Expand Down Expand Up @@ -285,6 +290,27 @@ def matmul(x: Tensor, y: Tensor):
return ops.matmul(x, y)


@register_function(torch.zeros)
def zeros(*size, out=None, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
import hidet

if out is not None:
raise NotImplementedError("out is not None")
if layout is not None:
raise NotImplementedError("layout is not None")
if len(size) == 1:
if isinstance(size[0], (list, tuple)):
size = size[0]
shape = [int(v) for v in size]
if dtype is None:
dtype = torch.get_default_dtype()

_ = pin_memory
_ = requires_grad

return hidet.zeros(shape, dtype=dtype_from_torch(dtype), device=device_from_torch(device))


@register_function(torch.ones)
def ones(
*size: Union[int, Sequence[int]],
Expand Down Expand Up @@ -653,3 +679,12 @@ def celu(x: Tensor, alpha: float):
@register_function(torch.nn.functional.logsigmoid)
def logsigmoid(x: Tensor):
return ops.logsigmoid(x)


@register_function(torch.gather)
def gather(x: Tensor, dim: int, index: Tensor, *, sparse_grad=False, out=None):
if sparse_grad:
warnings.warn_once('hidet: gather with sparse_grad=True is not supported. Treat as sparse_grad=False.')
if out is not None:
raise NotImplementedError('hidet: gather with out=... is not supported')
return ops.gather(x, index, axis=dim)
5 changes: 5 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def tensor_half(self: Tensor) -> Tensor:
return ops.cast(self, "float16")


@register_method(torch.Tensor.bool)
def tensor_bool(self: Tensor) -> Tensor:
return ops.cast(self, "bool")


@register_method(torch.Tensor.to)
def tensor_to(self: Tensor, *args, **kwargs) -> Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .definitions.reduce import mean, sum, var, min, max, std, prod, argmin, argmax, all, any
from .definitions.cumulative import cumsum
from .definitions.transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape
from .definitions.transform import transpose, broadcast, pad, tile, split, conv_pad, expand_dims
from .definitions.transform import transpose, broadcast, pad, tile, split, conv_pad, expand_dims, gather
from .definitions.transform import permute_dims
from .definitions.fusion import fused_operator
from .definitions.special import barrier
Expand Down
33 changes: 32 additions & 1 deletion python/hidet/graph/ops/definitions/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from hidet.ir.utils import index_deserialize, index_serialize
from hidet.utils import prod
from .utils import Task, InverseMap, Operator, Tensor, TensorNode, compute, input_like, normalize_dim, can_broadcast
from .utils import TensorInput


def same_shape(shape_a: Sequence[int], shape_b: Sequence[int]) -> bool:
Expand Down Expand Up @@ -187,6 +188,22 @@ def fmap(*output_indices):
super().__init__(name='take', inputs=[data, indices], outputs=[output])


class GatherTask(Task):
def __init__(self, data: TensorInput, indices: TensorInput, axis=0):
data_shape = data.const_shape()
indices_shape = indices.const_shape()
output_shape = data_shape[:axis] + [indices_shape[axis]] + data_shape[axis + 1 :]

def fmap(*output_indices):
index_value = indices[output_indices]
index_value = if_then_else(index_value < 0, index_value + data_shape[axis], index_value)
data_indices = output_indices[:axis] + (index_value,) + output_indices[axis + 1 :]
return data[data_indices]

output = compute(name='output', shape=output_shape, fcompute=lambda *output_indices: fmap(*output_indices))
super().__init__(name='gather', inputs=[data, indices], outputs=[output])


class StridedSliceTask(Task):
def __init__(
self,
Expand Down Expand Up @@ -445,6 +462,15 @@ def __init__(self, data: Tensor, indices: Tensor, axis: int):
)


class GatherOp(Operator):
def __init__(self, data: Tensor, indices: Tensor, axis: int):
super().__init__(
inputs=[data, indices],
attributes={'axis': axis},
task=GatherTask(input_like(data, 'data'), input_like(indices, 'indices'), axis=axis),
)


class StridedSliceOp(Operator):
def __init__(
self,
Expand Down Expand Up @@ -478,8 +504,9 @@ def normalize(data_shape, starts, ends, axes: Optional[List[int]], strides: Opti
if k > 0:
i = i if i is not None else 0
j = j if j is not None else n
if not (-n <= i <= n and -n <= j <= n):
if not (-n <= i <= n and -n <= j):
raise IndexError('Invalid slice')
j = min(j, n)
if i < 0:
i += n
if j < 0:
Expand Down Expand Up @@ -648,6 +675,10 @@ def take(data: Tensor, indices: Tensor, axis: int = 0) -> Tensor:
return TakeOp(data, indices, axis).get_output(0)


def gather(data: Tensor, indices: Tensor, axis: int = 0) -> Tensor:
return GatherOp(data, indices, axis).outputs[0]


def strided_slice(
data: Tensor,
starts: Sequence[Optional[int]],
Expand Down
4 changes: 4 additions & 0 deletions python/hidet/graph/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,10 @@ def torch(self):
"""
import torch

if self.dtype == dtypes.boolean:
# workaround for torch not supporting exporting boolean to dlpack
return torch.from_dlpack(self.to(dtype='uint8')).bool()

return torch.from_dlpack(self)


Expand Down
19 changes: 19 additions & 0 deletions tests/operators/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Optional, List
import pytest
import numpy as np
import torch
import hidet as hi
from hidet import ops
from hidet.utils import prod
Expand All @@ -25,6 +26,13 @@ def check_transform(shape, numpy_op, hidet_op, dtype=np.float32, atol=0, rtol=0)
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=atol, rtol=rtol)


def check_transform_torch(shape, torch_op, hidet_op, dtype=np.float32, atol=0, rtol=0):
data = torch.asarray(np.array(np.random.randn(*shape)).astype(dtype))
torch_result = torch_op(data)
hidet_result = hidet_op(hi.asarray(data).cuda()).cpu().numpy()
np.testing.assert_allclose(actual=hidet_result, desired=torch_result.cpu().numpy(), atol=atol, rtol=rtol)


@pytest.mark.parametrize(
"shape, new_shape",
[
Expand Down Expand Up @@ -106,6 +114,17 @@ def test_take(shape, indices_shape, axis):
check_transform(shape, lambda x: np.take(x, indices, axis), lambda x: ops.take(x, hi.asarray(indices).cuda(), axis))


@pytest.mark.parametrize("shape, indices_shape, axis", [[[1234, 512], [2100, 512], 0], [[12, 34, 56], [12, 1, 56], 1]])
def test_gather(shape, indices_shape, axis):
dim_extent = shape[axis]
indices = np.random.randint(0, dim_extent - 1, indices_shape).astype(np.int64)
check_transform_torch(
shape,
lambda x: torch.gather(x, axis, torch.asarray(indices)),
lambda x: ops.gather(x, hi.asarray(indices).cuda(), axis),
)


@pytest.mark.parametrize(
"shape, starts, ends, axes, strides",
[
Expand Down

0 comments on commit ef81b2a

Please sign in to comment.