Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add clamp/isinf/any/all op, enhance where op #343

Merged
merged 1 commit into from
Aug 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 37 additions & 2 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from hidet.runtime.device import Device
from .interpreter import register_function, register_method
from .interpreter import warnings
from .utils import dtype_from_torch, device_from_torch, normalize_to_scalar
from .utils import dtype_from_torch, device_from_torch, normalize_to_scalar, convert_to_scalar_if_possible

Number = Union[int, float, bool]

Expand Down Expand Up @@ -590,7 +590,7 @@ def addmm(


@register_function(torch.where)
def where(condition: Tensor, x: Tensor, y: Tensor):
def where(condition: Tensor, x: Union[Tensor, Number], y: Union[Tensor, Number]):
return ops.where(cond=condition, x=x, y=y)


Expand Down Expand Up @@ -1069,3 +1069,38 @@ def zeros_like(
hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype) if dtype else x.dtype

return ops.full(x.shape, dtype=hidet_dtype, device=hidet_device, value=hidet_dtype.zero)


@register_function(torch.clamp)
def clamp(
x: Tensor,
min: Optional[Union[Tensor, Number]] = None,
max: Optional[Union[Tensor, Number]] = None,
*,
out: Optional[Tensor] = None,
) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.clamp(..., out=...)")

min = convert_to_scalar_if_possible(min)
max = convert_to_scalar_if_possible(max)

if min is None and max is None:
return x
elif min is None:
if not isinstance(max, Tensor):
assert isinstance(max, (int, float, complex))
max = ops.full([], value=max, dtype=x.dtype, device=x.device)
return ops.minimum(x, max)
elif max is None:
if not isinstance(min, Tensor):
assert isinstance(min, (int, float, complex))
min = ops.full([], value=min, dtype=x.dtype, device=x.device)
return ops.maximum(x, min)
else:
return ops.clamp(x, min, max)


@register_function(torch.isinf)
def isinf(x: Tensor) -> Tensor:
return ops.isinf(x)
10 changes: 10 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,13 @@ def tensor_repeat(self: Tensor, *sizes: int) -> Tensor:
@register_method(torch.Tensor.detach)
def tensor_detach(self: Tensor) -> Tensor:
return self


@register_method(torch.Tensor.any)
def tensor_any(self: Tensor, dim=None, keepdim=False) -> Tensor:
return ops.any(self, axis=dim, keepdims=keepdim)


@register_method(torch.Tensor.all)
def tensor_all(self: Tensor, dim=None, keepdim=False) -> Tensor:
return ops.all(self, axis=dim, keepdims=keepdim)
9 changes: 9 additions & 0 deletions python/hidet/graph/frontend/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,12 @@ def normalize_to_scalar(value: Union[Tensor, Expr, float, int, bool]) -> Union[E
raise RuntimeError(f'Cannot convert tensor {value.signature()} to scalar')
else:
return value


def convert_to_scalar_if_possible(x: Union[Tensor, Expr, float, int, bool]) -> Optional[Union[Expr, float, int, bool]]:
if isinstance(x, Tensor):
if len(x.shape) == 0 and x.storage:
return x.item()
return None
else:
return x
13 changes: 3 additions & 10 deletions python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,8 @@
from .matmul import batch_matmul, matmul, matmul_x86
from .conv1d import conv1d, conv1d_gemm
from .conv1d_transpose import conv1d_transpose
from .conv2d import (
conv2d,
conv2d_channel_last,
conv2d_winograd,
conv2d_gemm,
conv2d_gemm_fp16,
conv2d_gemm_fp16_channel_last,
conv2d_gemm_image_transform,
)
from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16
from .conv2d import conv2d_gemm_fp16_channel_last, conv2d_gemm_image_transform
from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm
from .conv3d import conv3d, conv3d_gemm
from .conv3d_transpose import conv3d_transpose
Expand All @@ -38,7 +31,7 @@
from .arithmetic import floor, ceil, round, trunc, sqrt, rsqrt, pow, abs
from .arithmetic import reciprocal, exp, expm1, log, log2, log10, log1p, logaddexp, erf
from .arithmetic import bitwise_right_shift, bitwise_left_shift, bitwise_and, bitwise_invert, bitwise_or
from .arithmetic import bitwise_xor, maximum, minimum
from .arithmetic import bitwise_xor, maximum, minimum, clamp
from .arithmetic import isfinite, isinf, isnan, sign, where
from .arithmetic import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, atan2
from .complex import real, imag, conj, make_complex
Expand Down
106 changes: 94 additions & 12 deletions python/hidet/graph/ops/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
from typing import List, Callable, Any, Union, Optional, Dict

from hidet.ir import primitives
from hidet.ir import expr, dtypes
from hidet.ir import Var, expr, dtypes
from hidet.ir.type import DataType
from hidet.ir.expr import Expr, Var, if_then_else
from hidet.ir.tools import rewrite
from hidet.ir.expr import Expr, if_then_else, is_true
from hidet.utils import prod, same_list
from .utils import Task, Operator, Tensor, TensorNode, InverseMap, compute, input_like
from .utils import broadcast_shape, broadcast_shapes, broadcast_indices

PyScalar = Union[int, float, bool]


# In order for the subgraph rewrite of Composite Elementwise Operator to work,
# we need to store the callable in an Operator object. But lambda cannot be pickled,
# so we define auxiliary classes UnaryElementwiseOperation and BinaryElementwiseOperation
Expand Down Expand Up @@ -117,7 +120,7 @@ def __init__(self, name: str, args: List[TensorNode], op: Callable[[Any], Any]):
inverse_map={
v: InverseMap.identity(len(v_shape))
for v, v_shape in zip(args, shapes)
if prod(v_shape) == prod(out_shape)
if is_true(prod(v_shape) == prod(out_shape)) and len(v_shape) == len(out_shape)
},
)

Expand Down Expand Up @@ -207,6 +210,15 @@ def __init__(self, x: Tensor, y: Tensor, op, name: str):
)


def get_dtype(scalar: Expr):
from hidet.ir.tools import infer_type

inferred_type = infer_type(scalar)
if not isinstance(inferred_type, DataType):
raise TypeError(f'Expected scalar to be of type DataType, got {type(inferred_type)}')
return inferred_type


class CompositeElementwiseOp(Operator):
def __init__(
self,
Expand Down Expand Up @@ -238,37 +250,37 @@ def resolve_dtype(tensor_dtype: DataType, scalar_dtype: DataType) -> DataType:

class AddScalarOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, scalar: Expr):
dtype = resolve_dtype(x.dtype, scalar.type)
dtype = resolve_dtype(x.dtype, get_dtype(scalar))
super().__init__(x, op=lambda v: v + dtype(scalar), attributes={'scalar': scalar}, name='adds')


class SubScalarOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, scalar: Expr):
dtype = resolve_dtype(x.dtype, scalar.type)
dtype = resolve_dtype(x.dtype, get_dtype(scalar))
super().__init__(x, op=lambda v: v - dtype(scalar), attributes={'scalar': scalar}, name='subs')


class RSubScalarOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, scalar: Expr):
dtype = resolve_dtype(x.dtype, scalar.type)
dtype = resolve_dtype(x.dtype, get_dtype(scalar))
super().__init__(x, op=lambda v: dtype(scalar) - v, attributes={'scalar': scalar}, name='rsubs')


class MultiplyScalarOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, scalar: Expr):
dtype = resolve_dtype(x.dtype, scalar.type)
dtype = resolve_dtype(x.dtype, get_dtype(scalar))
super().__init__(x, op=lambda v: v * dtype(scalar), attributes={'scalar': scalar}, name='muls')


class DivideScalarOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, scalar: Expr):
dtype = resolve_dtype(x.dtype, scalar.type)
dtype = resolve_dtype(x.dtype, get_dtype(scalar))
super().__init__(x, op=lambda v: v / dtype(scalar), attributes={'scalar': scalar}, name='divs')


class RDivideScalarOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, scalar: Expr):
dtype = resolve_dtype(x.dtype, scalar.type)
dtype = resolve_dtype(x.dtype, get_dtype(scalar))
super().__init__(x, op=lambda v: dtype(scalar) / v, attributes={'scalar': scalar}, name='rdivs')


Expand Down Expand Up @@ -478,6 +490,19 @@ def __init__(self, x: Tensor):
)


class ClampOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, min_value: Union[int, float], max_value: Union[int, float]):
assert isinstance(min_value, (int, float))
assert isinstance(max_value, (int, float))
min_value = x.dtype(min_value)
max_value = x.dtype(max_value)
super().__init__(
x,
op=lambda a: if_then_else(a < min_value, min_value, if_then_else(a > max_value, max_value, a)),
name='clamp',
)


class RightShiftOp(BinaryElementwiseOp):
def __init__(self, x: Tensor, y: Tensor):
super().__init__(x, y, op=lambda a, b: expr.RightShift(a, b), name='rightshift')
Expand Down Expand Up @@ -522,6 +547,47 @@ def __init__(self, cond: Tensor, x: Tensor, y: Tensor):
)


class WhereScalarScalarOp(Operator):
def __init__(self, cond: Tensor, x: PyScalar, y: PyScalar):
if isinstance(x, int) and isinstance(y, int):
dtype = dtypes.default_int_dtype
elif isinstance(x, float) or isinstance(y, float):
dtype = dtypes.default_float_dtype
else:
raise ValueError(f'Unsupported scalar type: {type(x)}')
x, y = dtype(x), dtype(y)
super().__init__(
inputs=[cond],
attributes={'x': x, 'y': y},
task=UnaryElementwiseTask(name='where', x=input_like(cond, 'cond'), op=lambda a: if_then_else(a, x, y)),
)


class WhereScalarTensorOp(Operator):
def __init__(self, cond: Tensor, y: Tensor, x: PyScalar):
dtype = y.dtype
x = dtype(x)
super().__init__(
inputs=[cond, y],
attributes={'x': x},
task=BinaryElementwiseTask(
name='where', x=input_like(cond, 'cond'), y=input_like(y, 'y'), op=lambda a, b: if_then_else(a, x, b)
),
)


class WhereTensorScalarOp(Operator):
def __init__(self, cond: Tensor, x: Tensor, y: PyScalar):
y = x.dtype(y)
super().__init__(
inputs=[cond, x],
attributes={'y': y},
task=BinaryElementwiseTask(
name='where', x=input_like(cond, 'cond'), y=input_like(x, 'x'), op=lambda a, b: if_then_else(a, b, y)
),
)


class MaxOp(Operator):
def __init__(self, *tensors: Tensor):
def scalar_max(args: List[expr.Expr]):
Expand Down Expand Up @@ -792,10 +858,25 @@ def sign(x: Tensor) -> Tensor:
return SignOp(x).outputs[0]


def where(cond: Tensor, x: Tensor, y: Tensor) -> Tensor:
def clamp(x: Tensor, min: Union[Tensor, float, int], max: Union[Tensor, float, int]) -> Tensor:
if isinstance(min, Tensor) or isinstance(max, Tensor):
raise NotImplementedError('clamp with tensor min/max is not implemented yet')
return ClampOp(x, min, max).outputs[0]


def where(cond: Tensor, x: Union[Tensor, PyScalar], y: Union[Tensor, PyScalar]) -> Tensor:
if cond.dtype != dtypes.boolean:
raise ValueError('The condition tensor must have dtype "bool", but got {}'.format(cond.dtype.name))
return WhereOp(cond, x, y).outputs[0]
if isinstance(x, Tensor) and isinstance(y, Tensor):
return WhereOp(cond, x, y).outputs[0]
elif isinstance(x, Tensor) and isinstance(y, (int, float, complex)):
return WhereTensorScalarOp(cond, x=x, y=y).outputs[0]
elif isinstance(x, (int, float, complex)) and isinstance(y, Tensor):
return WhereScalarTensorOp(cond, x=x, y=y).outputs[0]
elif isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)):
return WhereScalarScalarOp(cond, x=x, y=y).outputs[0]
else:
raise ValueError('Invalid arguments for where: x={}, y={}'.format(x, y))


def maximum(a: Tensor, b: Tensor, *others: Tensor) -> Tensor:
Expand All @@ -812,7 +893,8 @@ def mod(x: Tensor, y: Tensor) -> Tensor:
return ModOp(x, y).outputs[0]


remainder = mod
def remainder(x: Tensor, y: Tensor) -> Tensor:
return mod(x, y)


def abs(x: Tensor) -> Tensor:
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/reduce/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def allow_epilogue(self) -> bool:
def allow_prologue(self) -> bool:
return False

def implement_cuda(self, working_dir: str) -> IRModule:
def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
rank = len(self.inputs[0].shape)
if rank - 1 in self.dims:
return tune.extract_ir_modules(self.cuda_schedule_reduce_by_warp)
Expand All @@ -80,7 +80,7 @@ def cuda_schedule_reduce_by_warp(self, use_atomic=True) -> IRModule:
xdtype = x.type.dtype
shape: List[Int] = list(x.shape)
lanes = 1
vtype: DataType = xdtype
vtype: Union[DataType, VectorType] = xdtype
if xdtype.nbytes < 4:
num_eles: int = 4 // xdtype.nbytes
if is_constant(shape[-1]) and shape[-1] % num_eles == 0:
Expand Down Expand Up @@ -204,7 +204,7 @@ def cuda_schedule_reduce_by_default(self) -> IRModule:
shape: List[Int] = list(x.shape)

lanes = 1
vtype: DataType = xdtype
vtype: Union[VectorType, DataType] = xdtype
if xdtype.nbytes < 4:
num_eles: int = 4 // xdtype.nbytes
if shape[-1] % num_eles == 0:
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/ir/primitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# pylint: disable=redefined-builtin
from .math import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, expm1, abs
from .math import max, min, exp, pow, sqrt, rsqrt, erf, ceil, log, log2, log10, log1p, round, floor, trunc
from .math import isfinite, isinf, isnan, make_vector
from .math import isfinite, isinf, isnan, make_vector, atan2, mod

from .complex import real, imag, conj, make_complex

Expand Down