-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][linalg] Propagate filter tensor encoding in im2col #160099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: None (fabrizio-indirli) ChangesIn the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op. Full diff: https://github.com/llvm/llvm-project/pull/160099.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 108abe800b13e..12e2b6f5c3f0e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -155,10 +155,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
Location loc = convOp.getLoc();
+ if (!isa<RankedTensorType>(filterType))
+ return rewriter.notifyMatchFailure(
+ convOp, "expected filter type to be a ranked tensor");
+ auto tensorFilterType = cast<RankedTensorType>(filterType);
+
// Reshape output and filter to the LHS and result of a (B)MNK matmul.
SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
auto reshapedFilterType =
- RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
+ RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
+ tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
@@ -435,9 +441,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto loc = convOp.getLoc();
MLIRContext *context = rewriter.getContext();
+ if (!isa<RankedTensorType>(filterType))
+ return rewriter.notifyMatchFailure(
+ convOp, "expected filter type to be a ranked tensor");
+ auto tensorFilterType = cast<RankedTensorType>(filterType);
+
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
- RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
+ RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
+ tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
@@ -560,11 +572,17 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
Location loc = convOp.getLoc();
+ if (!isa<RankedTensorType>(filterType))
+ return rewriter.notifyMatchFailure(
+ convOp, "expected filter type to be a ranked tensor");
+ auto tensorFilterType = cast<RankedTensorType>(filterType);
+
// Reshape output and filter to the LHS and result of a "row-wise" matrix
// multiplication.
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
- RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+ RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
+ tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 8627fcd2576b9..af911e3b3919a 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -290,7 +290,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
// CHECK: linalg.yield %{{.+}} : f32
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
@@ -327,6 +327,34 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK: func.func @conv2d_decompose_im2col_with_filter_encoding
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>
+// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+ // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
+// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+func.func @conv2d_decompose_im2col_with_filter_encoding(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32, 42 : i32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc
+ { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
// Check for signed extend when the input type is smaller than the accumulator type.
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
@llvm/pr-subscribers-mlir-linalg Author: None (fabrizio-indirli) ChangesIn the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op. Full diff: https://github.com/llvm/llvm-project/pull/160099.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 108abe800b13e..12e2b6f5c3f0e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -155,10 +155,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
Location loc = convOp.getLoc();
+ if (!isa<RankedTensorType>(filterType))
+ return rewriter.notifyMatchFailure(
+ convOp, "expected filter type to be a ranked tensor");
+ auto tensorFilterType = cast<RankedTensorType>(filterType);
+
// Reshape output and filter to the LHS and result of a (B)MNK matmul.
SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
auto reshapedFilterType =
- RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
+ RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
+ tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
@@ -435,9 +441,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto loc = convOp.getLoc();
MLIRContext *context = rewriter.getContext();
+ if (!isa<RankedTensorType>(filterType))
+ return rewriter.notifyMatchFailure(
+ convOp, "expected filter type to be a ranked tensor");
+ auto tensorFilterType = cast<RankedTensorType>(filterType);
+
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
- RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
+ RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
+ tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
@@ -560,11 +572,17 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
Location loc = convOp.getLoc();
+ if (!isa<RankedTensorType>(filterType))
+ return rewriter.notifyMatchFailure(
+ convOp, "expected filter type to be a ranked tensor");
+ auto tensorFilterType = cast<RankedTensorType>(filterType);
+
// Reshape output and filter to the LHS and result of a "row-wise" matrix
// multiplication.
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
- RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+ RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
+ tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 8627fcd2576b9..af911e3b3919a 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -290,7 +290,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
// CHECK: linalg.yield %{{.+}} : f32
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
@@ -327,6 +327,34 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK: func.func @conv2d_decompose_im2col_with_filter_encoding
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>
+// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+ // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
+// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+func.func @conv2d_decompose_im2col_with_filter_encoding(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32, 42 : i32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc
+ { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
// Check for signed extend when the input type is smaller than the accumulator type.
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
8797532
to
da4b961
Compare
Thanks for the review @banach-space ! Hopefully I addressed your comments, let me know should you have more suggestions :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM!
I've posted one request, but that's non-blocking and just nice-to-have. No need to wait for me to re-approve if you decide to update this PR.
In the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op. Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli@arm.com> Change-Id: I275d6fad0257d9813b9821341a6160144ae983e7
da4b961
to
d73b6c1
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/40528 Here is the relevant piece of the build log for the reference
|
In the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op. Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli@arm.com>
In the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op.