New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Replace CopyOp from memref to linalg in linalg PromoteOp #69154
[mlir][linalg] Replace CopyOp from memref to linalg in linalg PromoteOp #69154
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Aviad Cohen (AviadCo) Changeslinalg::CopyOp is much more generic and useful to promote buffers. In addition, this is linalg transform and makes more sense to use linalg operations when possible. Full diff: https://github.com/llvm/llvm-project/pull/69154.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index a131f3097666197..5c140a7d692a930 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -209,7 +209,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
Location loc = linalgOp.getLoc();
auto defaultCopyCallBack = [loc](OpBuilder &b, Value src,
Value dst) -> LogicalResult {
- b.create<memref::CopyOp>(loc, src, dst);
+ b.create<linalg::CopyOp>(loc, src, dst);
return success();
};
copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack);
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index 31b29c0e105d99d..a6f32fea814dd7d 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -52,20 +52,19 @@ func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref<?x?xf32>
// CHECK: %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-// CHECK: memref.copy %[[vA]], %[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-// CHECK: memref.copy %[[vB]], %[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-// CHECK: memref.copy %[[vC]], %[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+// CHECK: linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+// CHECK: linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
//
// CHECK: linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
//
-// CHECK: memref.copy %[[partialC]], %[[vC]] :
-// CHECK: memref<?x?xf32, strided<[?, 1], offset: ?>> to
-// CHECK: memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: linalg.copy ins(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
//
// CHECK-NOT: memref.dealloc %[[tmpA]] : memref<32xi8>
// CHECK-NOT: memref.dealloc %[[tmpB]] : memref<48xi8>
// CHECK-NOT: memref.dealloc %[[tmpC]] : memref<24xi8>
+
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -122,15 +121,13 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf64>
// CHECK: %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>
-// CHECK: memref.copy %[[vA_f64]], %[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
-// CHECK: memref.copy %[[vB_f64]], %[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
-// CHECK: memref.copy %[[vC_f64]], %[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
+// CHECK: linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+// CHECK: linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+// CHECK: linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
//
// CHECK: linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
//
-// CHECK: memref.copy %[[partialC_f64]], %[[vC_f64]] :
-// CHECK: memref<?x?xf64, strided<[?, 1], offset: ?>> to
-// CHECK: memref<?x?xf64, strided<[?, 1], offset: ?>>
+// CHECK: linalg.copy ins(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
//
// CHECK: memref.dealloc %[[tmpA_f64]] : memref<64xi8>
// CHECK: memref.dealloc %[[tmpB_f64]] : memref<96xi8>
@@ -255,7 +252,7 @@ func.func @promote_rank_reducing_subviews(%arg0: memref<?x?x?x64xf32, strided<[
// CHECK: %[[c_view:.+]] = memref.view
// CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]]
- // CHECK-COUNT-3: memref.copy
+ // CHECK-COUNT-3: linalg.copy
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]]
// CHECK-SAME: outs(%[[c_pro_subview]]
@@ -351,8 +348,8 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
// CHECK: %[[VAL_60:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
// CHECK: %[[VAL_61:.*]] = memref.view %[[VAL_60]]{{\[}}%[[VAL_56]]]{{\[}}%[[VAL_50]], %[[VAL_53]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
// CHECK: %[[VAL_62:.*]] = memref.subview %[[VAL_61]][0, 0] {{\[}}%[[VAL_52]], %[[VAL_55]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
- // CHECK: memref.copy %[[VAL_3]], %[[VAL_24]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
- // CHECK: memref.copy %[[VAL_4]], %[[VAL_43]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
+// CHECK: linalg.copy ins(%[[VAL_3]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_24]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
+// CHECK: linalg.copy ins(%[[VAL_4]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_24]], %[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>, memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) {
// CHECK: ^bb0(%[[VAL_63:.*]]: f32, %[[VAL_64:.*]]: f32, %[[VAL_65:.*]]: f32):
// CHECK: %[[VAL_66:.*]] = arith.addf %[[VAL_63]], %[[VAL_64]] : f32
@@ -366,7 +363,7 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
linalg.yield %1 : f32
}
- // CHECK: memref.copy %[[VAL_62]], %[[VAL_5]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>> to memref<4x3xf32, strided<[4, 1]>, 1>
+ // CHECK: linalg.copy ins(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_5]] : memref<4x3xf32, strided<[4, 1]>, 1>)
// CHECK: memref.dealloc %[[VAL_22]] : memref<48xi8, #gpu.address_space<workgroup>>
// CHECK: memref.dealloc %[[VAL_41]] : memref<48xi8, #gpu.address_space<workgroup>>
// CHECK: memref.dealloc %[[VAL_60]] : memref<48xi8, #gpu.address_space<workgroup>>
diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index a6daa9af2f37cec..6e5c1c78b329e02 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -27,10 +27,10 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
// CHECK: %[[VC:.*]] = memref.view %[[tmpC]][%[[C0]]][] : memref<1024xi8> to memref<16x16xf32>
// CHECK: %[[svCC:.+]] = memref.subview %[[VC]]
-// CHECK: memref.copy %[[svA]], %[[svAA]]
-// CHECK: memref.copy %[[svC]], %[[svCC]]
+// CHECK: linalg.copy ins(%[[svA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svAA]] : memref<?x?xf32, strided<[16, 1]>>)
+// CHECK: linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
// CHECK: linalg.matmul ins(%[[VA]], %[[svB]]{{.*}} outs(%[[VC]]
-// CHECK: memref.copy %[[svCC]], %[[svC]]
+// CHECK: linalg.copy ins(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>) outs(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
// CHECK: memref.dealloc %[[tmpA]]
// CHECK: memref.dealloc %[[tmpC]]
diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir
index 2f98e394fe05198..c311d471f6c4ae9 100644
--- a/mlir/test/Dialect/Linalg/transform-promotion.mlir
+++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir
@@ -51,9 +51,9 @@ func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset:
// CHECK: %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32>
// CHECK: %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
// CHECK-SAME: memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-// CHECK: memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK: memref.copy %[[s1]], %[[l1]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK: memref.copy %[[s2]], %[[l2]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
+// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK: linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK: linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[v2]] : memref<?x?xf32>)
@@ -112,8 +112,8 @@ func.func @promote_first_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], o
// CHECK-NOT: memref.alloc
// CHECK-NOT: memref.view
// CHECK-NOT: memref.subview
-// CHECK: memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK-NOT: memref.copy
+// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK-NOT: linalg.copy
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref<?x?xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>)
// CHECK-SAME: outs(%[[s2]] : memref<?x?xf32, strided<[?, ?], offset: ?>>)
@@ -148,7 +148,7 @@ func.func @aligned_promote_fill(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?
// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
// CHECK: linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref<?x?xf32>)
-// CHECK: memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
+// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
// CHECK: linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref<?x?xf32>)
transform.with_pdl_patterns {
@@ -182,7 +182,7 @@ func.func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, strided<
// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref<?x?xcomplex<f32>>
// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, strided<[?, 1], offset: ?>>
// CHECK: linalg.fill ins({{.*}} : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
-// CHECK: memref.copy %[[s0]], %[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}> to memref<?x?xcomplex<f32>, strided{{.*}}>
+// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xcomplex<f32>, strided{{.*}}>) outs(%[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}>)
// CHECK: linalg.fill ins(%[[cc]] : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
transform.with_pdl_patterns {
|
@nicolasvasilache What do you think about this patch? do you prefer I adjust PromoteOp to accept CopyOp as an option? |
@nicolasvasilache ping |
My apologies for the delay, LGTM! |
linalg::CopyOp is much more generic and useful to promote buffers. In addition, this is linalg transform and makes more sense to use linalg operations when possible.
88c7247
to
f18e1e8
Compare
@nicolasvasilache thanks! |
@nicolasvasilache this is useful transform! I will also review it |
Local branch amd-gfx c49aa73 Merged main:2633d94f289b into amd-gfx:38011e802b7f Remote branch main a7d6039 [mlir][linalg] Replace CopyOp from memref to linalg in linalg PromoteOp (llvm#69154)
linalg::CopyOp is much more generic and useful to promote buffers. In addition, this is linalg transform and makes more sense to use linalg operations when possible.