diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 41670249936e6..7266687584b38 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1858,6 +1858,7 @@ void populateDecomposePadPatterns(RewritePatternSet &patterns); /// Populates patterns to transform linalg.conv_2d_xxx operations into /// linalg.generic (for img2col packing) and linalg.matmul. +/// Note: currently limited to Tensor semantics only. /// \see rewriteInIm2Col for more details. void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp index 108abe800b13e..ebc4dcf6bbcb5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include #include namespace mlir { @@ -124,6 +125,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { auto filterType = cast(convOp.getInputs()[1].getType()); auto outputType = cast(convOp.getOutputs()[0].getType()); + if (!convOp.hasPureTensorSemantics()) + return rewriter.notifyMatchFailure( + convOp, "expected op to have pure tensor semantics"); + if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( convOp, "expected a static shape for the filter"); @@ -155,10 +160,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { Location loc = convOp.getLoc(); + assert(isa(filterType) && + "expected filter type to be a ranked tensor"); + auto tensorFilterType = cast(filterType); + // Reshape output and filter to the LHS and result of a (B)MNK matmul. SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; auto reshapedFilterType = - RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType()); + RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(), + tensorFilterType.getEncoding()); Value reshapedFilter = tensor::CollapseShapeOp::create( rewriter, loc, reshapedFilterType, filter, filterReassocIndices); @@ -253,6 +263,10 @@ rewriteInIm2Col(RewriterBase &rewriter, auto filterType = cast(convOp.getInputs()[1].getType()); auto outputType = cast(convOp.getOutputs()[0].getType()); + if (!convOp.hasPureTensorSemantics()) + return rewriter.notifyMatchFailure( + convOp, "expected op to have pure tensor semantics"); + if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( convOp, "expected a static shape for the filter"); @@ -404,6 +418,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { auto filterType = cast(convOp.getInputs()[1].getType()); auto outputType = cast(convOp.getOutputs()[0].getType()); + if (!convOp.hasPureTensorSemantics()) + return rewriter.notifyMatchFailure( + convOp, "expected op to have pure tensor semantics"); + if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( convOp, "expected a static shape for the filter"); @@ -435,9 +453,14 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { auto loc = convOp.getLoc(); MLIRContext *context = rewriter.getContext(); + assert(isa(filterType) && + "expected filter type to be a ranked tensor"); + auto tensorFilterType = cast(filterType); + SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; auto reshapedFilterType = - RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType()); + RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(), + tensorFilterType.getEncoding()); Value reshapedFilter = tensor::CollapseShapeOp::create( rewriter, loc, reshapedFilterType, filter, filterReassocIndices); @@ -529,6 +552,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { auto filterType = cast(convOp.getInputs()[1].getType()); auto outputType = cast(convOp.getOutputs()[0].getType()); + if (!convOp.hasPureTensorSemantics()) + return rewriter.notifyMatchFailure( + convOp, "expected op to have pure tensor semantics"); + if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( convOp, "expected a static shape for the filter"); @@ -560,11 +587,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { Location loc = convOp.getLoc(); + assert(isa(filterType) && + "expected filter type to be a ranked tensor"); + auto tensorFilterType = cast(filterType); + // Reshape output and filter to the LHS and result of a "row-wise" matrix // multiplication. SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; auto reshapedFilterType = - RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType()); + RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(), + tensorFilterType.getEncoding()); Value reshapedFilter = tensor::CollapseShapeOp::create( rewriter, loc, reshapedFilterType, filter, filterReassocIndices); diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir index 8627fcd2576b9..152a392afe247 100644 --- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -26,6 +26,26 @@ module attributes {transform.with_named_sequence} { // ----- +// Memref semantics is not supported. +// Check that we emit an error. +func.func @negative_conv_memref(%arg0: memref<1x16x16x4xf32>, %arg1: memref<16x3x3x4xf32>, %arg2: memref<1x14x14x16xf32>) { + // expected-note@below {{when applied to this op}} + linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : memref<2xi64>, strides = dense<1> : memref<2xi64> } + ins(%arg0, %arg1: memref<1x16x16x4xf32>, memref<16x3x3x4xf32>) outs(%arg2: memref<1x14x14x16xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{failed to apply}} + %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + // Check that we get the proper handles for the img2col tensor producer // and the final instruction. @@ -267,6 +287,31 @@ module attributes {transform.with_named_sequence} { // ----- +// Check that the encoding on the filter (weights) tensor is propagated when applying the transform. + +// CHECK: func.func @batch_nchw_conv_with_filter_encoding(%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.*]]: tensor<16x4x3x3xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<8x16x14x14xf32>) +// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] + // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x4x3x3xf32, 42 : i32> into tensor<16x36xf32, 42 : i32> +// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<8x4x16x16xf32>) +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COLLAPSED_FILTER]], %[[COL_TENSOR]] : tensor<16x36xf32, 42 : i32>, tensor<8x36x196xf32>) +func.func @batch_nchw_conv_with_filter_encoding(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32, 42 : i32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32, 42 : i32>) + outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> + return %0 : tensor<8x16x14x14xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + // CHECK: IR printer: tensor_producer // CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)> @@ -290,7 +335,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32> // CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32> // CHECK: %[[COL_TENSOR:.+]] = linalg.generic -// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>) // CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) // CHECK: linalg.yield %{{.+}} : f32 // CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic @@ -327,6 +372,31 @@ module attributes {transform.with_named_sequence} { // ----- +// Check that the encoding on the filter (weights) tensor is propagated when applying the transform. + +// CHECK: func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%[[INPUT:.+]]: tensor<1x16x16x4xf32>, %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>) +// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] + // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32> +// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32, 42 : i32>) +func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%input: tensor<1x16x16x4xf32>, %filter: tensor<16x3x3x4xf32, 42 : i32>, %out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + %0 = linalg.conv_2d_nhwc_fhwc + { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%input, %filter: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>) + outs(%out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + // Check for signed extend when the input type is smaller than the accumulator type. // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>