Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,7 @@ void populateDecomposePadPatterns(RewritePatternSet &patterns);

/// Populates patterns to transform linalg.conv_2d_xxx operations into
/// linalg.generic (for img2col packing) and linalg.matmul.
/// Note: currently limited to Tensor semantics only.
/// \see rewriteInIm2Col for more details.
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);

Expand Down
38 changes: 35 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include <cassert>
#include <utility>

namespace mlir {
Expand Down Expand Up @@ -124,6 +125,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());

if (!convOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
convOp, "expected op to have pure tensor semantics");

if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");
Expand Down Expand Up @@ -155,10 +160,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {

Location loc = convOp.getLoc();

assert(isa<RankedTensorType>(filterType) &&
"expected filter type to be a ranked tensor");
auto tensorFilterType = cast<RankedTensorType>(filterType);

// Reshape output and filter to the LHS and result of a (B)MNK matmul.
SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
auto reshapedFilterType =
RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);

Expand Down Expand Up @@ -253,6 +263,10 @@ rewriteInIm2Col(RewriterBase &rewriter,
auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());

if (!convOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
convOp, "expected op to have pure tensor semantics");

if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");
Expand Down Expand Up @@ -404,6 +418,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());

if (!convOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
convOp, "expected op to have pure tensor semantics");

if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");
Expand Down Expand Up @@ -435,9 +453,14 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto loc = convOp.getLoc();
MLIRContext *context = rewriter.getContext();

assert(isa<RankedTensorType>(filterType) &&
"expected filter type to be a ranked tensor");
auto tensorFilterType = cast<RankedTensorType>(filterType);

SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);

Expand Down Expand Up @@ -529,6 +552,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());

if (!convOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
convOp, "expected op to have pure tensor semantics");

if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");
Expand Down Expand Up @@ -560,11 +587,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {

Location loc = convOp.getLoc();

assert(isa<RankedTensorType>(filterType) &&
"expected filter type to be a ranked tensor");
auto tensorFilterType = cast<RankedTensorType>(filterType);

// Reshape output and filter to the LHS and result of a "row-wise" matrix
// multiplication.
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);

Expand Down
72 changes: 71 additions & 1 deletion mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,26 @@ module attributes {transform.with_named_sequence} {

// -----

// Memref semantics is not supported.
// Check that we emit an error.
func.func @negative_conv_memref(%arg0: memref<1x16x16x4xf32>, %arg1: memref<16x3x3x4xf32>, %arg2: memref<1x14x14x16xf32>) {
// expected-note@below {{when applied to this op}}
linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : memref<2xi64>, strides = dense<1> : memref<2xi64> }
ins(%arg0, %arg1: memref<1x16x16x4xf32>, memref<16x3x3x4xf32>) outs(%arg2: memref<1x14x14x16xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error@below {{failed to apply}}
%img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

// Check that we get the proper handles for the img2col tensor producer
// and the final instruction.

Expand Down Expand Up @@ -267,6 +287,31 @@ module attributes {transform.with_named_sequence} {

// -----

// Check that the encoding on the filter (weights) tensor is propagated when applying the transform.

// CHECK: func.func @batch_nchw_conv_with_filter_encoding(%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.*]]: tensor<16x4x3x3xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<8x16x14x14xf32>)
// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x4x3x3xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COLLAPSED_FILTER]], %[[COL_TENSOR]] : tensor<16x36xf32, 42 : i32>, tensor<8x36x196xf32>)
func.func @batch_nchw_conv_with_filter_encoding(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32, 42 : i32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
%0 = linalg.conv_2d_nchw_fchw
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32, 42 : i32>)
outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
return %0 : tensor<8x16x14x14xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

// CHECK: IR printer: tensor_producer
// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
Expand All @@ -290,7 +335,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
// CHECK: linalg.yield %{{.+}} : f32
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
Expand Down Expand Up @@ -327,6 +372,31 @@ module attributes {transform.with_named_sequence} {

// -----

// Check that the encoding on the filter (weights) tensor is propagated when applying the transform.

// CHECK: func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%[[INPUT:.+]]: tensor<1x16x16x4xf32>, %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>)
// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>)
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32, 42 : i32>)
func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%input: tensor<1x16x16x4xf32>, %filter: tensor<16x3x3x4xf32, 42 : i32>, %out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%0 = linalg.conv_2d_nhwc_fhwc
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>)
outs(%out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

// Check for signed extend when the input type is smaller than the accumulator type.

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
Expand Down