Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operators] Conv2d fp16 implicit gemm kernel #283

Merged
merged 28 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
# pylint: disable=redefined-builtin
from .conv1d import conv1d
from .conv1d_transpose import conv1d_transpose
from .conv2d import conv2d, conv2d_winograd, conv2d_gemm, conv2d_gemm_image_transform
from .conv2d import (
conv2d,
conv2d_channel_last,
conv2d_winograd,
conv2d_gemm,
conv2d_gemm_fp16,
conv2d_gemm_fp16_channel_last,
conv2d_gemm_image_transform,
)
from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm
from .conv3d import conv3d, conv3d_gemm
from .conv3d_transpose import conv3d_transpose
Expand Down
12 changes: 9 additions & 3 deletions python/hidet/graph/ops/conv2d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv2d import conv2d
from .conv2d import Conv2dOp
from .conv2d import conv2d, conv2d_channel_last
from .conv2d import Conv2dOp, Conv2dChannelLastOp
from .conv2d_winograd import conv2d_winograd, conv2d_winograd_image_transform, conv2d_winograd_filter_transform
from .conv2d_winograd import conv2d_winograd_inverse_transform
from .conv2d_winograd import Conv2dWinogradInverseTransformOp, Conv2dWinogradFilterTransformOp
from .conv2d_winograd import Conv2dWinogradImageTransformOp
from .conv2d_gemm import conv2d_gemm, conv2d_gemm_image_transform, conv2d_gemm_filter_transform
from .conv2d_gemm import (
conv2d_gemm,
conv2d_gemm_fp16,
conv2d_gemm_fp16_channel_last,
conv2d_gemm_image_transform,
conv2d_gemm_filter_transform,
)
from .conv2d_gemm import conv2d_gemm_inverse_transform
from .conv2d_gemm import Conv2dGemmImageTransformOp

Expand Down
63 changes: 63 additions & 0 deletions python/hidet/graph/ops/conv2d/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,48 @@ def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dila
super().__init__(name='conv2d', inputs=[data, weight], outputs=[output])


class Conv2dChannelLastTask(Task):
def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dilations: List[int], groups: int):
# pylint: disable=too-many-locals
# we assume that only data needs to have dynamic shape
n, h, w, c = data.shape
oc, wc, kx, ky = weight.shape
sx, sy = stride
dilx, dily = dilations
p, q = (h - dilx * (kx - 1) - 1) // sx + 1, (w - dily * (ky - 1) - 1) // sy + 1
self._assert(
ir.logical_or(c % groups == 0, oc % groups == 0),
msg=(
'Conv2d expect the in_channels % groups == 0 and out_channels % groups == 0, \n'
'but got in_channels, out_channels, groups: {}, {}, {}'.format(c, oc, groups)
),
)
self._assert(
wc * groups == c,
msg=(
'Conv2d expect the weight has shape [out_channels, in_channels / groups, kx, ky], \n'
'got weight shape {}, in_channels {} and groups {}'.format([oc, wc, kx, ky], c, groups)
),
)
out_group_size = oc // groups
output = compute(
name='out',
shape=[n, p, q, oc],
fcompute=lambda ni, pi, qi, oci: reduce(
shape=[wc, kx, ky],
fcompute=lambda wci, kxi, kyi: (
data[ni, pi * sx + kxi * dilx, qi * sy + kyi * dily, (oci // out_group_size) * wc + wci]
* weight[oci, wci, kxi, kyi]
),
reduce_type='sum',
),
)
self.channels = c
self.stride = stride
self.groups = groups
super().__init__(name='conv2d_channel_last', inputs=[data, weight], outputs=[output])


class Conv2dOp(Operator):
def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union[int, Sequence[int]], groups: int):
stride = normalize_stride(stride)
Expand All @@ -68,6 +110,17 @@ def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union
)


class Conv2dChannelLastOp(Operator):
def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union[int, Sequence[int]], groups: int):
stride = normalize_stride(stride)
dilations = normalize_dilations(dilations)
super().__init__(
inputs=[x, w],
attributes={'stride': stride, 'groups': groups, 'dilations': dilations},
task=Conv2dChannelLastTask(input_like(x, 'x'), input_like(w, 'w'), stride, dilations, groups),
)


def conv2d(
data: Tensor,
weight: Tensor,
Expand All @@ -76,3 +129,13 @@ def conv2d(
groups: int = 1,
) -> Tensor:
return Conv2dOp(data, weight, stride, dilations, groups).get_output(0)


def conv2d_channel_last(
data: Tensor,
weight: Tensor,
stride: Union[int, Sequence[int]] = (1, 1),
dilations: Union[int, Sequence[int]] = (1, 1),
groups: int = 1,
) -> Tensor:
return Conv2dChannelLastOp(data, weight, stride, dilations, groups).get_output(0)