Skip to content

Commit bfc15d5

Browse files
committed
[mlir][func]-Add deduplicate funcOp arguments transform
This PR adds a new transform operation which removes the duplicate arguments from the function operation based on the callOp of this function. To have a more simple implementation for now, the transform will fail when having multiple callOps for the same function we want to eliminate the different arguments from. This pull request also adpat the utils under the func dialect to be reusable also for this transformOp.
1 parent 4b03252 commit bfc15d5

File tree

6 files changed

+447
-67
lines changed

6 files changed

+447
-67
lines changed

mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,30 @@ def ReplaceFuncSignatureOp
134134
}];
135135
}
136136

137+
def DeduplicateFuncArgsOp
138+
: Op<Transform_Dialect, "func.deduplicate_func_args",
139+
[DeclareOpInterfaceMethods<TransformOpInterface>,
140+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
141+
let description = [{
142+
This transform takes a module and a function name, and deduplicates
143+
the arguments of the function. The function is expected to be defined in
144+
the module.
145+
146+
This transform will emit a silenceable failure if:
147+
- The function with the given name does not exist in the module.
148+
- The function does not have duplicate arguments.
149+
- The function does not have a single call.
150+
}];
151+
152+
let arguments = (ins TransformHandleTypeInterface:$module,
153+
SymbolRefAttr:$function_name);
154+
let results = (outs TransformHandleTypeInterface:$transformed_module,
155+
TransformHandleTypeInterface:$transformed_function);
156+
157+
let assemblyFormat = [{
158+
$function_name
159+
`at` $module attr-dict `:` functional-type(operands, results)
160+
}];
161+
}
162+
137163
#endif // FUNC_TRANSFORM_OPS

mlir/include/mlir/Dialect/Func/Utils/Utils.h

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,50 @@
1818

1919
#include "mlir/IR/PatternMatch.h"
2020
#include "llvm/ADT/ArrayRef.h"
21+
#include <string>
2122

2223
namespace mlir {
2324

25+
class ModuleOp;
26+
2427
namespace func {
2528

2629
class FuncOp;
2730
class CallOp;
2831

2932
/// Creates a new function operation with the same name as the original
30-
/// function operation, but with the arguments reordered according to
31-
/// the `newArgsOrder` and `newResultsOrder`.
33+
/// function operation, but with the arguments mapped according to
34+
/// the `oldArgToNewArg` and `oldResToNewRes`.
3235
/// The `funcOp` operation must have exactly one block.
3336
/// Returns the new function operation or failure if `funcOp` doesn't
3437
/// have exactly one block.
35-
FailureOr<FuncOp>
36-
replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
37-
llvm::ArrayRef<unsigned> newArgsOrder,
38-
llvm::ArrayRef<unsigned> newResultsOrder);
38+
/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
39+
/// maps the whole function arguments and results.
40+
mlir::FailureOr<mlir::func::FuncOp> replaceFuncWithNewMapping(
41+
mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
42+
ArrayRef<int> oldArgIdxToNewArgIdx, ArrayRef<int> oldResIdxToNewResIdx);
3943
/// Creates a new call operation with the values as the original
40-
/// call operation, but with the arguments reordered according to
41-
/// the `newArgsOrder` and `newResultsOrder`.
42-
CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
43-
llvm::ArrayRef<unsigned> newArgsOrder,
44-
llvm::ArrayRef<unsigned> newResultsOrder);
44+
/// call operation, but with the arguments mapped according to
45+
/// the `oldArgToNewArg` and `oldResToNewRes`.
46+
/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
47+
/// maps the whole call operation arguments and results.
48+
mlir::func::CallOp replaceCallOpWithNewMapping(
49+
mlir::RewriterBase &rewriter, mlir::func::CallOp callOp,
50+
ArrayRef<int> oldArgIdxToNewArgIdx, ArrayRef<int> oldResIdxToNewResIdx);
51+
52+
/// This utility function examines all call operations within the given
53+
/// `moduleOp` that target the specified `funcOp`. It identifies duplicate
54+
/// operands in the call operations, creates mappings to deduplicate them, and
55+
/// then applies the transformation to both the function and its call sites. For
56+
/// now, it only supports one call operation for the function operation. The
57+
/// function returns a pair containing the new funcOp and the new callOp. Note:
58+
/// after the transformation, the original funcOp and callOp will be erased. The
59+
/// `errorMessage` will be set to the error message if the transformation fails.
60+
mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
61+
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
62+
mlir::ModuleOp moduleOp, std::string &errorMessage);
4563

