@@ -104,16 +104,18 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
104
104
return success ();
105
105
}
106
106
107
- // Specialization for `linalg::GenericOp`.
107
+ // / Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp `.
108
108
// / A pattern to convert Generic Linalg operations which work on tensors to
109
109
// / use buffers. BufferPlacement pass should be later used to move
110
110
// / Alloc operations to the correct positions and insert the missing Dealloc
111
111
// / operations in the correct places.
112
- static void finalizeBufferAllocation (ConversionPatternRewriter &rewriter,
113
- linalg::GenericOp genericOp,
114
- ValueRange inputs, ValueRange outputs) {
112
+ template <typename GenericOpTy>
113
+ static void
114
+ finalizeBufferAllocationForGenericOp (ConversionPatternRewriter &rewriter,
115
+ GenericOpTy genericOp, ValueRange inputs,
116
+ ValueRange outputs) {
115
117
// Generate a new linalg operation that works on buffers.
116
- auto newGenericOp = rewriter.create <linalg::GenericOp >(
118
+ auto newGenericOp = rewriter.create <GenericOpTy >(
117
119
genericOp.getLoc (),
118
120
/* resultTensorTypes=*/ llvm::None,
119
121
/* inputs=*/ inputs,
@@ -147,9 +149,7 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
147
149
rewriter.replaceOp (genericOp, outputs);
148
150
}
149
151
150
- // TODO: Specialization for `linalg::IndexedGenericOp`.
151
-
152
- // Specialization for all other `linalg::LinalgOp`.
152
+ // / Specialization for all other `linalg::LinalgOp`.
153
153
static void finalizeBufferAllocation (ConversionPatternRewriter &rewriter,
154
154
linalg::LinalgOp linalgOp,
155
155
ValueRange inputs, ValueRange outputs) {
@@ -207,8 +207,15 @@ class BufferizeAnyLinalgOp : public ConversionPattern {
207
207
208
208
// Delegate to the linalg generic pattern.
209
209
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
210
- finalizeBufferAllocation (rewriter, genericOp, adaptor.inputs (),
211
- newOutputBuffers);
210
+ finalizeBufferAllocationForGenericOp<GenericOp>(
211
+ rewriter, genericOp, adaptor.inputs (), newOutputBuffers);
212
+ return success ();
213
+ }
214
+
215
+ // Delegate to the linalg indexed generic pattern.
216
+ if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) {
217
+ finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
218
+ rewriter, genericOp, adaptor.inputs (), newOutputBuffers);
212
219
return success ();
213
220
}
214
221
0 commit comments