Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] non-linear operator support + tests #143

Merged
merged 21 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f52ae28
assertion message for IR check
AndreSlavescu Jan 24, 2023
82afa1a
assertion message for valid IR check
AndreSlavescu Jan 24, 2023
8cd661c
fixed lint
AndreSlavescu Jan 24, 2023
a7db719
sigmoid and hardsigmoid module support
AndreSlavescu Mar 3, 2023
4fda6f5
Merge branch 'docs' of https://github.com/AndreSlavescu/hidet into docs
AndreSlavescu Mar 3, 2023
61ecba3
Merge branch 'main' of https://github.com/hidet-org/hidet into docs
AndreSlavescu Mar 7, 2023
03e6250
Merge branch 'main' of https://github.com/AndreSlavescu/hidet into docs
AndreSlavescu Mar 7, 2023
bae449d
Sigmoid/Hardsigmoid Modules
AndreSlavescu Mar 7, 2023
4c9c78f
extended module/function support
AndreSlavescu Mar 9, 2023
923457b
reformat
AndreSlavescu Mar 9, 2023
88c1a87
register_functions typo fix + normalization tests
AndreSlavescu Mar 9, 2023
28c4960
tests + refactoring
AndreSlavescu Mar 14, 2023
c414ffa
removed group norm
AndreSlavescu Mar 16, 2023
3ca4c14
batch_norm bug fix
AndreSlavescu Mar 16, 2023
7fcec9f
non-linear operators
AndreSlavescu Mar 23, 2023
79d086a
non-linear activation operators
AndreSlavescu Mar 23, 2023
1c48549
Merge pull request #7 from AndreSlavescu/operators
AndreSlavescu Mar 23, 2023
4ded510
Merge branch 'main' of https://github.com/hidet-org/hidet into operators
AndreSlavescu Mar 24, 2023
aae1a8b
Merge branch 'hidet-org:main' into main
AndreSlavescu Mar 24, 2023
74e45d2
optimized operator definitions
AndreSlavescu Mar 24, 2023
963c203
Merge pull request #8 from AndreSlavescu/operators
AndreSlavescu Mar 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

project = 'Hidet'
author = 'Hidet Team'
copyright = '2022, Hidet Authors'
copyright = '2023, Hidet Authors'

# The full version, including alpha/beta/rc tags
release = '0.1'
Expand Down
48 changes: 48 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor]):
return y


@register_function(torch.nn.functional.bilinear)
def bilinear(x_1: Tensor, x_2: Tensor, weight: Tensor, bias: Optional[Tensor]):
y = ops.matmul(x_1, ops.matmul(weight, x_2))
if bias is not None:
y = y + bias
return y


@register_function(operator.add)
def add(x: Tensor, y: Tensor):
return ops.add(x, y)
Expand Down Expand Up @@ -456,3 +464,43 @@ def hardswish(x: Tensor, inplace: bool):
if inplace:
warnings.warn_once('hidet: hardswish with inplace=True is not supported. Treat as inplace=False.')
return ops.hardswish(x)


@register_function(torch.nn.functional.softmin)
def softmin(x: Tensor, axis: int):
return ops.softmin(x, axis)


@register_function(torch.nn.functional.softplus)
def softplus(x: Tensor, beta: int, threshold: int):
return ops.softplus(x, beta, threshold)


@register_function(torch.nn.functional.softshrink)
def softshrink(x: Tensor, lambda_val: float):
return ops.softshrink(x, lambda_val)


@register_function(torch.nn.functional.tanhshrink)
def tanhshrink(x: Tensor):
return ops.tanhshrink(x)


@register_function(torch.nn.functional.hardshrink)
def hardshrink(x: Tensor, lambda_val: float):
return ops.hardshrink(x, lambda_val)


@register_function(torch.nn.functional.softsign)
def softsign(x: Tensor):
return ops.softsign(x)


@register_function(torch.nn.functional.celu)
def celu(x: Tensor, alpha: float):
return ops.celu(x, alpha)


@register_function(torch.nn.functional.logsigmoid)
def logsigmoid(x: Tensor):
return ops.logsigmoid(x)
63 changes: 63 additions & 0 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,66 @@ class HidetSiLU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.SiLU)
return regs.silu(x, self.mod.inplace)


@register_module(torch.nn.Softmax)
class HidetSoftmax(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softmax)
return regs.softmax(x, self.mod.dim)


@register_module(torch.nn.Softmin)
class HidetSoftmin(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softmin)
return regs.softmin(x, self.mod.dim)


@register_module(torch.nn.Softplus)
class HidetSoftplus(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softplus)
return regs.softplus(x, self.mod.beta, self.mod.threshold)


@register_module(torch.nn.Softsign)
class HidetSoftsign(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softsign)
return regs.softsign(x)


@register_module(torch.nn.Softshrink)
class HidetSoftshrink(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softshrink)
return regs.softshrink(x, self.mod.lambd)


@register_module(torch.nn.Tanhshrink)
class HidetTanhshrink(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Tanhshrink)
return regs.tanhshrink(x)


