Skip to content

Commit

Permalink
[mlir][linalg] Replace CopyOp from memref to linalg in linalg Promote…
Browse files Browse the repository at this point in the history
…Op (#69154)

linalg::CopyOp is much more generic and useful to promote buffers. In addition, this is linalg transform and makes more sense to use linalg operations when possible.
  • Loading branch information
AviadCo committed Oct 26, 2023
1 parent 196d154 commit a7d6039
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 27 deletions.
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
Location loc = linalgOp.getLoc();
auto defaultCopyCallBack = [loc](OpBuilder &b, Value src,
Value dst) -> LogicalResult {
b.create<memref::CopyOp>(loc, src, dst);
b.create<linalg::CopyOp>(loc, src, dst);
return success();
};
copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack);
Expand Down
28 changes: 12 additions & 16 deletions mlir/test/Dialect/Linalg/promote.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,13 @@ func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref<?x?xf32>
// CHECK: %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>

// CHECK: memref.copy %[[vA]], %[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
// CHECK: memref.copy %[[vB]], %[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
// CHECK: memref.copy %[[vC]], %[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
// CHECK: linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
// CHECK: linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
// CHECK: linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
//
// CHECK: linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
//
// CHECK: memref.copy %[[partialC]], %[[vC]] :
// CHECK: memref<?x?xf32, strided<[?, 1], offset: ?>> to
// CHECK: memref<?x?xf32, strided<[?, 1], offset: ?>>
// CHECK: linalg.copy ins(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
//
// CHECK-NOT: memref.dealloc %[[tmpA]] : memref<32xi8>
// CHECK-NOT: memref.dealloc %[[tmpB]] : memref<48xi8>
Expand Down Expand Up @@ -124,15 +122,13 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf64>
// CHECK: %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>

// CHECK: memref.copy %[[vA_f64]], %[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
// CHECK: memref.copy %[[vB_f64]], %[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
// CHECK: memref.copy %[[vC_f64]], %[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
// CHECK: linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
// CHECK: linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
// CHECK: linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
//
// CHECK: linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
//
// CHECK: memref.copy %[[partialC_f64]], %[[vC_f64]] :
// CHECK: memref<?x?xf64, strided<[?, 1], offset: ?>> to
// CHECK: memref<?x?xf64, strided<[?, 1], offset: ?>>
// CHECK: linalg.copy ins(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
//
// CHECK: memref.dealloc %[[tmpA_f64]] : memref<64xi8>
// CHECK: memref.dealloc %[[tmpB_f64]] : memref<96xi8>
Expand Down Expand Up @@ -263,7 +259,7 @@ func.func @promote_rank_reducing_subviews(%arg0: memref<?x?x?x64xf32, strided<[
// CHECK: %[[c_view:.+]] = memref.view
// CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]]

// CHECK-COUNT-3: memref.copy
// CHECK-COUNT-3: linalg.copy
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]]
// CHECK-SAME: outs(%[[c_pro_subview]]
Expand Down Expand Up @@ -361,8 +357,8 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
// CHECK: %[[VAL_60:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
// CHECK: %[[VAL_61:.*]] = memref.view %[[VAL_60]]{{\[}}%[[VAL_56]]]{{\[}}%[[VAL_50]], %[[VAL_53]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
// CHECK: %[[VAL_62:.*]] = memref.subview %[[VAL_61]][0, 0] {{\[}}%[[VAL_52]], %[[VAL_55]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
// CHECK: memref.copy %[[VAL_3]], %[[VAL_24]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
// CHECK: memref.copy %[[VAL_4]], %[[VAL_43]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
// CHECK: linalg.copy ins(%[[VAL_3]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_24]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
// CHECK: linalg.copy ins(%[[VAL_4]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_24]], %[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>, memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) {
// CHECK: ^bb0(%[[VAL_63:.*]]: f32, %[[VAL_64:.*]]: f32, %[[VAL_65:.*]]: f32):
// CHECK: %[[VAL_66:.*]] = arith.addf %[[VAL_63]], %[[VAL_64]] : f32
Expand All @@ -376,7 +372,7 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
linalg.yield %1 : f32
}

// CHECK: memref.copy %[[VAL_62]], %[[VAL_5]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>> to memref<4x3xf32, strided<[4, 1]>, 1>
// CHECK: linalg.copy ins(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_5]] : memref<4x3xf32, strided<[4, 1]>, 1>)
// CHECK: memref.dealloc %[[VAL_22]] : memref<48xi8, #gpu.address_space<workgroup>>
// CHECK: memref.dealloc %[[VAL_41]] : memref<48xi8, #gpu.address_space<workgroup>>
// CHECK: memref.dealloc %[[VAL_60]] : memref<48xi8, #gpu.address_space<workgroup>>
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Linalg/promotion_options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
// CHECK: %[[VC:.*]] = memref.view %[[tmpC]][%[[C0]]][] : memref<1024xi8> to memref<16x16xf32>
// CHECK: %[[svCC:.+]] = memref.subview %[[VC]]

// CHECK: memref.copy %[[svA]], %[[svAA]]
// CHECK: memref.copy %[[svC]], %[[svCC]]
// CHECK: linalg.copy ins(%[[svA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svAA]] : memref<?x?xf32, strided<[16, 1]>>)
// CHECK: linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
// CHECK: linalg.matmul ins(%[[VA]], %[[svB]]{{.*}} outs(%[[VC]]
// CHECK: memref.copy %[[svCC]], %[[svC]]
// CHECK: linalg.copy ins(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>) outs(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
// CHECK: memref.dealloc %[[tmpA]]
// CHECK: memref.dealloc %[[tmpC]]

Expand Down
14 changes: 7 additions & 7 deletions mlir/test/Dialect/Linalg/transform-promotion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset:
// CHECK: %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32>
// CHECK: %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
// CHECK-SAME: memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
// CHECK: memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
// CHECK: memref.copy %[[s1]], %[[l1]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
// CHECK: memref.copy %[[s2]], %[[l2]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
// CHECK: linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>)
// CHECK: linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[v2]] : memref<?x?xf32>)
Expand Down Expand Up @@ -114,8 +114,8 @@ func.func @promote_first_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], o
// CHECK-NOT: memref.alloc
// CHECK-NOT: memref.view
// CHECK-NOT: memref.subview
// CHECK: memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
// CHECK-NOT: memref.copy
// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
// CHECK-NOT: linalg.copy
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref<?x?xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>)
// CHECK-SAME: outs(%[[s2]] : memref<?x?xf32, strided<[?, ?], offset: ?>>)
Expand Down Expand Up @@ -149,7 +149,7 @@ func.func @aligned_promote_fill(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?
// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
// CHECK: linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref<?x?xf32>)
// CHECK: memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
// CHECK: linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref<?x?xf32>)

module attributes {transform.with_named_sequence} {
Expand Down Expand Up @@ -182,7 +182,7 @@ func.func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, strided<
// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref<?x?xcomplex<f32>>
// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, strided<[?, 1], offset: ?>>
// CHECK: linalg.fill ins({{.*}} : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
// CHECK: memref.copy %[[s0]], %[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}> to memref<?x?xcomplex<f32>, strided{{.*}}>
// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xcomplex<f32>, strided{{.*}}>) outs(%[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}>)
// CHECK: linalg.fill ins(%[[cc]] : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)

module attributes {transform.with_named_sequence} {
Expand Down

0 comments on commit a7d6039

Please sign in to comment.