17
17
#include " mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18
18
#include " mlir/IR/PatternMatch.h"
19
19
#include " mlir/Transforms/DialectConversion.h"
20
+ #include " llvm/ADT/STLExtras.h"
20
21
21
22
using namespace mlir ;
22
23
@@ -296,9 +297,16 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
296
297
}
297
298
}
298
299
299
- FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder (
300
- rewriter, funcOp, argsInterchange.getArrayRef (),
301
- resultsInterchange.getArrayRef ());
300
+ llvm::SmallVector<int > oldArgToNewArg (argsInterchange.size ());
301
+ for (auto [newArgIdx, oldArgIdx] : llvm::enumerate (argsInterchange))
302
+ oldArgToNewArg[oldArgIdx] = newArgIdx;
303
+
304
+ llvm::SmallVector<int > oldResToNewRes (resultsInterchange.size ());
305
+ for (auto [newResIdx, oldResIdx] : llvm::enumerate (resultsInterchange))
306
+ oldResToNewRes[oldResIdx] = newResIdx;
307
+
308
+ FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping (
309
+ rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
302
310
if (failed (newFuncOpOrFailure))
303
311
return emitSilenceableFailure (getLoc ())
304
312
<< " failed to replace function signature '" << getFunctionName ()
@@ -312,9 +320,8 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
312
320
});
313
321
314
322
for (func::CallOp callOp : callOps)
315
- func::replaceCallOpWithNewOrder (rewriter, callOp,
316
- argsInterchange.getArrayRef (),
317
- resultsInterchange.getArrayRef ());
323
+ func::replaceCallOpWithNewMapping (rewriter, callOp, oldArgToNewArg,
324
+ oldResToNewRes);
318
325
}
319
326
320
327
results.set (cast<OpResult>(getTransformedModule ()), {targetModuleOp});
@@ -330,6 +337,50 @@ void transform::ReplaceFuncSignatureOp::getEffects(
330
337
transform::modifiesPayload (effects);
331
338
}
332
339
340
+ // ===----------------------------------------------------------------------===//
341
+ // DeduplicateFuncArgsOp
342
+ // ===----------------------------------------------------------------------===//
343
+
344
+ DiagnosedSilenceableFailure
345
+ transform::DeduplicateFuncArgsOp::apply (transform::TransformRewriter &rewriter,
346
+ transform::TransformResults &results,
347
+ transform::TransformState &state) {
348
+ auto payloadOps = state.getPayloadOps (getModule ());
349
+ if (!llvm::hasSingleElement (payloadOps))
350
+ return emitDefiniteFailure () << " requires a single module to operate on" ;
351
+
352
+ auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin ());
353
+ if (!targetModuleOp)
354
+ return emitSilenceableFailure (getLoc ())
355
+ << " target is expected to be module operation" ;
356
+
357
+ func::FuncOp funcOp =
358
+ targetModuleOp.lookupSymbol <func::FuncOp>(getFunctionName ());
359
+ if (!funcOp)
360
+ return emitSilenceableFailure (getLoc ())
361
+ << " function with name '" << getFunctionName () << " ' is not found" ;
362
+
363
+ std::string errorMessage;
364
+ auto transformationResult = func::deduplicateArgsOfFuncOp (
365
+ rewriter, funcOp, targetModuleOp, errorMessage);
366
+ if (failed (transformationResult))
367
+ return emitSilenceableFailure (getLoc ()) << errorMessage;
368
+
369
+ auto [newFuncOp, newCallOp] = *transformationResult;
370
+
371
+ results.set (cast<OpResult>(getTransformedModule ()), {targetModuleOp});
372
+ results.set (cast<OpResult>(getTransformedFunction ()), {newFuncOp});
373
+
374
+ return DiagnosedSilenceableFailure::success ();
375
+ }
376
+
377
+ void transform::DeduplicateFuncArgsOp::getEffects (
378
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
379
+ transform::consumesHandle (getModuleMutable (), effects);
380
+ transform::producesHandle (getOperation ()->getOpResults (), effects);
381
+ transform::modifiesPayload (effects);
382
+ }
383
+
333
384
// ===----------------------------------------------------------------------===//
334
385
// Transform op registration
335
386
// ===----------------------------------------------------------------------===//
0 commit comments