@register_module(torch.nn.Hardshrink)
class HidetHardshrink(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Hardshrink)
return regs.hardshrink(x, self.mod.lambd)


@register_module(torch.nn.CELU)
class HidetCELU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.CELU)
return regs.celu(x, self.mod.alpha)


@register_module(torch.nn.LogSigmoid)
class HidetLogSigmoid(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.LogSigmoid)
return regs.logsigmoid(x)
3 changes: 2 additions & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from .definitions.matmul import batch_matmul, matmul
from .definitions.pool import avg_pool2d, avg_pool3d, adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d
from .definitions.pool import max_pool2d, max_pool3d, adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d
from .definitions.softmax import softmax
from .definitions.activation import relu, leaky_relu, sigmoid, hardsigmoid, clip, relu6, prelu, gelu, silu, hardswish
from .definitions.activation import logsigmoid, celu, hardshrink, softplus, softsign, tanhshrink
from .definitions.activation import softshrink, softmax, softmin, hardtanh
from .definitions.norm import batch_norm_infer, instance_norm, layer_norm
from .definitions.image import resize2d
from .definitions.create import full, arange, linspace
Expand Down
5 changes: 3 additions & 2 deletions python/hidet/graph/ops/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, split, pad, conv_pad
from .pool import avg_pool2d, adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d
from .pool import max_pool2d, adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d
from .softmax import softmax
from .activation import relu, sigmoid, relu6, clip, prelu
from .activation import relu, leaky_relu, sigmoid, hardsigmoid, clip, relu6, prelu, gelu, silu, hardswish
from .activation import logsigmoid, celu, hardshrink, softplus, softsign, tanhshrink, softshrink
from .activation import softmax, softmin, hardtanh
from .norm import batch_norm_infer, instance_norm
from .image import resize2d
from .cumulative import cumsum
Expand Down
159 changes: 157 additions & 2 deletions python/hidet/graph/ops/definitions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from typing import Optional
import math
from hidet.ir import primitives as prim
from hidet.ir.expr import if_then_else
from .utils import Tensor
from hidet.ir.expr import if_then_else, BitwiseAnd
from .utils import Tensor, Operator, normalize_dim, input_like
from .arithmetic import UnaryElementwiseOp, BinaryElementwiseOp
from .softmax import SoftmaxTask


class ReluOp(UnaryElementwiseOp):
Expand Down Expand Up @@ -87,6 +88,108 @@ def __init__(self, x: Tensor):
)


class ThresholdOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, threshold_val: float, value: float) -> Tensor:
super().__init__(x, op=lambda v: if_then_else(v > x.dtype(threshold_val), v, x.dtype(value)), name='threshold')


class HardTanhOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, min_val: float = -1.0, max_val: float = 1.0) -> Tensor:
super().__init__(x, op=lambda v: prim.min(x.dtype(max_val), prim.max(x.dtype(min_val), v)), name='hardtanh')


class EluOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, alpha: float = 1.0) -> Tensor:
super().__init__(
x, op=lambda v: if_then_else(v > 0, v, x.dtype(alpha) * (prim.exp(v) - x.dtype(1.0))), name='elu'
)


class SeluOp(UnaryElementwiseOp):
def __init__(
self,
x: Tensor,
alpha: float = 1.6732632423543772848170429916717,
scale: float = 1.0507009873554804934193349852946,
) -> Tensor:
super().__init__(
x,
op=lambda v: x.dtype(scale)
* (prim.max(x.dtype(0.0), v) + prim.min(x.dtype(0.0), x.dtype(alpha) * (prim.exp(v) - x.dtype(-1.0)))),
name='selu',
)


class CeluOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, alpha: float = 1.0) -> Tensor:
super().__init__(
x,
op=lambda v: prim.max(x.dtype(0.0), v)
+ prim.min(x.dtype(0.0), x.dtype(alpha) * (prim.exp(v / x.dtype(alpha)) - x.dtype(1.0))),
name='celu',
)


class LogSigmoidOp(UnaryElementwiseOp):
def __init__(self, x: Tensor):
super().__init__(x, op=lambda v: -(prim.log(x.dtype(1.0) + prim.exp(-v))), name='logsigmoid')


class HardShrinkOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, lambda_val: float = 0.5):
super().__init__(
x,
op=lambda v: if_then_else(BitwiseAnd(v >= x.dtype(-lambda_val), v <= x.dtype(lambda_val)), x.dtype(0), v),
name='hardshrink',
)


class TanhShrinkOp(UnaryElementwiseOp):
def __init__(self, x: Tensor):
super().__init__(
x, op=lambda v: v - (prim.exp(v) - prim.exp(-v)) / (prim.exp(v) + prim.exp(-v)), name='tanhshrink'
)


class SoftSignOp(UnaryElementwiseOp):
def __init__(self, x: Tensor):
super().__init__(x, op=lambda v: v / (x.dtype(1.0) + prim.abs(v)), name='softsign')


class SoftPlusOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, beta: int = 1, threshold_val: int = 20):
super().__init__(
x,
op=lambda v: if_then_else(
v * x.dtype(beta) <= x.dtype(threshold_val),
(x.dtype(1.0 / beta)) * prim.log(x.dtype(1.0) + prim.exp(x.dtype(beta) * v)),
v,
),
name='softplus',
)


class SoftShrinkOp(UnaryElementwiseOp):
def __init__(self, x: Tensor, lambda_val: float = 0.5):
super().__init__(
x,
op=lambda v: if_then_else(
v > x.dtype(lambda_val),
v - x.dtype(lambda_val),
if_then_else(v < x.dtype(-lambda_val), v + x.dtype(lambda_val), x.dtype(0.0)),
),
name='softshrink',
)


class SoftmaxOp(Operator):
def __init__(self, x: Tensor, axis: int = 1):
axis = normalize_dim(axis, len(x.shape))
super().__init__(
inputs=[x], task=SoftmaxTask(input_like(x, 'x'), axis), attributes={'axis': axis}, name='softmax'
)


def relu(x) -> Tensor:
return ReluOp(x).get_output(0)

Expand Down Expand Up @@ -125,3 +228,55 @@ def prelu(x: Tensor, slope: Tensor) -> Tensor:

def hardswish(x: Tensor) -> Tensor:
return HardSwishOp(x).get_output(0)


def threshold(x: Tensor, threshold_val: float, value: float) -> Tensor:
return ThresholdOp(x, threshold_val, value).get_output(0)


def hardtanh(x: Tensor, min_val: float, max_val: float) -> Tensor:
return HardTanhOp(x, min_val, max_val).get_output(0)


def elu(x: Tensor, alpha: float) -> Tensor:
return EluOp(x, alpha).get_output(0)


def selu(x: Tensor, alpha: float, scale: float) -> Tensor:
return SeluOp(x, alpha, scale).get_output(0)


def celu(x: Tensor, alpha: float) -> Tensor:
return CeluOp(x, alpha).get_output(0)


def logsigmoid(x: Tensor) -> Tensor:
return LogSigmoidOp(x).get_output(0)


def hardshrink(x: Tensor, lambda_val: float) -> Tensor:
return HardShrinkOp(x, lambda_val).get_output(0)


def tanhshrink(x: Tensor) -> Tensor:
return TanhShrinkOp(x).get_output(0)


def softsign(x: Tensor) -> Tensor:
return SoftSignOp(x).get_output(0)


def softplus(x: Tensor, beta: int, threshold_val: int) -> Tensor:
return SoftPlusOp(x, beta, threshold_val).get_output(0)


def softshrink(x: Tensor, lambda_val: float) -> Tensor:
return SoftShrinkOp(x, lambda_val).get_output(0)


def softmax(x: Tensor, axis=1) -> Tensor:
return SoftmaxOp(x, axis).get_output(0)


def softmin(x: Tensor, axis: int) -> Tensor:
return SoftmaxOp(-x, axis).get_output(0)
12 changes: 1 addition & 11 deletions python/hidet/graph/ops/definitions/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
from hidet.ir.func import IRModule
from hidet.ir import primitives as prim
from .utils import Task, Operator, Tensor, TensorNode, compute, input_like, normalize_dim, reduce
from .utils import Task, TensorNode, compute, reduce


class SoftmaxTask(Task):
Expand Down Expand Up @@ -62,13 +62,3 @@ def implement_cuda(self, workding_dir: str) -> IRModule:
from hidet.graph.ops.schedules import softmax_cuda_schedule

return softmax_cuda_schedule(self)


class SoftmaxOp(Operator):
def __init__(self, x: Tensor, axis: int = 1):
axis = normalize_dim(axis, len(x.shape))
super().__init__(inputs=[x], task=SoftmaxTask(input_like(x, 'x'), axis), attributes={'axis': axis})


def softmax(x: Tensor, axis=1) -> Tensor:
return SoftmaxOp(x, axis).get_output(0)
2 changes: 1 addition & 1 deletion python/hidet/ir/primitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# base primitive functions
# pylint: disable=redefined-builtin
from .math import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, expm1
from .math import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, expm1, abs
from .math import max, min, exp, pow, sqrt, rsqrt, erf, ceil, log, log2, log10, log1p, round, floor, trunc
from .math import isfinite, isinf, isnan

Expand Down
4 changes: 4 additions & 0 deletions python/hidet/ir/primitives/cpu/math/float32.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def register():
'rsqrt': 'rsqrtf',
'log': 'logf',
'round': 'roundf',
'abs': 'fabsf',
'ceil': 'ceilf',
'floor': 'floorf',
'expm1': 'expm1f',
Expand Down Expand Up @@ -94,6 +95,9 @@ def log(self, a: Expr) -> Expr:
def round(self, a: Expr) -> Expr:
return self.call('round', a)

def abs(self, a: Expr) -> Expr:
return self.call('abs', a)

def ceil(self, a: Expr) -> Expr:
return self.call('ceil', a)

Expand Down