Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] module tests + operator support #148

Merged
merged 1 commit into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 62 additions & 5 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@
Number = Union[int, float, bool]


@register_function(torch.nn.functional.conv1d)
def conv1d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
x = ops.conv_pad(x, padding)
y = ops.conv1d(x, weight, stride=stride, dilations=dilation, groups=groups)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2])
return y


@register_function(torch.nn.functional.conv_transpose1d)
def conv1d_transpose(
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation
):
if dilation != 1 and not same_list(dilation, [1]):
raise NotImplementedError("dilation != 1")
y = ops.conv1d_transpose(x, weight, stride, padding, groups, output_padding)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2])
return y


@register_function(torch.nn.functional.conv2d)
def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
x = ops.conv_pad(x, padding)
Expand All @@ -34,6 +55,18 @@ def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d
return y


@register_function(torch.nn.functional.conv_transpose2d)
def conv2d_transpose(
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation
):
if dilation != 1 and not same_list(dilation, [1, 1]):
raise NotImplementedError("dilation != 1")
y = ops.conv2d_transpose(x, weight, stride, padding, groups, output_padding)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2, 3])
return y


@register_function(torch.nn.functional.conv3d)
def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
x = ops.conv_pad(x, padding)
Expand All @@ -43,6 +76,18 @@ def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d
return y


@register_function(torch.nn.functional.conv_transpose3d)
def conv3d_transpose(
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation
):
if dilation != 1 and not same_list(dilation, [1, 1, 1]):
raise NotImplementedError("dilation != 1")
y = ops.conv3d_transpose(x, weight, stride, padding, groups, output_padding)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2, 3, 4])
return y


@register_function(torch.nn.functional.adaptive_avg_pool2d)
def adaptive_avg_pool2d(x: Tensor, output_size):
return ops.adaptive_avg_pool2d(x, output_size)
Expand Down Expand Up @@ -106,11 +151,6 @@ def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)


@register_function(operator.neg)
def neg(x: Tensor):
return -x


@register_function(torch.sin)
def sin(x: Tensor):
return ops.sin(x)
Expand Down Expand Up @@ -277,6 +317,11 @@ def sub(x: Tensor, y: Tensor):
return x - y


@register_function(operator.neg)
def neg(x: Tensor):
return -x


@register_function(torch.nn.functional.softmax)
@register_method(torch.Tensor.softmax)
def softmax(x: Tensor, dim: int, _stacklevel: int = 3, dtype=None):
Expand Down Expand Up @@ -398,6 +443,11 @@ def tanh(x: Tensor):
return ops.tanh(x)


@register_function(torch.nn.functional.hardtanh)
def hardtanh(x: Tensor, min_val: float, max_val: float):
return ops.hardtanh(x, min_val, max_val)


@register_function(torch.nn.functional.embedding)
def embedding(
x: Tensor,
Expand Down Expand Up @@ -681,6 +731,13 @@ def logsigmoid(x: Tensor):
return ops.logsigmoid(x)


@register_function(torch.nn.functional.mish)
AndreSlavescu marked this conversation as resolved.
Show resolved Hide resolved
def mish(x: Tensor, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: mish with inplace=True is not supported. Treat as inplace=False.')
return ops.multiply(x, ops.tanh(ops.softplus(x, 1.0, 20.0)))


@register_function(torch.gather)
def gather(x: Tensor, dim: int, index: Tensor, *, sparse_grad=False, out=None):
if sparse_grad:
Expand Down
77 changes: 77 additions & 0 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,37 @@
from . import register_functions as regs


@register_module(torch.nn.Conv1d)
class HidetConv1d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Conv1d)
return regs.conv1d(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
stride=self.mod.stride,
padding=self.mod.padding,
dilation=self.mod.dilation,
groups=self.mod.groups,
)


@register_module(torch.nn.ConvTranspose1d)
class HidetConvTranspose1d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ConvTranspose1d)
return regs.conv1d_transpose(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
stride=self.mod.stride,
padding=self.mod.padding,
output_padding=self.mod.output_padding,
groups=self.mod.groups,
dilation=self.mod.dilation,
)


@register_module(torch.nn.Conv2d)
class HidetConv2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand All @@ -31,6 +62,22 @@ def __call__(self, x: Tensor) -> Tensor:
)


@register_module(torch.nn.ConvTranspose2d)
class HidetConvTranspose2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ConvTranspose2d)
return regs.conv2d_transpose(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
stride=self.mod.stride,
padding=self.mod.padding,
output_padding=self.mod.output_padding,
groups=self.mod.groups,
dilation=self.mod.dilation,
)


@register_module(torch.nn.Conv3d)
class HidetConv3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand All @@ -46,6 +93,22 @@ def __call__(self, x: Tensor) -> Tensor:
)


