-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Operator] Add leaky_relu and conv2d_transpose operator (#3)
- Loading branch information
1 parent
cfe3605
commit e2e1ed1
Showing
9 changed files
with
237 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
python/hidet/graph/ops/definitions/conv2d_transpose/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .conv2d_transpose import conv2d_transpose, Conv2dTransposeOp |
88 changes: 88 additions & 0 deletions
88
python/hidet/graph/ops/definitions/conv2d_transpose/conv2d_transpose.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from typing import Sequence, Union | ||
from hidet.ir.expr import if_then_else, And | ||
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode | ||
from hidet.graph.ops.definitions.utils import compute, input_like, normalize_stride, reduce, normalize_padding | ||
|
||
|
||
class Conv2dTransposeTask(Task): | ||
def __init__( | ||
self, | ||
data: TensorNode, | ||
weight: TensorNode, | ||
stride: Sequence[int], # [sx, sy] | ||
padding: Sequence[int], # [px0, py0, px1, py1] | ||
groups: int, | ||
output_padding: Sequence[int], # [opx, opy] | ||
): | ||
n, oc, p, q = data.const_shape() | ||
oc, wc, kx, ky = weight.const_shape() | ||
c = wc * groups | ||
sx, sy = stride | ||
px0, py0, px1, py1 = padding | ||
h = (p - 1) * sx + -px0 - px1 + kx + output_padding[0] | ||
w = (q - 1) * sy + -py0 - py1 + ky + output_padding[1] | ||
|
||
if output_padding[0] >= stride[0] or output_padding[1] >= stride[1]: | ||
raise ValueError( | ||
'Conv2dTranspose expect the output_padding < stride, \n' | ||
'but got output_padding, stride: {}, {}'.format(output_padding, stride) | ||
) | ||
if any(p < 0 for p in padding): | ||
raise ValueError('Negative padding is not supported.') | ||
|
||
out_group_size = oc // groups | ||
output = compute( | ||
name='out', | ||
shape=[n, c, h, w], | ||
fcompute=lambda ni, ci, hi, wi: reduce( | ||
shape=[out_group_size, kx, ky], | ||
fcompute=lambda ogi, kxi, kyi: if_then_else( | ||
cond=And.join( | ||
hi + px0 >= kxi, | ||
hi + px0 < p * sx + kxi, | ||
(hi + px0 - kxi) % sx == 0, | ||
wi + py0 >= kyi, | ||
wi + py0 < q * sy + kyi, | ||
(wi + py0 - kyi) % sy == 0, | ||
), | ||
then_expr=( | ||
data[ni, (ci // wc) * out_group_size + ogi, (hi + px0 - kxi) // sx, (wi + py0 - kyi) // sy] | ||
* weight[(ci // wc) * out_group_size + ogi, ci % wc, kxi, kyi] | ||
), | ||
else_expr=0.0, | ||
), | ||
reduce_type='sum', | ||
), | ||
) | ||
super().__init__(name='conv2d_transpose', inputs=[data, weight], outputs=[output]) | ||
|
||
|
||
class Conv2dTransposeOp(Operator): | ||
def __init__( | ||
self, | ||
x: Tensor, | ||
w: Tensor, | ||
stride: Sequence[int], | ||
padding: Sequence[int], | ||
groups: int, | ||
output_padding: Sequence[int], | ||
): | ||
stride = normalize_stride(stride) | ||
padding = normalize_padding(padding) | ||
output_padding = normalize_stride(output_padding) # normalize output padding same as stride | ||
super().__init__( | ||
inputs=[x, w], | ||
task=Conv2dTransposeTask(input_like(x, 'x'), input_like(w, 'w'), stride, padding, groups, output_padding), | ||
attributes={'stride': stride, 'groups': groups, 'output_padding': output_padding}, | ||
) | ||
|
||
|
||
def conv2d_transpose( | ||
data: Tensor, | ||
weight: Tensor, | ||
stride: Union[int, Sequence[int]], | ||
padding: Union[int, Sequence[int]], | ||
groups: int = 1, | ||
output_padding: Union[int, Sequence[int]] = 0, | ||
) -> Tensor: | ||
return Conv2dTransposeOp(data, weight, stride, padding, groups, output_padding).get_output(0) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import pytest | ||
import numpy as np | ||
import torch | ||
import hidet | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'in_channels, out_channels, kernel_size, stride, pads, groups, height, width, output_padding', | ||
[[10, 20, (5, 5), (3, 2), [2, 1], 5, 11, 10, (2, 1)]], | ||
) | ||
def test_conv2d_transpose(in_channels, out_channels, kernel_size, stride, pads, groups, height, width, output_padding): | ||
torch_data = torch.ones(1, in_channels, height, width, dtype=torch.float32).cuda() | ||
torch_weight = torch.ones( | ||
out_channels, in_channels // groups, kernel_size[0], kernel_size[1], dtype=torch.float32 | ||
).cuda() | ||
|
||
torch_output = torch.nn.functional.conv2d( | ||
torch_data, torch_weight, stride=stride, padding=pads, groups=groups, bias=None, dilation=1 | ||
) | ||
hidet_data = hidet.from_torch(torch_data) | ||
hidet_weight = hidet.from_torch(torch_weight) | ||
hidet_output = hidet.ops.conv_pad(hidet_data, pads) | ||
hidet_output = hidet.ops.conv2d(hidet_output, hidet_weight, stride, groups) | ||
np.testing.assert_allclose(hidet_output.numpy(), torch_output.cpu().numpy(), atol=1e-5) | ||
torch_transpose_output = torch.nn.functional.conv_transpose2d( | ||
torch_output, | ||
torch_weight, | ||
stride=stride, | ||
padding=pads, | ||
groups=groups, | ||
bias=None, | ||
dilation=1, | ||
output_padding=output_padding, | ||
) | ||
hidet_transpose_output = hidet.ops.conv2d_transpose( | ||
hidet_output, hidet_weight, stride, pads, groups, output_padding=output_padding | ||
) | ||
np.testing.assert_allclose(hidet_transpose_output.numpy(), torch_transpose_output.cpu().numpy(), atol=1e-5) | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main([__file__]) |