Skip to content

Commit

Permalink
[Frontend] Add more ops and op mappings for PyTorch frontend (#148)
Browse files Browse the repository at this point in the history
flatten all changes into one commit for review

Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
  • Loading branch information
AndreSlavescu and yaoyaoding committed Apr 26, 2023
1 parent 30ae787 commit 0f8d3fa
Show file tree
Hide file tree
Showing 29 changed files with 1,042 additions and 40 deletions.
67 changes: 62 additions & 5 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@
Number = Union[int, float, bool]


@register_function(torch.nn.functional.conv1d)
def conv1d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
x = ops.conv_pad(x, padding)
y = ops.conv1d(x, weight, stride=stride, dilations=dilation, groups=groups)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2])
return y


@register_function(torch.nn.functional.conv_transpose1d)
def conv1d_transpose(
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation
):
if dilation != 1 and not same_list(dilation, [1]):
raise NotImplementedError("dilation != 1")
y = ops.conv1d_transpose(x, weight, stride, padding, groups, output_padding)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2])
return y


@register_function(torch.nn.functional.conv2d)
def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
x = ops.conv_pad(x, padding)
Expand All @@ -34,6 +55,18 @@ def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d
return y


@register_function(torch.nn.functional.conv_transpose2d)
def conv2d_transpose(
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation
):
if dilation != 1 and not same_list(dilation, [1, 1]):
raise NotImplementedError("dilation != 1")
y = ops.conv2d_transpose(x, weight, stride, padding, groups, output_padding)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2, 3])
return y


@register_function(torch.nn.functional.conv3d)
def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
x = ops.conv_pad(x, padding)
Expand All @@ -43,6 +76,18 @@ def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d
return y


@register_function(torch.nn.functional.conv_transpose3d)
def conv3d_transpose(
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation
):
if dilation != 1 and not same_list(dilation, [1, 1, 1]):
raise NotImplementedError("dilation != 1")
y = ops.conv3d_transpose(x, weight, stride, padding, groups, output_padding)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2, 3, 4])
return y


@register_function(torch.nn.functional.adaptive_avg_pool2d)
def adaptive_avg_pool2d(x: Tensor, output_size):
return ops.adaptive_avg_pool2d(x, output_size)
Expand Down Expand Up @@ -106,11 +151,6 @@ def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)


@register_function(operator.neg)
def neg(x: Tensor):
return -x


@register_function(torch.sin)
def sin(x: Tensor):
return ops.sin(x)
Expand Down Expand Up @@ -277,6 +317,11 @@ def sub(x: Tensor, y: Tensor):
return x - y


@register_function(operator.neg)
def neg(x: Tensor):
return -x


@register_function(torch.nn.functional.softmax)
@register_method(torch.Tensor.softmax)
def softmax(x: Tensor, dim: int, _stacklevel: int = 3, dtype=None):
Expand Down Expand Up @@ -398,6 +443,11 @@ def tanh(x: Tensor):
return ops.tanh(x)


@register_function(torch.nn.functional.hardtanh)
def hardtanh(x: Tensor, min_val: float, max_val: float):
return ops.hardtanh(x, min_val, max_val)


@register_function(torch.nn.functional.embedding)
def embedding(
x: Tensor,
Expand Down Expand Up @@ -681,6 +731,13 @@ def logsigmoid(x: Tensor):
return ops.logsigmoid(x)


@register_function(torch.nn.functional.mish)
def mish(x: Tensor, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: mish with inplace=True is not supported. Treat as inplace=False.')
return ops.multiply(x, ops.tanh(ops.softplus(x, 1.0, 20.0)))


@register_function(torch.gather)
def gather(x: Tensor, dim: int, index: Tensor, *, sparse_grad=False, out=None):
if sparse_grad:
Expand Down
77 changes: 77 additions & 0 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,37 @@
from . import register_functions as regs


@register_module(torch.nn.Conv1d)
class HidetConv1d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Conv1d)
return regs.conv1d(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
stride=self.mod.stride,
padding=self.mod.padding,
dilation=self.mod.dilation,
groups=self.mod.groups,
)


@register_module(torch.nn.ConvTranspose1d)
class HidetConvTranspose1d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ConvTranspose1d)
return regs.conv1d_transpose(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
stride=self.mod.stride,
padding=self.mod.padding,
output_padding=self.mod.output_padding,
groups=self.mod.groups,
dilation=self.mod.dilation,
)


@register_module(torch.nn.Conv2d)
class HidetConv2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand All @@ -31,6 +62,22 @@ def __call__(self, x: Tensor) -> Tensor:
)


@register_module(torch.nn.ConvTranspose2d)
class HidetConvTranspose2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ConvTranspose2d)
return regs.conv2d_transpose(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
stride=self.mod.stride,
padding=self.mod.padding,
output_padding=self.mod.output_padding,
groups=self.mod.groups,
dilation=self.mod.dilation,
)


@register_module(torch.nn.Conv3d)
class HidetConv3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand All @@ -46,6 +93,22 @@ def __call__(self, x: Tensor) -> Tensor:
)


@register_module(torch.nn.ConvTranspose3d)
class HidetConvTranspose3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ConvTranspose3d)
return regs.conv3d_transpose(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
stride=self.mod.stride,
padding=self.mod.padding,
output_padding=self.mod.output_padding,
groups=self.mod.groups,
dilation=self.mod.dilation,
)


@register_module(torch.nn.AdaptiveAvgPool2d)
class HidetAdaptiveAvgPool2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -163,6 +226,13 @@ def __call__(self, x: Tensor) -> Tensor:
return regs.tanh(x)


@register_module(torch.nn.Hardtanh)
class HidetHardtanh(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Hardtanh)
return regs.hardtanh(x, self.mod.min_val, self.mod.max_val)


@register_module(torch.nn.Embedding)
class HidetEmbedding(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -303,3 +373,10 @@ class HidetLogSigmoid(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.LogSigmoid)
return regs.logsigmoid(x)


@register_module(torch.nn.Mish)
class HidetMish(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Mish)
return regs.mish(x, self.mod.inplace)
3 changes: 3 additions & 0 deletions python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# pylint: disable=redefined-builtin
from . import definitions

from .definitions.conv1d import conv1d
from .definitions.conv1d_transpose import conv1d_transpose
from .definitions.conv2d import conv2d, conv2d_winograd, conv2d_gemm, conv2d_gemm_image_transform
from .definitions.conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm
from .definitions.conv3d import conv3d, conv3d_gemm
from .definitions.conv3d_transpose import conv3d_transpose
from .definitions.matmul import batch_matmul, matmul
from .definitions.pool import avg_pool2d, avg_pool3d, adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d
from .definitions.pool import max_pool2d, max_pool3d, adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d
Expand Down
4 changes: 4 additions & 0 deletions python/hidet/graph/ops/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
from .image import resize2d
from .cumulative import cumsum
from .special import barrier
from .conv1d import conv1d
from .conv1d_transpose import conv1d_transpose
from .attention import attention
from .conv2d import conv2d
from .conv2d_transpose import conv2d_transpose
from .conv3d import conv3d
from .conv3d_transpose import conv3d_transpose
from .matmul import batch_matmul, matmul

from .matmul import BatchMatmulOp, MatmulOp
Expand Down
13 changes: 13 additions & 0 deletions python/hidet/graph/ops/definitions/conv1d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv1d import conv1d
from .conv1d import Conv1dOp
68 changes: 68 additions & 0 deletions python/hidet/graph/ops/definitions/conv1d/conv1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union, Sequence
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.definitions.utils import compute, input_like, normalize_stride, normalize_dilations, reduce


class Conv1dTask(Task):
def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dilations: List[int], groups: int):
n, c, l = data.const_shape()
oc, wc, k = weight.const_shape()
s = normalize_stride(stride, dim=1)[0]
dil = normalize_dilations(dilations, dim=1)[0]
len_in = (l - dil * (k - 1) - 1) // s + 1
if c % groups != 0 or oc % groups != 0:
raise ValueError(
'Conv1d expects: in_channels % groups == 0 and out_channels % groups == 0, \n'
'but got in_channels, out_channels, groups: {}, {}, {}'.format(c, oc, groups)
)
if wc * groups != c:
raise ValueError(
'Conv1d expects the weight tensor has shape [out_channels, in_channels / groups, kernel_size], \n'
'got weight shape {}, in_channels {} and groups {}'.format([oc, wc, k], c, groups)
)
out_group_size = oc // groups
output = compute(
name='out',
shape=[n, oc, len_in],
fcompute=lambda ni, oci, li: reduce(
shape=[wc, k],
fcompute=lambda wci, ki: (
data[ni, (oci // out_group_size) * wc + wci, li * s + ki * dil] * weight[oci, wci, ki]
),
reduce_type='sum',
),
)
self.channels = c
self.stride = s
self.groups = groups
super().__init__(name='conv1d', inputs=[data, weight], outputs=[output])


class Conv1dOp(Operator):
def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union[int, Sequence[int]], groups: int):
super().__init__(
inputs=[x, w],
task=Conv1dTask(input_like(x, 'x'), input_like(w, 'w'), stride, dilations, groups),
attributes={'stride': stride, 'groups': groups, 'dilations': dilations},
)


def conv1d(
data: Tensor,
weight: Tensor,
stride: Union[int, Sequence[int]] = (1),
dilations: Union[int, Sequence[int]] = (1),
groups: int = 1,
) -> Tensor:
return Conv1dOp(data, weight, stride, dilations, groups).get_output(0)
12 changes: 12 additions & 0 deletions python/hidet/graph/ops/definitions/conv1d_transpose/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv1d_transpose import conv1d_transpose

0 comments on commit 0f8d3fa

Please sign in to comment.