@register_module(torch.nn.ConvTranspose3d)
class HidetConvTranspose3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ConvTranspose3d)
return regs.conv3d_transpose(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
stride=self.mod.stride,
padding=self.mod.padding,
output_padding=self.mod.output_padding,
groups=self.mod.groups,
dilation=self.mod.dilation,
)


@register_module(torch.nn.AdaptiveAvgPool2d)
class HidetAdaptiveAvgPool2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -163,6 +226,13 @@ def __call__(self, x: Tensor) -> Tensor:
return regs.tanh(x)


@register_module(torch.nn.Hardtanh)
class HidetHardtanh(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Hardtanh)
return regs.hardtanh(x, self.mod.min_val, self.mod.max_val)


@register_module(torch.nn.Embedding)
class HidetEmbedding(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -303,3 +373,10 @@ class HidetLogSigmoid(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.LogSigmoid)
return regs.logsigmoid(x)


@register_module(torch.nn.Mish)
class HidetMish(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Mish)
return regs.mish(x, self.mod.inplace)
3 changes: 3 additions & 0 deletions python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# pylint: disable=redefined-builtin
from . import definitions

from .definitions.conv1d import conv1d
from .definitions.conv1d_transpose import conv1d_transpose
from .definitions.conv2d import conv2d, conv2d_winograd, conv2d_gemm, conv2d_gemm_image_transform
from .definitions.conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm
from .definitions.conv3d import conv3d, conv3d_gemm
from .definitions.conv3d_transpose import conv3d_transpose
from .definitions.matmul import batch_matmul, matmul
from .definitions.pool import avg_pool2d, avg_pool3d, adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d
from .definitions.pool import max_pool2d, max_pool3d, adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d
Expand Down
4 changes: 4 additions & 0 deletions python/hidet/graph/ops/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
from .image import resize2d
from .cumulative import cumsum
from .special import barrier
from .conv1d import conv1d
from .conv1d_transpose import conv1d_transpose
from .attention import attention
from .conv2d import conv2d
from .conv2d_transpose import conv2d_transpose
from .conv3d import conv3d
from .conv3d_transpose import conv3d_transpose
from .matmul import batch_matmul, matmul

from .matmul import BatchMatmulOp, MatmulOp
Expand Down
13 changes: 13 additions & 0 deletions python/hidet/graph/ops/definitions/conv1d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv1d import conv1d
from .conv1d import Conv1dOp
68 changes: 68 additions & 0 deletions python/hidet/graph/ops/definitions/conv1d/conv1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union, Sequence
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.definitions.utils import compute, input_like, normalize_stride, normalize_dilations, reduce


class Conv1dTask(Task):
def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dilations: List[int], groups: int):
n, c, l = data.const_shape()
oc, wc, k = weight.const_shape()
s = normalize_stride(stride, dim=1)[0]
dil = normalize_dilations(dilations, dim=1)[0]
len_in = (l - dil * (k - 1) - 1) // s + 1
if c % groups != 0 or oc % groups != 0:
raise ValueError(
'Conv1d expects: in_channels % groups == 0 and out_channels % groups == 0, \n'
'but got in_channels, out_channels, groups: {}, {}, {}'.format(c, oc, groups)
)
if wc * groups != c:
raise ValueError(
'Conv1d expects the weight tensor has shape [out_channels, in_channels / groups, kernel_size], \n'
'got weight shape {}, in_channels {} and groups {}'.format([oc, wc, k], c, groups)
)
out_group_size = oc // groups
output = compute(
name='out',
shape=[n, oc, len_in],
fcompute=lambda ni, oci, li: reduce(
shape=[wc, k],
fcompute=lambda wci, ki: (
data[ni, (oci // out_group_size) * wc + wci, li * s + ki * dil] * weight[oci, wci, ki]
),
reduce_type='sum',
),
)
self.channels = c
self.stride = s
self.groups = groups
super().__init__(name='conv1d', inputs=[data, weight], outputs=[output])


class Conv1dOp(Operator):
def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union[int, Sequence[int]], groups: int):
super().__init__(
inputs=[x, w],
task=Conv1dTask(input_like(x, 'x'), input_like(w, 'w'), stride, dilations, groups),
attributes={'stride': stride, 'groups': groups, 'dilations': dilations},
)


def conv1d(
data: Tensor,
weight: Tensor,
stride: Union[int, Sequence[int]] = (1),
dilations: Union[int, Sequence[int]] = (1),
groups: int = 1,
) -> Tensor:
return Conv1dOp(data, weight, stride, dilations, groups).get_output(0)
12 changes: 12 additions & 0 deletions python/hidet/graph/ops/definitions/conv1d_transpose/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv1d_transpose import conv1d_transpose