Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Replace CopyOp from memref to linalg in linalg PromoteOp #69154

Merged
merged 1 commit into from Oct 26, 2023

Conversation

AviadCo
Copy link
Contributor

@AviadCo AviadCo commented Oct 16, 2023

linalg::CopyOp is much more generic and useful to promote buffers. In addition, this is linalg transform and makes more sense to use linalg operations when possible.

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 16, 2023

adding @amrami @amirBish @aniragil as subscribers.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 16, 2023

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Aviad Cohen (AviadCo)

Changes

linalg::CopyOp is much more generic and useful to promote buffers. In addition, this is linalg transform and makes more sense to use linalg operations when possible.


Full diff: https://github.com/llvm/llvm-project/pull/69154.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (+1-1)
  • (modified) mlir/test/Dialect/Linalg/promote.mlir (+13-16)
  • (modified) mlir/test/Dialect/Linalg/promotion_options.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/transform-promotion.mlir (+7-7)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index a131f3097666197..5c140a7d692a930 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -209,7 +209,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
   Location loc = linalgOp.getLoc();
   auto defaultCopyCallBack = [loc](OpBuilder &b, Value src,
                                    Value dst) -> LogicalResult {
-    b.create<memref::CopyOp>(loc, src, dst);
+    b.create<linalg::CopyOp>(loc, src, dst);
     return success();
   };
   copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack);
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index 31b29c0e105d99d..a6f32fea814dd7d 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -52,20 +52,19 @@ func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref<?x?xf32>
 //       CHECK:         %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
 
-//       CHECK:         memref.copy %[[vA]], %[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vB]], %[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vC]], %[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+//       CHECK:         linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
 //
 //       CHECK:         linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
 //
-//       CHECK:         memref.copy %[[partialC]], %[[vC]] :
-//       CHECK:           memref<?x?xf32, strided<[?, 1], offset: ?>> to
-//       CHECK:           memref<?x?xf32, strided<[?, 1], offset: ?>>
+//       CHECK:         linalg.copy ins(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
 //
 //   CHECK-NOT:         memref.dealloc %[[tmpA]] : memref<32xi8>
 //   CHECK-NOT:         memref.dealloc %[[tmpB]] : memref<48xi8>
 //   CHECK-NOT:         memref.dealloc %[[tmpC]] : memref<24xi8>
 
+
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -122,15 +121,13 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf64>
 //       CHECK:         %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>
 
-//       CHECK:         memref.copy %[[vA_f64]], %[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vB_f64]], %[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vC_f64]], %[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
+//       CHECK:         linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
 //
 //       CHECK:         linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
 //
-//       CHECK:         memref.copy %[[partialC_f64]], %[[vC_f64]] :
-//       CHECK:           memref<?x?xf64, strided<[?, 1], offset: ?>> to
-//       CHECK:           memref<?x?xf64, strided<[?, 1], offset: ?>>
+//       CHECK:         linalg.copy ins(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
 //
 //       CHECK:         memref.dealloc %[[tmpA_f64]] : memref<64xi8>
 //       CHECK:         memref.dealloc %[[tmpB_f64]] : memref<96xi8>
@@ -255,7 +252,7 @@ func.func @promote_rank_reducing_subviews(%arg0:  memref<?x?x?x64xf32, strided<[
   // CHECK: %[[c_view:.+]] = memref.view
   // CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]]
 
-  // CHECK-COUNT-3: memref.copy
+  // CHECK-COUNT-3: linalg.copy
   // CHECK: linalg.generic
   // CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]]
   // CHECK-SAME: outs(%[[c_pro_subview]]
@@ -351,8 +348,8 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
   // CHECK:           %[[VAL_60:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           %[[VAL_61:.*]] = memref.view %[[VAL_60]]{{\[}}%[[VAL_56]]]{{\[}}%[[VAL_50]], %[[VAL_53]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
   // CHECK:           %[[VAL_62:.*]] = memref.subview %[[VAL_61]][0, 0] {{\[}}%[[VAL_52]], %[[VAL_55]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
-  // CHECK:           memref.copy %[[VAL_3]], %[[VAL_24]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
-  // CHECK:           memref.copy %[[VAL_4]], %[[VAL_43]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
+// CHECK:           linalg.copy ins(%[[VAL_3]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_24]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
+// CHECK:           linalg.copy ins(%[[VAL_4]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
   // CHECK:           linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_24]], %[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>, memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) {
   // CHECK:           ^bb0(%[[VAL_63:.*]]: f32, %[[VAL_64:.*]]: f32, %[[VAL_65:.*]]: f32):
   // CHECK:             %[[VAL_66:.*]] = arith.addf %[[VAL_63]], %[[VAL_64]] : f32
@@ -366,7 +363,7 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
     linalg.yield %1 : f32
   }
 
-  // CHECK:           memref.copy %[[VAL_62]], %[[VAL_5]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>> to memref<4x3xf32, strided<[4, 1]>, 1>
+  // CHECK:           linalg.copy ins(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_5]] : memref<4x3xf32, strided<[4, 1]>, 1>)
   // CHECK:           memref.dealloc %[[VAL_22]] : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           memref.dealloc %[[VAL_41]] : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           memref.dealloc %[[VAL_60]] : memref<48xi8, #gpu.address_space<workgroup>>
diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index a6daa9af2f37cec..6e5c1c78b329e02 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -27,10 +27,10 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
 //      CHECK:       %[[VC:.*]] = memref.view %[[tmpC]][%[[C0]]][] : memref<1024xi8> to memref<16x16xf32>
 //      CHECK:       %[[svCC:.+]] = memref.subview %[[VC]]
 
-//      CHECK:       memref.copy %[[svA]], %[[svAA]]
-//      CHECK:       memref.copy %[[svC]], %[[svCC]]
+//      CHECK:       linalg.copy ins(%[[svA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svAA]] : memref<?x?xf32, strided<[16, 1]>>)
+//      CHECK:       linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
 //      CHECK:       linalg.matmul ins(%[[VA]], %[[svB]]{{.*}} outs(%[[VC]]
-//      CHECK:       memref.copy %[[svCC]], %[[svC]]
+//      CHECK:       linalg.copy ins(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>) outs(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
 //      CHECK:       memref.dealloc %[[tmpA]]
 //      CHECK:       memref.dealloc %[[tmpC]]
 
diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir
index 2f98e394fe05198..c311d471f6c4ae9 100644
--- a/mlir/test/Dialect/Linalg/transform-promotion.mlir
+++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir
@@ -51,9 +51,9 @@ func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset:
 // CHECK:               %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32>
 // CHECK:               %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
 // CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-// CHECK:               memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK:               memref.copy %[[s1]], %[[l1]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK:               memref.copy %[[s2]], %[[l2]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
+// CHECK:               linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK:               linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK:               linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
 // CHECK:               linalg.matmul
 // CHECK-SAME:                 ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
 // CHECK-SAME:                outs(%[[v2]] : memref<?x?xf32>)
@@ -112,8 +112,8 @@ func.func @promote_first_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], o
 // CHECK-NOT:     memref.alloc
 // CHECK-NOT:     memref.view
 // CHECK-NOT:     memref.subview
-// CHECK:         memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK-NOT:     memref.copy
+// CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK-NOT:     linalg.copy
 // CHECK:         linalg.matmul
 // CHECK-SAME:           ins(%[[v0]], %[[s1]] : memref<?x?xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>)
 // CHECK-SAME:          outs(%[[s2]] : memref<?x?xf32, strided<[?, ?], offset: ?>>)
@@ -148,7 +148,7 @@ func.func @aligned_promote_fill(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
 // CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
 // CHECK:         linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref<?x?xf32>)
-// CHECK:         memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
+// CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
 // CHECK:         linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref<?x?xf32>)
 
 transform.with_pdl_patterns {
@@ -182,7 +182,7 @@ func.func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, strided<
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref<?x?xcomplex<f32>>
 // CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, strided<[?, 1], offset: ?>>
 // CHECK:         linalg.fill ins({{.*}} : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
-// CHECK:         memref.copy %[[s0]], %[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}> to memref<?x?xcomplex<f32>, strided{{.*}}>
+// CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xcomplex<f32>, strided{{.*}}>) outs(%[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}>)
 // CHECK:         linalg.fill ins(%[[cc]] : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
 
 transform.with_pdl_patterns {

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 22, 2023

@nicolasvasilache What do you think about this patch? do you prefer I adjust PromoteOp to accept CopyOp as an option?

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 26, 2023

@nicolasvasilache ping

@nicolasvasilache
Copy link
Contributor

My apologies for the delay, LGTM!

linalg::CopyOp is much more generic and useful to promote buffers. In
addition, this is linalg transform and makes more sense to use linalg
operations when possible.
@AviadCo AviadCo force-pushed the linalg/promote-op-enable-linalg-copy branch from 88c7247 to f18e1e8 Compare October 26, 2023 14:18
@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 26, 2023

@nicolasvasilache thanks!

@nicolasvasilache
Copy link
Contributor

@AviadCo please also kindly note that @chelini has started the transform to specialize a generic to a named op: #70326.
For now it's only for copy but will evolve.

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 26, 2023

@nicolasvasilache this is useful transform! I will also review it

@AviadCo AviadCo merged commit a7d6039 into llvm:main Oct 26, 2023
3 checks passed
@AviadCo AviadCo deleted the linalg/promote-op-enable-linalg-copy branch October 26, 2023 15:54
Guzhu-AMD pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Nov 2, 2023
Local branch amd-gfx c49aa73 Merged main:2633d94f289b into amd-gfx:38011e802b7f
Remote branch main a7d6039 [mlir][linalg] Replace CopyOp from memref to linalg in linalg PromoteOp (llvm#69154)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants