Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add conv2d_transpose_gemm operator & fix a bug #13

Merged
merged 3 commits into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 12 additions & 3 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional
import functools
import warnings
import os
import ctypes
import shutil
Expand Down Expand Up @@ -126,10 +127,18 @@ def compile_source(src_path: str, out_lib_path: str, keep_ptx=False) -> None:
target_ptx_path = os.path.join(out_lib_dir, ptx_name)
shutil.move(ptx_path, target_ptx_path)
# os.rename(ptx_path, target_ptx_path)
with open(os.path.join(out_lib_dir, 'compile.sh'), 'w') as f:
f.write("#!/bin/bash\n")
f.write(" ".join(result.args))
with open(os.path.join(out_lib_dir, 'nvcc_log.txt'), 'w') as f:
f.write('Command: {}\n'.format(" ".join(result.args)))
f.write(result.stdout.decode('utf-8'))
f.write(result.stderr.decode('utf-8'))
output = '\n'.join([result.stdout.decode('utf-8').strip(), result.stderr.decode('utf-8').strip()])
f.write(output)

lines = output.split('\n')
warning_lines = [line for line in lines if 'warning' in line]
warning_lines = warning_lines[: len(warning_lines) // 2] # nvcc would print the same warning twice
if len(warning_lines) > 0:
warnings.warn('Compilation warnings:\n' + '\n'.join(warning_lines))
except subprocess.CalledProcessError as e:
print(' '.join(command))
print(e.stderr.decode('utf-8'))
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def visit_Let(self, e: Let):
raise ValueError("please run 'expand_let_expr' pass before codegen")

def visit_Var(self, e: Var):
cast2int = {'threadIdx.x', 'blockIdx.x'}
cast2int = {'threadIdx.x', 'threadIdx.y', 'threadIdx.z', 'blockIdx.x', 'blockIdx.y', 'blockIdx.z'}
name = self.namer.get_name(e)
if name in cast2int:
return Text(f'(int){name}')
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from . import definitions

from .definitions.conv2d import conv2d, conv2d_winograd, conv2d_gemm, conv2d_gemm_image_transform
from .definitions.conv2d_transpose import conv2d_transpose
from .definitions.conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm
from .definitions.matmul import batch_matmul, matmul
from .definitions.pool import max_pool2d, avg_pool2d
from .definitions.softmax import softmax
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .conv2d import conv2d_gemm_image_transform, conv2d_gemm_filter_transform, conv2d_gemm_inverse_transform
from .conv2d import conv2d_winograd_image_transform, conv2d_winograd_filter_transform, conv2d_winograd_inverse_transform

from .conv2d_transpose import conv2d_transpose
from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm

from .matmul import batch_matmul, matmul
from .pool import max_pool2d, avg_pool2d
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .conv2d_transpose import conv2d_transpose, Conv2dTransposeOp
from .conv2d_transpose_gemm import conv2d_transpose_gemm

from . import resolve
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from typing import Sequence, Union
from typing import Sequence, Union, Tuple
from hidet.ir.expr import if_then_else, And
from hidet.ir.compute import compute, reduce
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.definitions.utils import compute, input_like, normalize_stride, reduce, normalize_padding
from hidet.graph.ops.definitions.utils import input_like, normalize_stride, normalize_padding


class Conv2dTransposeTask(Task):
def __init__(
self,
data: TensorNode,
weight: TensorNode,
stride: Sequence[int], # [sx, sy]
padding: Sequence[int], # [px0, py0, px1, py1]
stride: Tuple[int, int],
padding: Tuple[int, int, int, int],
groups: int,
output_padding: Sequence[int], # [opx, opy]
output_padding: Tuple[int, int],
):
n, oc, p, q = data.const_shape()
oc, wc, kx, ky = weight.const_shape()
Expand All @@ -30,12 +31,12 @@ def __init__(
if any(p < 0 for p in padding):
raise ValueError('Negative padding is not supported.')

out_group_size = oc // groups
og = oc // groups # output channels in each group
output = compute(
name='out',
shape=[n, c, h, w],
fcompute=lambda ni, ci, hi, wi: reduce(
shape=[out_group_size, kx, ky],
shape=[og, kx, ky],
fcompute=lambda ogi, kxi, kyi: if_then_else(
cond=And.join(
hi + px0 >= kxi,
Expand All @@ -46,8 +47,8 @@ def __init__(
(wi + py0 - kyi) % sy == 0,
),
then_expr=(
data[ni, (ci // wc) * out_group_size + ogi, (hi + px0 - kxi) // sx, (wi + py0 - kyi) // sy]
* weight[(ci // wc) * out_group_size + ogi, ci % wc, kxi, kyi]
data[ni, (ci // wc) * og + ogi, (hi + px0 - kxi) // sx, (wi + py0 - kyi) // sy]
* weight[(ci // wc) * og + ogi, ci % wc, kxi, kyi]
),
else_expr=0.0,
),
Expand All @@ -62,14 +63,11 @@ def __init__(
self,
x: Tensor,
w: Tensor,
stride: Sequence[int],
padding: Sequence[int],
stride: Tuple[int, int],
padding: Tuple[int, int, int, int],
groups: int,
output_padding: Sequence[int],
output_padding: Tuple[int, int],
):
stride = normalize_stride(stride)
padding = normalize_padding(padding)
output_padding = normalize_stride(output_padding) # normalize output padding same as stride
super().__init__(
inputs=[x, w],
task=Conv2dTransposeTask(input_like(x, 'x'), input_like(w, 'w'), stride, padding, groups, output_padding),
Expand All @@ -85,4 +83,7 @@ def conv2d_transpose(
groups: int = 1,
output_padding: Union[int, Sequence[int]] = 0,
) -> Tensor:
return Conv2dTransposeOp(data, weight, stride, padding, groups, output_padding).get_output(0)
sx, sy = normalize_stride(stride)
px0, py0, px1, py1 = normalize_padding(padding)
opx, opy = normalize_stride(output_padding) # normalize output padding same as stride
return Conv2dTransposeOp(data, weight, (sx, sy), (px0, py0, px1, py1), groups, (opx, opy)).get_output(0)
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Sequence, Union, Tuple
from hidet.ir.expr import if_then_else, And
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.definitions.utils import input_like, normalize_stride, normalize_padding
from hidet.ir.compute import compute
from hidet.graph.ops.definitions.matmul import matmul


class Conv2dTransposeGemmImageTask(Task):
def __init__(
self,
data: TensorNode,
kernel: Tuple[int, int],
stride: Sequence[int], # [sx, sy]
padding: Sequence[int], # [px0, py0, px1, py1]
groups: int,
output_padding: Sequence[int], # [opx, opy]
):
n, oc, p, q = data.const_shape()
kx, ky = kernel
sx, sy = stride
px0, py0, px1, py1 = padding
h = (p - 1) * sx + -px0 - px1 + kx + output_padding[0]
w = (q - 1) * sy + -py0 - py1 + ky + output_padding[1]
og = oc // groups # output channels in each group

def fcompute(b, i, k):
gi = b
ni, hi, wi = i // (h * w), ((i // w) % h), (i % w)
ogi, kxi, kyi = k // (kx * ky), ((k // ky) % kx), (k % ky)
xx = hi + px0 - kxi
yy = wi + py0 - kyi
return if_then_else(
cond=And.join(xx >= 0, xx < p * sx, xx % sx == 0, yy >= 0, yy < q * sy, yy % sy == 0),
then_expr=data[ni, gi * og + ogi, xx // sx, yy // sy],
else_expr=0.0,
)

output = compute(name='gemm_x', shape=[groups, n * h * w, og * kx * ky], fcompute=fcompute)
super().__init__(name='conv2d_transpose_gemm_image', inputs=[data], outputs=[output])


class Conv2dTransposeGemmImageOp(Operator):
def __init__(
self,
data: Tensor,
kernel: Tuple[int, int],
stride: Tuple[int, int],
padding: Tuple[int, int, int, int],
groups: int,
output_padding: Tuple[int, int],
):
super().__init__(
inputs=[data],
task=Conv2dTransposeGemmImageTask(
input_like(data, 'data'), kernel, stride, padding, groups, output_padding
),
attributes={
'kernel': kernel,
'stride': stride,
'padding': padding,
'groups': groups,
'output_padding': output_padding,
},
)


def conv2d_transpose_gemm_image(
data: Tensor,
kernel: Tuple[int, int],
stride: Tuple[int, int],
padding: Tuple[int, int, int, int],
groups: int,
output_padding: Tuple[int, int],
):
# input shape: [n, oc, p, q]
# output shape: [groups, n * h * w, og * kx * ky]
return Conv2dTransposeGemmImageOp(data, kernel, stride, padding, groups, output_padding).get_output(0)


def conv2d_transpose_gemm_filter(weight: Tensor, groups: int = 1):
# input shape: [oc, wc, kx, ky] where oc = groups * og
# output shape: [groups, og * kx * ky, wc]
oc, wc, kx, ky = weight.shape
og = oc // groups
return weight.reshape([groups, og, wc, kx, ky]).rearrange([[0], [1, 3, 4], [2]])


def conv2d_transpose_gemm_inverse(gemm_y, height: int, width: int):
# input shape: [groups, n * h * w, wc]
# output shape: [n, c, h, w] where c = groups * wc
groups, nhw, wc = gemm_y.shape
assert nhw % (height * width) == 0
n = nhw // (height * width)
return gemm_y.reshape([groups, n, height, width, wc]).rearrange([[1], [0, 4], [2], [3]])


def conv2d_transpose_gemm(
data: Tensor,
weight: Tensor,
stride: Union[int, Sequence[int]],
padding: Union[int, Sequence[int]],
groups: int = 1,
output_padding: Union[int, Sequence[int]] = 0,
) -> Tensor:
sx, sy = normalize_stride(stride)
px0, py0, px1, py1 = normalize_padding(padding)
opx, opy = normalize_stride(output_padding) # normalize output padding same as stride
kx, ky = weight.shape[2:]
gemm_x = conv2d_transpose_gemm_image(data, (kx, ky), (sx, sy), (px0, py0, px1, py1), groups, (opx, opy))
gemm_w = conv2d_transpose_gemm_filter(weight, groups)
gemm_y = matmul(gemm_x, gemm_w)

p, q = data.shape[2:]
h = (p - 1) * sx + -px0 - px1 + kx + output_padding[0]
w = (q - 1) * sy + -py0 - py1 + ky + output_padding[1]
y = conv2d_transpose_gemm_inverse(gemm_y, h, w)
return y
22 changes: 22 additions & 0 deletions python/hidet/graph/ops/definitions/conv2d_transpose/resolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import List, Type, Optional
from hidet.graph.ir import Operator, Tensor
from hidet.graph import ops
from hidet.graph.transforms import ResolveRule, register_resolve_rule

from .conv2d_transpose import Conv2dTransposeOp


@register_resolve_rule
class Conv2dTransposeResolveRule(ResolveRule):
def op_cls(self) -> Type[Operator]:
return Conv2dTransposeOp

def resolve(self, op: Conv2dTransposeOp) -> Optional[List[Tensor]]:
attrs = op.attrs
data, weight = op.inputs
stride = attrs['stride']
padding = attrs['padding']
groups = attrs['groups']
output_padding = attrs['output_padding']
out = ops.conv2d_transpose_gemm(data, weight, stride, padding, groups, output_padding)
return [out]
4 changes: 1 addition & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ pytest==7.2
black==22.10.0

# linter
pylint==2.15.5
pylint==2.13.9

# for models to test
--extra-index-url https://download.pytorch.org/whl/cu116
torch
--extra-index-url https://download.pytorch.org/whl/cu116
torchvision
transformers

Expand Down
10 changes: 6 additions & 4 deletions tests/graph/operators/test_conv2d_transpose.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import pytest
import numpy as np
import torch
import torch.nn.functional
import hidet


@pytest.mark.parametrize("hidet_op", [hidet.ops.conv2d_transpose, hidet.ops.conv2d_transpose_gemm])
@pytest.mark.parametrize(
'in_channels, out_channels, kernel_size, stride, pads, groups, height, width, output_padding',
[[10, 20, (5, 5), (3, 2), [2, 1], 5, 11, 10, (2, 1)]],
)
def test_conv2d_transpose(in_channels, out_channels, kernel_size, stride, pads, groups, height, width, output_padding):
def test_conv2d_transpose(
hidet_op, in_channels, out_channels, kernel_size, stride, pads, groups, height, width, output_padding
):
torch_data = torch.ones(1, in_channels, height, width, dtype=torch.float32).cuda()
torch_weight = torch.ones(
out_channels, in_channels // groups, kernel_size[0], kernel_size[1], dtype=torch.float32
Expand All @@ -32,9 +36,7 @@ def test_conv2d_transpose(in_channels, out_channels, kernel_size, stride, pads,
dilation=1,
output_padding=output_padding,
)
hidet_transpose_output = hidet.ops.conv2d_transpose(
hidet_output, hidet_weight, stride, pads, groups, output_padding=output_padding
)
hidet_transpose_output = hidet_op(hidet_output, hidet_weight, stride, pads, groups, output_padding=output_padding)
np.testing.assert_allclose(hidet_transpose_output.numpy(), torch_transpose_output.cpu().numpy(), atol=1e-5)


Expand Down