-
Notifications
You must be signed in to change notification settings - Fork 14k
[mlir][func]: Introduce ReplaceFuncSignature tranform operation #143381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-func Author: Aviad Cohen (AviadCo) ChangesThis transform takes a module and a function name, and replaces the signature of the function by reordering the arguments and results according to the interchange arrays. The function is expected to be defined in the module, and the interchange arrays must match the number of arguments and results of the function. Patch is 25.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143381.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 306fbf881de61..1cb9ca7418057 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -98,4 +98,41 @@ def CastAndCallOp : Op<Transform_Dialect,
let hasVerifier = 1;
}
+def ReplaceFuncSignatureOp : Op<Transform_Dialect,
+ "func.replace_func_signature",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ This transform takes a module and a function name, and replaces the
+ signature of the function by reordering the arguments and results
+ according to the interchange arrays. The function is expected to be
+ defined in the module, and the interchange arrays must match the number
+ of arguments and results of the function.
+
+ The `adjust_func_calls` attribute indicates whether the function calls
+ should be adjusted to match the new signature. If set to `true`, the
+ function calls will be adjusted to match the new signature, otherwise
+ they will not be adjusted.
+
+ This transform will emit a silenceable failure if:
+ - The function with the given name does not exist in the module.
+ - The interchange arrays do not match the number of arguments/results.
+ - The interchange arrays contain out of bound indices.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$module,
+ SymbolRefAttr:$function_name,
+ DenseI32ArrayAttr:$args_interchange,
+ DenseI32ArrayAttr:$results_interchange,
+ UnitAttr:$adjust_func_calls);
+ let results = (outs );
+
+ let assemblyFormat = [{
+ $function_name
+ `args_interchange` `=` $args_interchange
+ `results_interchange` `=` $results_interchange
+ `at` $module attr-dict `:` functional-type(operands, results)
+ }];
+}
+
#endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
new file mode 100644
index 0000000000000..ff0746f55a96f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -0,0 +1,42 @@
+//===- Utils.h - General Func transformation utilities ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes for various transformation utilities for
+// the Func dialect. These are not passes by themselves but are used
+// either by passes, optimization sequences, or in turn by other transformation
+// utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FUNC_UTILS_H
+#define MLIR_DIALECT_FUNC_UTILS_H
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+namespace mlir {
+
+namespace func {
+
+// Creates a new function operation with the same name as the original
+// function operation, but with the arguments reordered according to
+// the `newArgsOrder` and `newResultsOrder`.
+mlir::func::FuncOp replaceFuncWithNewOrder(mlir::func::FuncOp funcOp,
+ mlir::ArrayRef<int> newArgsOrder,
+ mlir::ArrayRef<int> newResultsOrder);
+// Creates a new call operation with the values as the original
+// call operation, but with the arguments reordered according to
+// the `newArgsOrder` and `newResultsOrder`.
+mlir::func::CallOp
+replaceCallOpWithNewOrder(mlir::func::CallOp callOp,
+ mlir::ArrayRef<int> newArgsOrder,
+ mlir::ArrayRef<int> newResultsOrder);
+
+} // namespace func
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FUNC_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..c5b657aefc0e3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -366,8 +366,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
let arguments =
(ins TransformHandleTypeInterface:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
let results = (outs TransformHandleTypeInterface:$transformed,
Variadic<TransformHandleTypeInterface>:$loops);
diff --git a/mlir/lib/Dialect/Func/CMakeLists.txt b/mlir/lib/Dialect/Func/CMakeLists.txt
index ec999ffdb99da..a834aae8fbf81 100644
--- a/mlir/lib/Dialect/Func/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/CMakeLists.txt
@@ -2,3 +2,4 @@ add_subdirectory(Extensions)
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 9966d7339e1b4..0a814f7cfdd13 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -11,6 +11,7 @@
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
@@ -226,6 +227,98 @@ void transform::CastAndCallOp::getEffects(
transform::modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// ReplaceFuncSignatureOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payloadOps = state.getPayloadOps(getModule());
+ if (!llvm::hasSingleElement(payloadOps))
+ return emitDefiniteFailure() << "requires a single module to operate on";
+
+ auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
+ if (!targetModuleOp)
+ return emitSilenceableFailure(getLoc())
+ << "target is expected to be module operation";
+
+ func::FuncOp funcOp =
+ targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
+ if (!funcOp)
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName() << "' not found";
+
+ int numArgs = funcOp.getNumArguments();
+ int numResults = funcOp.getNumResults();
+ // Check that the number of arguments and results matches the
+ // interchange sizes.
+ if (numArgs != (int)getArgsInterchange().size())
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName() << "' has " << numArgs
+ << " arguments, but " << getArgsInterchange().size()
+ << " args interchange were given";
+
+ if (numResults != (int)getResultsInterchange().size())
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName() << "' has "
+ << numResults << " results, but " << getResultsInterchange().size()
+ << " results interchange were given";
+
+ // Check that the args and results interchanges are unique.
+ SetVector<int> argsInterchange, resultsInterchange;
+ argsInterchange.insert_range(getArgsInterchange());
+ resultsInterchange.insert_range(getResultsInterchange());
+ if (argsInterchange.size() != getArgsInterchange().size())
+ return emitSilenceableFailure(getLoc())
+ << "args interchange must be unique";
+
+ if (resultsInterchange.size() != getResultsInterchange().size())
+ return emitSilenceableFailure(getLoc())
+ << "results interchange must be unique";
+
+ // Check that the args and results interchange indices are in bounds.
+ for (auto index : argsInterchange) {
+ if (index < 0 || index >= numArgs) {
+ return emitSilenceableFailure(getLoc())
+ << "args interchange index " << index
+ << " is out of bounds for function with name '"
+ << getFunctionName() << "' with " << numArgs << " arguments";
+ }
+ }
+ for (auto index : resultsInterchange) {
+ if (index < 0 || index >= numResults) {
+ return emitSilenceableFailure(getLoc())
+ << "results interchange index " << index
+ << " is out of bounds for function with name '"
+ << getFunctionName() << "' with " << numResults << " results";
+ }
+ }
+
+ func::replaceFuncWithNewOrder(funcOp, argsInterchange.getArrayRef(),
+ resultsInterchange.getArrayRef());
+ if (getAdjustFuncCalls()) {
+ SmallVector<func::CallOp> callOps;
+ targetModuleOp.walk([&](func::CallOp callOp) {
+ if (callOp.getCallee() == getFunctionName().getRootReference().str())
+ callOps.push_back(callOp);
+ });
+
+ for (auto callOp : callOps)
+ func::replaceCallOpWithNewOrder(callOp, argsInterchange.getArrayRef(),
+ resultsInterchange.getArrayRef());
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ReplaceFuncSignatureOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getModuleMutable(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Func/Utils/CMakeLists.txt b/mlir/lib/Dialect/Func/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..e39a8c8c25d03
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Utils/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRFuncUtils
+ Utils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Utils
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRComplexDialect
+ MLIRDialect
+ MLIRDialectUtils
+ MLIRIR
+ )
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
new file mode 100644
index 0000000000000..3bae13c354b20
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -0,0 +1,103 @@
+//===- Utils.cpp - Utilities to support the Func dialect ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utilities for the Func dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+func::FuncOp func::replaceFuncWithNewOrder(func::FuncOp funcOp,
+ ArrayRef<int> newArgsOrder,
+ ArrayRef<int> newResultsOrder) {
+ // Generate an empty new function operation with the same name as the
+ // original.
+ assert(funcOp.getNumArguments() == newArgsOrder.size());
+ assert(funcOp.getNumResults() == newResultsOrder.size());
+ auto origInputTypes = funcOp.getFunctionType().getInputs();
+ auto origOutputTypes = funcOp.getFunctionType().getResults();
+ SmallVector<Type> newInputTypes, newOutputTypes;
+ for (unsigned int i = 0; i < origInputTypes.size(); ++i)
+ newInputTypes.push_back(origInputTypes[newArgsOrder[i]]);
+ for (unsigned int i = 0; i < origOutputTypes.size(); ++i)
+ newOutputTypes.push_back(origOutputTypes[newResultsOrder[i]]);
+ IRRewriter rewriter(funcOp);
+ rewriter.setInsertionPoint(funcOp);
+ auto newFuncOp = rewriter.create<func::FuncOp>(
+ funcOp.getLoc(), funcOp.getName(),
+ rewriter.getFunctionType(newInputTypes, newOutputTypes));
+ newFuncOp.addEntryBlock();
+ newFuncOp.setVisibility(funcOp.getVisibility());
+ newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
+
+ // Map the arguments of the original function to the new function in
+ // the new order and adjust the attributes accordingly.
+ IRMapping operandMapper;
+ SmallVector<DictionaryAttr> argAttrs, resultAttrs;
+ funcOp.getAllArgAttrs(argAttrs);
+ for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
+ operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
+ newFuncOp.getArgument(i));
+ newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
+ }
+ funcOp.getAllResultAttrs(resultAttrs);
+ for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
+ newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
+
+ // Clone the operations from the original function to the new function.
+ rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
+ for (Operation &op : funcOp.getOps())
+ rewriter.clone(op, operandMapper);
+
+ // Handle the return operation.
+ auto returnOp = cast<func::ReturnOp>(
+ newFuncOp.getFunctionBody().begin()->getTerminator());
+ SmallVector<Value> newReturnValues;
+ for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
+ newReturnValues.push_back(returnOp.getOperand(newResultsOrder[i]));
+ rewriter.setInsertionPoint(returnOp);
+ auto newReturnOp =
+ rewriter.create<func::ReturnOp>(newFuncOp.getLoc(), newReturnValues);
+ newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
+ rewriter.eraseOp(returnOp);
+
+ rewriter.eraseOp(funcOp);
+
+ return newFuncOp;
+}
+
+func::CallOp func::replaceCallOpWithNewOrder(func::CallOp callOp,
+ ArrayRef<int> newArgsOrder,
+ ArrayRef<int> newResultsOrder) {
+ assert(callOp.getNumOperands() == newArgsOrder.size());
+ assert(callOp.getNumResults() == newResultsOrder.size());
+ IRRewriter rewriter(callOp);
+ SmallVector<Value> newArgsOrderValues;
+ for (auto argIdx : newArgsOrder)
+ newArgsOrderValues.push_back(callOp.getOperand(argIdx));
+ SmallVector<Type> newResultTypes;
+ for (auto resIdx : newResultsOrder)
+ newResultTypes.push_back(callOp.getResult(resIdx).getType());
+
+ // Replace the kernel call operation with a new one that has the
+ // reordered arguments.
+ auto newCallOp = rewriter.create<func::CallOp>(
+ callOp.getLoc(), callOp.getCallee(), newResultTypes, newArgsOrderValues);
+ newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
+ newCallOp->setDiscardableAttrs(callOp->getDiscardableAttrDictionary());
+ for (auto [newIndex, origIndex] : llvm::enumerate(newResultsOrder))
+ rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
+ newCallOp.getResult(newIndex));
+ rewriter.eraseOp(callOp);
+
+ return newCallOp;
+}
diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
index 6aab07b0cb38a..9cb91055d6143 100644
--- a/mlir/test/Dialect/Func/func-transform.mlir
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -118,3 +118,135 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+module {
+ // CHECK: func.func private @func_with_reverse_order_no_result_no_calls(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) {
+ func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[VAL_4:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+ %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+ // CHECK: %[[VAL_5:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+ %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+ // CHECK: %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+ %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+ return
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
+ transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2, 1] results_interchange = [] at %module : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module {
+ // CHECK: func.func private @func_with_reverse_order_no_result(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) {
+ func.func private @func_with_reverse_order_no_result(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[VAL_4:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+ %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+ // CHECK: %[[VAL_5:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+ %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+ // CHECK: %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+ %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+ return
+ }
+
+ // CHECK: func.func @func_with_reverse_order_no_result_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) {
+ func.func @func_with_reverse_order_no_result_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+ // CHECK: call @func_with_reverse_order_no_result(%[[ARG0]], %[[ARG2]], %[[ARG1]]) : (memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> ()
+ call @func_with_reverse_order_no_result(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> ()
+ return
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %module = transform.get_parent_op %f#0 : (!transform.any_op) -> !transform.any_op
+ transform.func.replace_func_signature @func_with_reverse_order_no_result args_interchange = [0, 2, 1] results_interchange = [] at %module {adjust_func_calls} : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module {
+ // CHECK: func.func private @func_with_reverse_order(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) -> (memref<2xi8, 1>, memref<1xi8, 1>) {
+ func.func private @func_with_reverse_order(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : inde...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Aviad Cohen (AviadCo) ChangesThis transform takes a module and a function name, and replaces the signature of the function by reordering the arguments and results according to the interchange arrays. The function is expected to be defined in the module, and the interchange arrays must match the number of arguments and results of the function. Patch is 25.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143381.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 306fbf881de61..1cb9ca7418057 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -98,4 +98,41 @@ def CastAndCallOp : Op<Transform_Dialect,
let hasVerifier = 1;
}
+def ReplaceFuncSignatureOp : Op<Transform_Dialect,
+ "func.replace_func_signature",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ This transform takes a module and a function name, and replaces the
+ signature of the function by reordering the arguments and results
+ according to the interchange arrays. The function is expected to be
+ defined in the module, and the interchange arrays must match the number
+ of arguments and results of the function.
+
+ The `adjust_func_calls` attribute indicates whether the function calls
+ should be adjusted to match the new signature. If set to `true`, the
+ function calls will be adjusted to match the new signature, otherwise
+ they will not be adjusted.
+
+ This transform will emit a silenceable failure if:
+ - The function with the given name does not exist in the module.
+ - The interchange arrays do not match the number of arguments/results.
+ - The interchange arrays contain out of bound indices.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$module,
+ SymbolRefAttr:$function_name,
+ DenseI32ArrayAttr:$args_interchange,
+ DenseI32ArrayAttr:$results_interchange,
+ UnitAttr:$adjust_func_calls);
+ let results = (outs );
+
+ let assemblyFormat = [{
+ $function_name
+ `args_interchange` `=` $args_interchange
+ `results_interchange` `=` $results_interchange
+ `at` $module attr-dict `:` functional-type(operands, results)
+ }];
+}
+
#endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
new file mode 100644
index 0000000000000..ff0746f55a96f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -0,0 +1,42 @@
+//===- Utils.h - General Func transformation utilities ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes for various transformation utilities for
+// the Func dialect. These are not passes by themselves but are used
+// either by passes, optimization sequences, or in turn by other transformation
+// utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FUNC_UTILS_H
+#define MLIR_DIALECT_FUNC_UTILS_H
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+namespace mlir {
+
+namespace func {
+
+// Creates a new function operation with the same name as the original
+// function operation, but with the arguments reordered according to
+// the `newArgsOrder` and `newResultsOrder`.
+mlir::func::FuncOp replaceFuncWithNewOrder(mlir::func::FuncOp funcOp,
+ mlir::ArrayRef<int> newArgsOrder,
+ mlir::ArrayRef<int> newResultsOrder);
+// Creates a new call operation with the values as the original
+// call operation, but with the arguments reordered according to
+// the `newArgsOrder` and `newResultsOrder`.
+mlir::func::CallOp
+replaceCallOpWithNewOrder(mlir::func::CallOp callOp,
+ mlir::ArrayRef<int> newArgsOrder,
+ mlir::ArrayRef<int> newResultsOrder);
+
+} // namespace func
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FUNC_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..c5b657aefc0e3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -366,8 +366,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
let arguments =
(ins TransformHandleTypeInterface:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
let results = (outs TransformHandleTypeInterface:$transformed,
Variadic<TransformHandleTypeInterface>:$loops);
diff --git a/mlir/lib/Dialect/Func/CMakeLists.txt b/mlir/lib/Dialect/Func/CMakeLists.txt
index ec999ffdb99da..a834aae8fbf81 100644
--- a/mlir/lib/Dialect/Func/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/CMakeLists.txt
@@ -2,3 +2,4 @@ add_subdirectory(Extensions)
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 9966d7339e1b4..0a814f7cfdd13 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -11,6 +11,7 @@
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
@@ -226,6 +227,98 @@ void transform::CastAndCallOp::getEffects(
transform::modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// ReplaceFuncSignatureOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payloadOps = state.getPayloadOps(getModule());
+ if (!llvm::hasSingleElement(payloadOps))
+ return emitDefiniteFailure() << "requires a single module to operate on";
+
+ auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
+ if (!targetModuleOp)
+ return emitSilenceableFailure(getLoc())
+ << "target is expected to be module operation";
+
+ func::FuncOp funcOp =
+ targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
+ if (!funcOp)
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName() << "' not found";
+
+ int numArgs = funcOp.getNumArguments();
+ int numResults = funcOp.getNumResults();
+ // Check that the number of arguments and results matches the
+ // interchange sizes.
+ if (numArgs != (int)getArgsInterchange().size())
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName() << "' has " << numArgs
+ << " arguments, but " << getArgsInterchange().size()
+ << " args interchange were given";
+
+ if (numResults != (int)getResultsInterchange().size())
+ return emitSilenceableFailure(getLoc())
+ << "function with name '" << getFunctionName() << "' has "
+ << numResults << " results, but " << getResultsInterchange().size()
+ << " results interchange were given";
+
+ // Check that the args and results interchanges are unique.
+ SetVector<int> argsInterchange, resultsInterchange;
+ argsInterchange.insert_range(getArgsInterchange());
+ resultsInterchange.insert_range(getResultsInterchange());
+ if (argsInterchange.size() != getArgsInterchange().size())
+ return emitSilenceableFailure(getLoc())
+ << "args interchange must be unique";
+
+ if (resultsInterchange.size() != getResultsInterchange().size())
+ return emitSilenceableFailure(getLoc())
+ << "results interchange must be unique";
+
+ // Check that the args and results interchange indices are in bounds.
+ for (auto index : argsInterchange) {
+ if (index < 0 || index >= numArgs) {
+ return emitSilenceableFailure(getLoc())
+ << "args interchange index " << index
+ << " is out of bounds for function with name '"
+ << getFunctionName() << "' with " << numArgs << " arguments";
+ }
+ }
+ for (auto index : resultsInterchange) {
+ if (index < 0 || index >= numResults) {
+ return emitSilenceableFailure(getLoc())
+ << "results interchange index " << index
+ << " is out of bounds for function with name '"
+ << getFunctionName() << "' with " << numResults << " results";
+ }
+ }
+
+ func::replaceFuncWithNewOrder(funcOp, argsInterchange.getArrayRef(),
+ resultsInterchange.getArrayRef());
+ if (getAdjustFuncCalls()) {
+ SmallVector<func::CallOp> callOps;
+ targetModuleOp.walk([&](func::CallOp callOp) {
+ if (callOp.getCallee() == getFunctionName().getRootReference().str())
+ callOps.push_back(callOp);
+ });
+
+ for (auto callOp : callOps)
+ func::replaceCallOpWithNewOrder(callOp, argsInterchange.getArrayRef(),
+ resultsInterchange.getArrayRef());
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ReplaceFuncSignatureOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getModuleMutable(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Func/Utils/CMakeLists.txt b/mlir/lib/Dialect/Func/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..e39a8c8c25d03
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Utils/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRFuncUtils
+ Utils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Utils
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRComplexDialect
+ MLIRDialect
+ MLIRDialectUtils
+ MLIRIR
+ )
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
new file mode 100644
index 0000000000000..3bae13c354b20
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -0,0 +1,103 @@
+//===- Utils.cpp - Utilities to support the Func dialect ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utilities for the Func dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+func::FuncOp func::replaceFuncWithNewOrder(func::FuncOp funcOp,
+ ArrayRef<int> newArgsOrder,
+ ArrayRef<int> newResultsOrder) {
+ // Generate an empty new function operation with the same name as the
+ // original.
+ assert(funcOp.getNumArguments() == newArgsOrder.size());
+ assert(funcOp.getNumResults() == newResultsOrder.size());
+ auto origInputTypes = funcOp.getFunctionType().getInputs();
+ auto origOutputTypes = funcOp.getFunctionType().getResults();
+ SmallVector<Type> newInputTypes, newOutputTypes;
+ for (unsigned int i = 0; i < origInputTypes.size(); ++i)
+ newInputTypes.push_back(origInputTypes[newArgsOrder[i]]);
+ for (unsigned int i = 0; i < origOutputTypes.size(); ++i)
+ newOutputTypes.push_back(origOutputTypes[newResultsOrder[i]]);
+ IRRewriter rewriter(funcOp);
+ rewriter.setInsertionPoint(funcOp);
+ auto newFuncOp = rewriter.create<func::FuncOp>(
+ funcOp.getLoc(), funcOp.getName(),
+ rewriter.getFunctionType(newInputTypes, newOutputTypes));
+ newFuncOp.addEntryBlock();
+ newFuncOp.setVisibility(funcOp.getVisibility());
+ newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
+
+ // Map the arguments of the original function to the new function in
+ // the new order and adjust the attributes accordingly.
+ IRMapping operandMapper;
+ SmallVector<DictionaryAttr> argAttrs, resultAttrs;
+ funcOp.getAllArgAttrs(argAttrs);
+ for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
+ operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
+ newFuncOp.getArgument(i));
+ newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
+ }
+ funcOp.getAllResultAttrs(resultAttrs);
+ for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
+ newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
+
+ // Clone the operations from the original function to the new function.
+ rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
+ for (Operation &op : funcOp.getOps())
+ rewriter.clone(op, operandMapper);
+
+ // Handle the return operation.
+ auto returnOp = cast<func::ReturnOp>(
+ newFuncOp.getFunctionBody().begin()->getTerminator());
+ SmallVector<Value> newReturnValues;
+ for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
+ newReturnValues.push_back(returnOp.getOperand(newResultsOrder[i]));
+ rewriter.setInsertionPoint(returnOp);
+ auto newReturnOp =
+ rewriter.create<func::ReturnOp>(newFuncOp.getLoc(), newReturnValues);
+ newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
+ rewriter.eraseOp(returnOp);
+
+ rewriter.eraseOp(funcOp);
+
+ return newFuncOp;
+}
+
+func::CallOp func::replaceCallOpWithNewOrder(func::CallOp callOp,
+ ArrayRef<int> newArgsOrder,
+ ArrayRef<int> newResultsOrder) {
+ assert(callOp.getNumOperands() == newArgsOrder.size());
+ assert(callOp.getNumResults() == newResultsOrder.size());
+ IRRewriter rewriter(callOp);
+ SmallVector<Value> newArgsOrderValues;
+ for (auto argIdx : newArgsOrder)
+ newArgsOrderValues.push_back(callOp.getOperand(argIdx));
+ SmallVector<Type> newResultTypes;
+ for (auto resIdx : newResultsOrder)
+ newResultTypes.push_back(callOp.getResult(resIdx).getType());
+
+ // Replace the kernel call operation with a new one that has the
+ // reordered arguments.
+ auto newCallOp = rewriter.create<func::CallOp>(
+ callOp.getLoc(), callOp.getCallee(), newResultTypes, newArgsOrderValues);
+ newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
+ newCallOp->setDiscardableAttrs(callOp->getDiscardableAttrDictionary());
+ for (auto [newIndex, origIndex] : llvm::enumerate(newResultsOrder))
+ rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
+ newCallOp.getResult(newIndex));
+ rewriter.eraseOp(callOp);
+
+ return newCallOp;
+}
diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
index 6aab07b0cb38a..9cb91055d6143 100644
--- a/mlir/test/Dialect/Func/func-transform.mlir
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -118,3 +118,135 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+module {
+ // CHECK: func.func private @func_with_reverse_order_no_result_no_calls(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) {
+ func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[VAL_4:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+ %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+ // CHECK: %[[VAL_5:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+ %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+ // CHECK: %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+ %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+ return
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
+ transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2, 1] results_interchange = [] at %module : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module {
+ // CHECK: func.func private @func_with_reverse_order_no_result(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) {
+ func.func private @func_with_reverse_order_no_result(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[VAL_4:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
+ %view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
+ // CHECK: %[[VAL_5:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
+ %view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
+ // CHECK: %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
+ %view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
+ return
+ }
+
+ // CHECK: func.func @func_with_reverse_order_no_result_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) {
+ func.func @func_with_reverse_order_no_result_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
+ // CHECK: call @func_with_reverse_order_no_result(%[[ARG0]], %[[ARG2]], %[[ARG1]]) : (memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> ()
+ call @func_with_reverse_order_no_result(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> ()
+ return
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %module = transform.get_parent_op %f#0 : (!transform.any_op) -> !transform.any_op
+ transform.func.replace_func_signature @func_with_reverse_order_no_result args_interchange = [0, 2, 1] results_interchange = [] at %module {adjust_func_calls} : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module {
+ // CHECK: func.func private @func_with_reverse_order(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) -> (memref<2xi8, 1>, memref<1xi8, 1>) {
+ func.func private @func_with_reverse_order(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : inde...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the patch!
Please address the comments and consider making this transformation operate on FunctionOpInterface
.
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
Thanks for review! |
9881989
to
144ee52
Compare
What I mean is the transformation logic itself, not the transformation operation. It should be applicable to, e.g., LLVM dialect functions as well. There's no reason why it should be specific to the function dialect functions. |
This transform takes a module and a function name, and replaces the signature of the function by reordering the arguments and results according to the interchange arrays. The function is expected to be defined in the module, and the interchange arrays must match the number of arguments and results of the function.
I see. |
This transform takes a module and a function name, and replaces the signature of the function by reordering the arguments and results according to the interchange arrays. The function is expected to be defined in the module, and the interchange arrays must match the number of arguments and results of the function.