4664
} // namespace func
4765
} // namespace mlir
4866

49-
#endif // MLIR_DIALECT_FUNC_UTILS_H
67+
#endif // MLIR_DIALECT_FUNC_UTILS_H

mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1818
#include "mlir/IR/PatternMatch.h"
1919
#include "mlir/Transforms/DialectConversion.h"
20+
#include "llvm/ADT/STLExtras.h"
2021

2122
using namespace mlir;
2223

@@ -296,9 +297,16 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
296297
}
297298
}
298299

299-
FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
300-
rewriter, funcOp, argsInterchange.getArrayRef(),
301-
resultsInterchange.getArrayRef());
300+
llvm::SmallVector<int> oldArgToNewArg(argsInterchange.size());
301+
for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(argsInterchange))
302+
oldArgToNewArg[oldArgIdx] = newArgIdx;
303+
304+
llvm::SmallVector<int> oldResToNewRes(resultsInterchange.size());
305+
for (auto [newResIdx, oldResIdx] : llvm::enumerate(resultsInterchange))
306+
oldResToNewRes[oldResIdx] = newResIdx;
307+
308+
FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
309+
rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
302310
if (failed(newFuncOpOrFailure))
303311
return emitSilenceableFailure(getLoc())
304312
<< "failed to replace function signature '" << getFunctionName()
@@ -312,9 +320,8 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
312320
});
313321

314322
for (func::CallOp callOp : callOps)
315-
func::replaceCallOpWithNewOrder(rewriter, callOp,
316-
argsInterchange.getArrayRef(),
317-
resultsInterchange.getArrayRef());
323+
func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgToNewArg,
324+
oldResToNewRes);
318325
}
319326

320327
results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
@@ -330,6 +337,50 @@ void transform::ReplaceFuncSignatureOp::getEffects(
330337
transform::modifiesPayload(effects);
331338
}
332339

340+
//===----------------------------------------------------------------------===//
341+
// DeduplicateFuncArgsOp
342+
//===----------------------------------------------------------------------===//
343+
344+
DiagnosedSilenceableFailure
345+
transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter,
346+
transform::TransformResults &results,
347+
transform::TransformState &state) {
348+
auto payloadOps = state.getPayloadOps(getModule());
349+
if (!llvm::hasSingleElement(payloadOps))
350+
return emitDefiniteFailure() << "requires a single module to operate on";
351+
352+
auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
353+
if (!targetModuleOp)
354+
return emitSilenceableFailure(getLoc())
355+
<< "target is expected to be module operation";
356+
357+
func::FuncOp funcOp =
358+
targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
359+
if (!funcOp)
360+
return emitSilenceableFailure(getLoc())
361+
<< "function with name '" << getFunctionName() << "' is not found";
362+
363+
std::string errorMessage;
364+
auto transformationResult = func::deduplicateArgsOfFuncOp(
365+
rewriter, funcOp, targetModuleOp, errorMessage);
366+
if (failed(transformationResult))
367+
return emitSilenceableFailure(getLoc()) << errorMessage;
368+
369+
auto [newFuncOp, newCallOp] = *transformationResult;
370+
371+
results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
372+
results.set(cast<OpResult>(getTransformedFunction()), {newFuncOp});
373+
374+
return DiagnosedSilenceableFailure::success();
375+
}
376+
377+
void transform::DeduplicateFuncArgsOp::getEffects(
378+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
379+
transform::consumesHandle(getModuleMutable(), effects);
380+
transform::producesHandle(getOperation()->getOpResults(), effects);
381+
transform::modifiesPayload(effects);
382+
}
383+
333384
//===----------------------------------------------------------------------===//
334385
// Transform op registration
335386
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)