diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td index 4062f310c6521..b64b3fcdb275b 100644 --- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td +++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td @@ -134,4 +134,30 @@ def ReplaceFuncSignatureOp }]; } +def DeduplicateFuncArgsOp + : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + This transform takes a module and a function name, and deduplicates + the arguments of the function. The function is expected to be defined in + the module. + + This transform will emit a silenceable failure if: + - The function with the given name does not exist in the module. + - The function does not have duplicate arguments. + - The function does not have a single call. + }]; + + let arguments = (ins TransformHandleTypeInterface:$module, + SymbolRefAttr:$function_name); + let results = (outs TransformHandleTypeInterface:$transformed_module, + TransformHandleTypeInterface:$transformed_function); + + let assemblyFormat = [{ + $function_name + `at` $module attr-dict `:` functional-type(operands, results) + }]; +} + #endif // FUNC_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h index 2e8b6723a0e53..3576126a487ac 100644 --- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h @@ -18,32 +18,49 @@ #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/ArrayRef.h" +#include namespace mlir { +class ModuleOp; + namespace func { class FuncOp; class CallOp; /// Creates a new function operation with the same name as the original -/// function operation, but with the arguments reordered according to -/// the `newArgsOrder` and `newResultsOrder`. +/// function operation, but with the arguments mapped according to +/// the `oldArgToNewArg` and `oldResToNewRes`. /// The `funcOp` operation must have exactly one block. /// Returns the new function operation or failure if `funcOp` doesn't /// have exactly one block. -FailureOr -replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp, - llvm::ArrayRef newArgsOrder, - llvm::ArrayRef newResultsOrder); +/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes` +/// maps the whole function arguments and results. +mlir::FailureOr replaceFuncWithNewMapping( + mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, + ArrayRef oldArgIdxToNewArgIdx, ArrayRef oldResIdxToNewResIdx); /// Creates a new call operation with the values as the original -/// call operation, but with the arguments reordered according to -/// the `newArgsOrder` and `newResultsOrder`. -CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp, - llvm::ArrayRef newArgsOrder, - llvm::ArrayRef newResultsOrder); +/// call operation, but with the arguments mapped according to +/// the `oldArgToNewArg` and `oldResToNewRes`. +/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes` +/// maps the whole call operation arguments and results. +mlir::func::CallOp replaceCallOpWithNewMapping( + mlir::RewriterBase &rewriter, mlir::func::CallOp callOp, + ArrayRef oldArgIdxToNewArgIdx, ArrayRef oldResIdxToNewResIdx); + +/// This utility function examines all call operations within the given +/// `moduleOp` that target the specified `funcOp`. It identifies duplicate +/// operands in the call operations, creates mappings to deduplicate them, and +/// then applies the transformation to both the function and its call sites. For +/// now, it only supports one call operation for the function operation. The +/// function returns a pair containing the new funcOp and the new callOp. Note: +/// after the transformation, the original funcOp and callOp will be erased. +mlir::FailureOr> +deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, + mlir::ModuleOp moduleOp); } // namespace func } // namespace mlir -#endif // MLIR_DIALECT_FUNC_UTILS_H +#endif // MLIR_DIALECT_FUNC_UTILS_H \ No newline at end of file diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp index 935d3e5ac331b..3a42d2a367d70 100644 --- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp +++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; @@ -296,9 +297,16 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter, } } - FailureOr newFuncOpOrFailure = func::replaceFuncWithNewOrder( - rewriter, funcOp, argsInterchange.getArrayRef(), - resultsInterchange.getArrayRef()); + llvm::SmallVector oldArgToNewArg(argsInterchange.size()); + for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(argsInterchange)) + oldArgToNewArg[oldArgIdx] = newArgIdx; + + llvm::SmallVector oldResToNewRes(resultsInterchange.size()); + for (auto [newResIdx, oldResIdx] : llvm::enumerate(resultsInterchange)) + oldResToNewRes[oldResIdx] = newResIdx; + + FailureOr newFuncOpOrFailure = func::replaceFuncWithNewMapping( + rewriter, funcOp, oldArgToNewArg, oldResToNewRes); if (failed(newFuncOpOrFailure)) return emitSilenceableFailure(getLoc()) << "failed to replace function signature '" << getFunctionName() @@ -312,9 +320,8 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter, }); for (func::CallOp callOp : callOps) - func::replaceCallOpWithNewOrder(rewriter, callOp, - argsInterchange.getArrayRef(), - resultsInterchange.getArrayRef()); + func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgToNewArg, + oldResToNewRes); } results.set(cast(getTransformedModule()), {targetModuleOp}); @@ -330,6 +337,51 @@ void transform::ReplaceFuncSignatureOp::getEffects( transform::modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// DeduplicateFuncArgsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto payloadOps = state.getPayloadOps(getModule()); + if (!llvm::hasSingleElement(payloadOps)) + return emitDefiniteFailure() << "requires a single module to operate on"; + + auto targetModuleOp = dyn_cast(*payloadOps.begin()); + if (!targetModuleOp) + return emitSilenceableFailure(getLoc()) + << "target is expected to be module operation"; + + func::FuncOp funcOp = + targetModuleOp.lookupSymbol(getFunctionName()); + if (!funcOp) + return emitSilenceableFailure(getLoc()) + << "function with name '" << getFunctionName() << "' is not found"; + + auto transformationResult = + func::deduplicateArgsOfFuncOp(rewriter, funcOp, targetModuleOp); + if (failed(transformationResult)) + return emitSilenceableFailure(getLoc()) + << "failed to deduplicate function arguments of function " + << funcOp.getName(); + + auto [newFuncOp, newCallOp] = *transformationResult; + + results.set(cast(getTransformedModule()), {targetModuleOp}); + results.set(cast(getTransformedFunction()), {newFuncOp}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::DeduplicateFuncArgsOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getModuleMutable(), effects); + transform::producesHandle(getOperation()->getOpResults(), effects); + transform::modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp index f781ed2d591b4..b4cb0932ef631 100644 --- a/mlir/lib/Dialect/Func/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp @@ -14,35 +14,101 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "func-utils" using namespace mlir; +/// This method creates an inverse mapping of the provided map `oldToNew`. +/// Given an array where `oldIdxToNewIdx[i] = j` means old index `i` maps +/// to new index `j`, +/// This method returns a vector where `result[j]` contains all old indices +/// that map to new index `j`. +/// +/// Example: +/// ``` +/// oldIdxToNewIdx = [0, 1, 2, 2, 3] +/// getInverseMapping(oldIdxToNewIdx) = [[0], [1], [2, 3], [4]] +/// ``` +/// +static llvm::SmallVector> +getInverseMapping(ArrayRef oldIdxToNewIdx) { + int numOfNewIdxs = 0; + if (!oldIdxToNewIdx.empty()) + numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx); + llvm::SmallVector> newToOldIdxs(numOfNewIdxs); + for (auto [oldIdx, newIdx] : llvm::enumerate(oldIdxToNewIdx)) + newToOldIdxs[newIdx].push_back(oldIdx); + return newToOldIdxs; +} + +/// This method returns a new vector of elements that are mapped from the +/// `origElements` based on the `newIdxToOldIdxs` mapping. This function assumes +/// that the `newIdxToOldIdxs` mapping is valid, i.e. for each new index, there +/// is at least one old index that maps to it. Also, It assumes that mapping to +/// the same old index has the same element in the `origElements` vector. +template +static SmallVector getMappedElements( + ArrayRef origElements, + const llvm::SmallVector> &newIdxToOldIdxs) { + SmallVector newElements; + for (const auto &oldIdxs : newIdxToOldIdxs) { + assert(llvm::all_of(oldIdxs, + [&origElements](int idx) -> bool { + return idx >= 0 && + static_cast(idx) < origElements.size(); + }) && + "idx must be less than the number of elements in the original " + "elements"); + assert(!oldIdxs.empty() && "oldIdx must not be empty"); + Element origTypeToCheck = origElements[oldIdxs.front()]; + assert(llvm::all_of(oldIdxs, + [&](int idx) -> bool { + return origElements[idx] == origTypeToCheck; + }) && + "all oldIdxs must be equal"); + newElements.push_back(origTypeToCheck); + } + return newElements; +} + FailureOr -func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp, - ArrayRef newArgsOrder, - ArrayRef newResultsOrder) { +func::replaceFuncWithNewMapping(RewriterBase &rewriter, func::FuncOp funcOp, + ArrayRef oldArgIdxToNewArgIdx, + ArrayRef oldResIdxToNewResIdx) { // Generate an empty new function operation with the same name as the // original. - assert(funcOp.getNumArguments() == newArgsOrder.size() && - "newArgsOrder must match the number of arguments in the function"); - assert(funcOp.getNumResults() == newResultsOrder.size() && - "newResultsOrder must match the number of results in the function"); + assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() && + "oldArgIdxToNewArgIdx must match the number of arguments in the " + "function"); + assert( + funcOp.getNumResults() == oldResIdxToNewResIdx.size() && + "oldResIdxToNewResIdx must match the number of results in the function"); if (!funcOp.getBody().hasOneBlock()) return rewriter.notifyMatchFailure( funcOp, "expected function to have exactly one block"); - ArrayRef origInputTypes = funcOp.getFunctionType().getInputs(); - ArrayRef origOutputTypes = funcOp.getFunctionType().getResults(); - SmallVector newInputTypes, newOutputTypes; + // We may have some duplicate arguments in the old function, i.e. + // in the mapping `newArgIdxToOldArgIdxs` for some new argument index + // there may be multiple old argument indices. + llvm::SmallVector> newArgIdxToOldArgIdxs = + getInverseMapping(oldArgIdxToNewArgIdx); + SmallVector newInputTypes = getMappedElements( + funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs); + SmallVector locs; - for (unsigned int idx : newArgsOrder) { - newInputTypes.push_back(origInputTypes[idx]); - locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc()); - } - for (unsigned int idx : newResultsOrder) - newOutputTypes.push_back(origOutputTypes[idx]); + for (const auto &oldArgIdxs : newArgIdxToOldArgIdxs) + locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc()); + + llvm::SmallVector> newResToOldResIdxs = + getInverseMapping(oldResIdxToNewResIdx); + SmallVector newOutputTypes = getMappedElements( + funcOp.getFunctionType().getResults(), newResToOldResIdxs); + rewriter.setInsertionPoint(funcOp); auto newFuncOp = func::FuncOp::create( rewriter, funcOp.getLoc(), funcOp.getName(), @@ -51,21 +117,21 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp, Region &newRegion = newFuncOp.getBody(); rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs); newFuncOp.setVisibility(funcOp.getVisibility()); - newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary()); // Map the arguments of the original function to the new function in // the new order and adjust the attributes accordingly. IRMapping operandMapper; SmallVector argAttrs, resultAttrs; funcOp.getAllArgAttrs(argAttrs); - for (unsigned int i = 0; i < newArgsOrder.size(); ++i) { - operandMapper.map(funcOp.getArgument(newArgsOrder[i]), - newFuncOp.getArgument(i)); - newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]); - } + for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgIdxToNewArgIdx)) + operandMapper.map(funcOp.getArgument(oldArgIdx), + newFuncOp.getArgument(newArgIdx)); + for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(newArgIdxToOldArgIdxs)) + newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]); + funcOp.getAllResultAttrs(resultAttrs); - for (unsigned int i = 0; i < newResultsOrder.size(); ++i) - newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]); + for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs)) + newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]); // Clone the operations from the original function to the new function. rewriter.setInsertionPointToStart(&newFuncOp.getBody().front()); @@ -76,12 +142,11 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp, auto returnOp = cast( newFuncOp.getFunctionBody().begin()->getTerminator()); SmallVector newReturnValues; - for (unsigned int idx : newResultsOrder) - newReturnValues.push_back(returnOp.getOperand(idx)); + for (const auto &oldResIdxs : newResToOldResIdxs) + newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front())); + rewriter.setInsertionPoint(returnOp); - auto newReturnOp = - func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues); - newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary()); + func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues); rewriter.eraseOp(returnOp); rewriter.eraseOp(funcOp); @@ -90,33 +155,102 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp, } func::CallOp -func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp, - ArrayRef newArgsOrder, - ArrayRef newResultsOrder) { - assert( - callOp.getNumOperands() == newArgsOrder.size() && - "newArgsOrder must match the number of operands in the call operation"); - assert( - callOp.getNumResults() == newResultsOrder.size() && - "newResultsOrder must match the number of results in the call operation"); - SmallVector newArgsOrderValues; - for (unsigned int argIdx : newArgsOrder) - newArgsOrderValues.push_back(callOp.getOperand(argIdx)); - SmallVector newResultTypes; - for (unsigned int resIdx : newResultsOrder) - newResultTypes.push_back(callOp.getResult(resIdx).getType()); +func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp, + ArrayRef oldArgIdxToNewArgIdx, + ArrayRef oldResIdxToNewResIdx) { + assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() && + "oldArgIdxToNewArgIdx must match the number of operands in the call " + "operation"); + assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() && + "oldResIdxToNewResIdx must match the number of results in the call " + "operation"); + + SmallVector origOperands = callOp.getOperands(); + SmallVector> newArgIdxToOldArgIdxs = + getInverseMapping(oldArgIdxToNewArgIdx); + SmallVector newOperandsValues = + getMappedElements(origOperands, newArgIdxToOldArgIdxs); + SmallVector> newResToOldResIdxs = + getInverseMapping(oldResIdxToNewResIdx); + SmallVector origResultTypes = llvm::to_vector(callOp.getResultTypes()); + SmallVector newResultTypes = + getMappedElements(origResultTypes, newResToOldResIdxs); // Replace the kernel call operation with a new one that has the - // reordered arguments. + // mapped arguments. rewriter.setInsertionPoint(callOp); auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(), - newResultTypes, newArgsOrderValues); + newResultTypes, newOperandsValues); newCallOp.setNoInlineAttr(callOp.getNoInlineAttr()); - for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder)) - rewriter.replaceAllUsesWith(callOp.getResult(origIndex), - newCallOp.getResult(newIndex)); + for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx)) + rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx), + newCallOp.getResult(newResIdx)); rewriter.eraseOp(callOp); return newCallOp; } + +FailureOr> +func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp, + ModuleOp moduleOp) { + SmallVector callOps; + auto traversalResult = moduleOp.walk([&](func::CallOp callOp) { + if (callOp.getCallee() == funcOp.getSymName()) { + if (!callOps.empty()) + // Only support one callOp for now + return WalkResult::interrupt(); + callOps.push_back(callOp); + } + return WalkResult::advance(); + }); + + if (traversalResult.wasInterrupted()) { + LDBG() << "function " << funcOp.getName() << " has more than one callOp"; + return failure(); + } + + if (callOps.empty()) { + LDBG() << "function " << funcOp.getName() << " does not have any callOp"; + return failure(); + } + + func::CallOp callOp = callOps.front(); + + // Create mapping for arguments (deduplicate operands) + SmallVector oldArgIdxToNewArgIdx(callOp.getNumOperands()); + llvm::DenseMap valueToNewArgIdx; + for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) { + auto [iterator, inserted] = valueToNewArgIdx.insert( + {operand, static_cast(valueToNewArgIdx.size())}); + // Reduce the duplicate operands and maintain the original order. + oldArgIdxToNewArgIdx[operandIdx] = iterator->second; + } + + bool hasDuplicateOperands = + valueToNewArgIdx.size() != callOp.getNumOperands(); + if (!hasDuplicateOperands) { + LDBG() << "function " << funcOp.getName() + << " does not have duplicate operands"; + return failure(); + } + + // Create identity mapping for results (no deduplication needed) + SmallVector oldResIdxToNewResIdx(callOp.getNumResults()); + for (int resultIdx : llvm::seq(0, callOp.getNumResults())) + oldResIdxToNewResIdx[resultIdx] = resultIdx; + + // Apply the transformation to create new function and call operations + FailureOr newFuncOpOrFailure = replaceFuncWithNewMapping( + rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx); + if (failed(newFuncOpOrFailure)) { + LDBG() << "failed to replace function signature with name " + << funcOp.getName() << " with new order"; + return failure(); + } + + func::CallOp newCallOp = replaceCallOpWithNewMapping( + rewriter, callOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx); + + return std::make_pair(*newFuncOpOrFailure, newCallOp); +} diff --git a/mlir/test/Dialect/Func/func-transform-invalid.mlir b/mlir/test/Dialect/Func/func-transform-invalid.mlir index e712eee83f36e..29bd58ab52742 100644 --- a/mlir/test/Dialect/Func/func-transform-invalid.mlir +++ b/mlir/test/Dialect/Func/func-transform-invalid.mlir @@ -85,3 +85,92 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func private @func_with_no_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) { + %c0 = arith.constant 0 : index + %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1> + %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1> + return +} + +func.func @func_with_no_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) { + call @func_with_no_duplicate_args(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op) { + // expected-error @+1 {{failed to deduplicate function arguments of function func_with_no_duplicate_args}} + transform.func.deduplicate_func_args @func_with_no_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func private @func_not_found(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) { + %c0 = arith.constant 0 : index + %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1> + %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op) { + // expected-error @+1 {{function with name '@non_existent_func' is not found}} + transform.func.deduplicate_func_args @non_existent_func at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func private @func_with_multiple_calls(%arg0: memref<1xi8, 1>, %arg1: memref<1xi8, 1>) { + %c0 = arith.constant 0 : index + %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + %view0 = memref.view %arg1[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + return +} + +func.func @func_with_multiple_calls_caller1(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) { + call @func_with_multiple_calls(%arg0, %arg0) : (memref<1xi8, 1>, memref<1xi8, 1>) -> () + return +} + +func.func @func_with_multiple_calls_caller2(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) { + call @func_with_multiple_calls(%arg0, %arg0) : (memref<1xi8, 1>, memref<1xi8, 1>) -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op) { + // expected-error @+1 {{failed to deduplicate function arguments of function func_with_multiple_calls}} + transform.func.deduplicate_func_args @func_with_multiple_calls at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func private @func_with_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<1xi8, 1>) { + %c0 = arith.constant 0 : index + %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + %view0 = memref.view %arg1[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + return +} + +func.func @some_other_func() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op) { + // expected-error @+1 {{failed to deduplicate function arguments of function func_with_no_calls}} + transform.func.deduplicate_func_args @func_with_no_calls at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir index 36a66aaa95bfb..8a71511e3ed5b 100644 --- a/mlir/test/Dialect/Func/func-transform.mlir +++ b/mlir/test/Dialect/Func/func-transform.mlir @@ -250,3 +250,65 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK: func.func private @func_with_duplicate_args(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>) { +func.func private @func_with_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<1xi8, 1>) { + %c0 = arith.constant 0 : index + // CHECK: %[[VAL_3:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0:.*]]][] : memref<1xi8, 1> to memref<1xi8, 1> + %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + // CHECK: %[[VAL_4:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1> + %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1> + // CHECK: %[[VAL_5:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1> + %view1 = memref.view %arg2[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + return +} + +// CHECK: func.func @func_with_duplicate_args_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>) { +func.func @func_with_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>) { + // CHECK: call @func_with_duplicate_args(%[[ARG0]], %[[ARG1]]) : (memref<1xi8, 1>, memref<2xi8, 1>) -> () + call @func_with_duplicate_args(%arg0, %arg1, %arg0) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>) -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op) { + transform.func.deduplicate_func_args @func_with_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK: func.func private @func_with_complex_duplicate_args(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) { +func.func private @func_with_complex_duplicate_args(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<1xi8, 1>, %arg3: memref<3xi8, 1>, %arg4: memref<2xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) { + %c0 = arith.constant 0 : index + // CHECK: %[[RET_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0:.*]]][] : memref<1xi8, 1> to memref<1xi8, 1> + %view0 = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + // CHECK: %[[RET_1:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1> + %view1 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1> + // CHECK: %[[RET_2:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1> + %view2 = memref.view %arg2[%c0][] : memref<1xi8, 1> to memref<1xi8, 1> + // CHECK: %[[RET_3:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1> + %view3 = memref.view %arg3[%c0][] : memref<3xi8, 1> to memref<3xi8, 1> + // CHECK: %[[RET_4:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1> + %view4 = memref.view %arg4[%c0][] : memref<2xi8, 1> to memref<2xi8, 1> + // CHECK: return %[[RET_0]], %[[RET_1]], %[[RET_2]], %[[RET_3]], %[[RET_4]] : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1> + return %view0, %view1, %view2, %view3, %view4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1> +} + +// CHECK: func.func @func_with_complex_duplicate_args_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) { +func.func @func_with_complex_duplicate_args_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) { + // CHECK: %[[RET:.*]]:5 = call @func_with_complex_duplicate_args(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) + %0:5 = call @func_with_complex_duplicate_args(%arg0, %arg1, %arg0, %arg2, %arg1) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) + // CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2, %[[RET]]#3, %[[RET]]#4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1> + return %0#0, %0#1, %0#2, %0#3, %0#4 : memref<1xi8, 1>, memref<2xi8, 1>, memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op) { + transform.func.deduplicate_func_args @func_with_complex_duplicate_args at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +}