Skip to content

Commit fd92c5d

Browse files
committed
[mlir][linalg] Add bufferization pattern for linalg.indexed_generic.
Differential Revision: https://reviews.llvm.org/D92014
1 parent 7b52542 commit fd92c5d

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,18 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
104104
return success();
105105
}
106106

107-
// Specialization for `linalg::GenericOp`.
107+
/// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`.
108108
/// A pattern to convert Generic Linalg operations which work on tensors to
109109
/// use buffers. BufferPlacement pass should be later used to move
110110
/// Alloc operations to the correct positions and insert the missing Dealloc
111111
/// operations in the correct places.
112-
static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
113-
linalg::GenericOp genericOp,
114-
ValueRange inputs, ValueRange outputs) {
112+
template <typename GenericOpTy>
113+
static void
114+
finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
115+
GenericOpTy genericOp, ValueRange inputs,
116+
ValueRange outputs) {
115117
// Generate a new linalg operation that works on buffers.
116-
auto newGenericOp = rewriter.create<linalg::GenericOp>(
118+
auto newGenericOp = rewriter.create<GenericOpTy>(
117119
genericOp.getLoc(),
118120
/*resultTensorTypes=*/llvm::None,
119121
/*inputs=*/inputs,
@@ -147,9 +149,7 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
147149
rewriter.replaceOp(genericOp, outputs);
148150
}
149151

150-
// TODO: Specialization for `linalg::IndexedGenericOp`.
151-
152-
// Specialization for all other `linalg::LinalgOp`.
152+
/// Specialization for all other `linalg::LinalgOp`.
153153
static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
154154
linalg::LinalgOp linalgOp,
155155
ValueRange inputs, ValueRange outputs) {
@@ -207,8 +207,15 @@ class BufferizeAnyLinalgOp : public ConversionPattern {
207207

208208
// Delegate to the linalg generic pattern.
209209
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
210-
finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(),
211-
newOutputBuffers);
210+
finalizeBufferAllocationForGenericOp<GenericOp>(
211+
rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
212+
return success();
213+
}
214+
215+
// Delegate to the linalg indexed generic pattern.
216+
if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) {
217+
finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
218+
rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
212219
return success();
213220
}
214221

mlir/test/Dialect/Linalg/bufferize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ func @basic(%arg0: tensor<4xf32>) -> tensor<4xf32> {
4545
// CHECK: linalg.generic
4646
// CHECK-SAME: ins(%{{.*}} : memref<4xf32>)
4747
// CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xf32>, memref<4xf32>)
48+
// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
4849
func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
4950
%0, %1 = linalg.generic {
5051
indexing_maps = [#map0, #map0, #map0],
@@ -59,6 +60,31 @@ func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
5960

6061
// -----
6162

63+
#map0 = affine_map<(d0) -> (d0)>
64+
65+
// CHECK-LABEL: func @multiple_results_indexed
66+
// CHECK: %[[RESULT0:.*]] = alloc() : memref<4xi32>
67+
// CHECK: %[[RESULT1:.*]] = alloc() : memref<4xi32>
68+
// CHECK: linalg.indexed_generic
69+
// CHECK-SAME: ins(%{{.*}} : memref<4xi32>)
70+
// CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xi32>, memref<4xi32>)
71+
// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32):
72+
func @multiple_results_indexed(%arg0: tensor<4xi32>)
73+
-> (tensor<4xi32>, tensor<4xi32>) {
74+
%0, %1 = linalg.indexed_generic {
75+
indexing_maps = [#map0, #map0, #map0],
76+
iterator_types = ["parallel"]
77+
} ins(%arg0 : tensor<4xi32>) {
78+
^bb0(%i: index, %gen_arg1: i32):
79+
%i_i32 = index_cast %i : index to i32
80+
%tmp1 = addi %gen_arg1, %i_i32 : i32
81+
linalg.yield %tmp1, %tmp1 : i32, i32
82+
} -> tensor<4xi32>, tensor<4xi32>
83+
return %0, %1 : tensor<4xi32>, tensor<4xi32>
84+
}
85+
86+
// -----
87+
6288
#map_2d = affine_map<(d0, d1) -> (d0, d1)>
6389
#map_2d_inv = affine_map<(d0, d1) -> (d1, d0)>
6490

0 commit comments

Comments
 (0)