From 2ea50df34f0263a8f0a99a60b855f8e52e0fceb2 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 16 Jan 2024 15:34:23 -0500 Subject: [PATCH 1/4] [mlir][transform] Add an op for replacing values with function calls Adds `transform.func.cast_and_call` that takes a set of inputs and outputs and replaces the uses of those outputs with a call to a function at a specified insertion point. The idea with this operation is to allow users to author independent IR outside of a to-be-compiled module, and then match and replace a slice of the program with a call to the external function. Additionally adds a mechanism for populating a type converter with a set of conversion materialization functions that allow insertion of casts on the inputs/outputs to and from the types of the function signature. --- .../Func/TransformOps/FuncTransformOps.td | 65 ++++++ .../Tensor/TransformOps/TensorTransformOps.td | 13 ++ .../Transform/IR/TransformInterfaces.td | 22 ++ .../Func/TransformOps/FuncTransformOps.cpp | 197 ++++++++++++++++++ .../TransformOps/TensorTransformOps.cpp | 40 ++++ .../lib/Dialect/Transform/IR/TransformOps.cpp | 4 + mlir/test/Dialect/Func/func-transform.mlir | 120 +++++++++++ .../Dialect/Tensor/transform-op-casting.mlir | 65 ++++++ 8 files changed, 526 insertions(+) create mode 100644 mlir/test/Dialect/Func/func-transform.mlir create mode 100644 mlir/test/Dialect/Tensor/transform-op-casting.mlir diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td index 7a7e991c78618..e5086c26c55a4 100644 --- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td +++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td @@ -12,6 +12,8 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/RegionKindInterface.td" include "mlir/IR/OpBase.td" def ApplyFuncToLLVMConversionPatternsOp : Op, + DeclareOpInterfaceMethods, + AttrSizedOperandSegments, + ReportTrackingListenerFailuresOpTrait] + # GraphRegionNoTerminator.traits> { + let summary = "Casts values to the signature of a function and replaces them " + "with a call"; + let description = [{ + This transform takes a set of |input| and |output| value handles and + attempts to cast them to the function signature of the attached function + op, then builds a call to the function and replaces the users of the + outputs. It is the responsibility of the user to ensure that the slice of + the program replaced by this operation makes sense, i.e. there is no + verification that the inputs to this operation have any relation to the + outputs outside of basic dominance requirements needed for the replacement. + + The casting materialization functions are specified in the graph region of + this op. They must implement the `TypeConversionOpInterface`. The order of + ops within the region is irrelevant. + + The target function can be specified by a symbol name or by a handle to the + operation. + + This transform only reads the target handles and only replaces the users of + the outputs with the results of the call. No handles are consumed and no + operations are removed. Users are expected to run cleanup separately if + desired. + + This transform will emit a silenceable failure if: + - The set of outputs isn't unique + - The handle for the insertion point does not include exactly one operation + - The insertion point op does not dominate any of the output users + - The insertion point op is not dominated by any of the inputs + - The function signature does not match the number of inputs/outputs + - Any of the input conversions fail to be materialized + + This transform will emit a definite failure if it fails to resolve the + target function, or if it fails to materialize the conversion from the call + results to the output types. + }]; + + let arguments = (ins + TransformHandleTypeInterface:$insertion_point, + UnitAttr:$insert_after, + Optional:$inputs, + Optional:$outputs, + OptionalAttr:$function_name, + Optional:$function); + let results = (outs TransformHandleTypeInterface:$result); + let regions = (region MaxSizedRegion<1>:$conversions); + + let assemblyFormat = [{ + ($function_name^)? ($function^)? + ( `(` $inputs^ `)` )? + ( `->` $outputs^ )? + (`after` $insert_after^):(`before`)? $insertion_point + ($conversions^)? attr-dict `:` functional-type(operands, results) + }]; + let hasVerifier = 1; +} + #endif // FUNC_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td index 8556d9570fd12..28e9249c82e30 100644 --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -169,4 +169,17 @@ def MakeLoopIndependentOp }]; } +def TypeConversionCastOp : Op]> { + let description = [{ + Indicates that tensor ops (such as tensor.generate) should be replaced with + constants (arith.constant) when possible. + }]; + let arguments = (ins UnitAttr:$ignore_dynamic_info); + + let assemblyFormat = + "(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict"; +} + #endif // TENSOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td index f29efaee620d8..3b601f42a6452 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -280,6 +280,28 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> { ]; } +def TypeConversionOpInterface : OpInterface<"TypeConversionOpInterface"> { + let description = [{ + This interface should be implemented by ops that populate type casting + of a `transform.cast_and_inline` op. It provides a method to populate a + type converter with source/target materialization patterns. + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Populate the given type converter with source/target materialization + functions. + }], + /*returnType=*/"void", + /*name=*/"populateTypeMaterializations", + /*arguments=*/(ins "::mlir::TypeConverter &":$converter) + > + ]; +} + def TypeConverterBuilderOpInterface : OpInterface<"TypeConverterBuilderOpInterface"> { let description = [{ diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp index 9e9b6bcea790d..14b6e633520d6 100644 --- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp +++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -36,6 +37,202 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter( return success(); } +//===----------------------------------------------------------------------===// +// CastAndCallOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector inputs; + if (getInputs()) + for (Value input : state.getPayloadValues(getInputs())) + inputs.push_back(input); + SmallVector outputs; + if (getOutputs()) + for (Value output : state.getPayloadValues(getOutputs())) + outputs.push_back(output); + + // Verify that the set of output values to be replaced is unique. + llvm::SmallDenseSet outputSet; + for (Value output : outputs) { + outputSet.insert(output); + } + if (outputSet.size() != outputs.size()) { + return emitSilenceableFailure(getLoc()) + << "cast and call output values must be unique"; + } + + // Get the insertion point for the call. + auto insertionOps = state.getPayloadOps(getInsertionPoint()); + if (!llvm::hasSingleElement(insertionOps)) { + return emitSilenceableFailure(getLoc()) + << "Only one op can be specified as an insertion point"; + } + bool insertAfter = getInsertAfter(); + Operation *insertionPoint = *insertionOps.begin(); + + // Check that all inputs dominate the insertion point, and the insertion + // point dominates all users of the outputs. + DominanceInfo dom(insertionPoint); + for (Value output : outputs) { + for (Operation *user : output.getUsers()) { + // If we are inserting after the insertion point operation, the + // insertion point operation must properly dominate the user. Otherwise + // basic dominance is enough. + bool doesDominate = insertAfter + ? dom.properlyDominates(insertionPoint, user) + : dom.dominates(insertionPoint, user); + if (!doesDominate) { + return emitDefiniteFailure() + << "User " << user << " is not dominated by insertion point " + << insertionPoint; + } + } + } + + for (Value input : inputs) { + // If we are inserting before the insertion point operation, the + // input must properly dominate the insertion point operation. Otherwise + // basic dominance is enough. + bool doesDominate = insertAfter + ? dom.dominates(input, insertionPoint) + : dom.properlyDominates(input, insertionPoint); + if (!doesDominate) { + return emitDefiniteFailure() + << "input " << input << " does not dominate insertion point " + << insertionPoint; + } + } + + // Get the function to inline. This can either be specified by symbol or as a + // transform handle. + func::FuncOp targetFunction = nullptr; + if (getFunctionName()) { + targetFunction = SymbolTable::lookupNearestSymbolFrom( + insertionPoint, *getFunctionName()); + if (!targetFunction) { + return emitDefiniteFailure() + << "unresolved symbol " << *getFunctionName(); + } + } else if (getFunction()) { + auto payloadOps = state.getPayloadOps(getFunction()); + if (!llvm::hasSingleElement(payloadOps)) { + return emitDefiniteFailure() << "requires a single function to call"; + } + targetFunction = dyn_cast(*payloadOps.begin()); + if (!targetFunction) { + return emitDefiniteFailure() << "invalid non-function callee"; + } + } else { + llvm_unreachable("Invalid CastAndCall op without a function to call"); + return emitDefiniteFailure(); + } + assert(targetFunction && "no target function found"); + + // Verify that the function argument and result lengths match the inputs and + // outputs given to this op. + if (targetFunction.getNumArguments() != inputs.size()) { + return emitSilenceableFailure(targetFunction.getLoc()) + << "mismatch between number of function arguments " + << targetFunction.getNumArguments() << " and number of inputs " + << inputs.size(); + } + if (targetFunction.getNumResults() != outputs.size()) { + return emitSilenceableFailure(targetFunction.getLoc()) + << "mismatch between number of function results " + << targetFunction->getNumResults() << " and number of outputs " + << outputs.size(); + } + + // Gather all specified converters. + MLIRContext *ctx = insertionPoint->getContext(); + mlir::TypeConverter converter; + if (!getRegion().empty()) { + for (Operation &op : getRegion().front()) { + cast(&op) + .populateTypeMaterializations(converter); + } + } + + OpBuilder builder(ctx); + if (insertAfter) + builder.setInsertionPointAfter(insertionPoint); + else + builder.setInsertionPoint(insertionPoint); + + for (auto [input, type] : + llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) { + if (input.getType() != type) { + Value newInput = converter.materializeSourceConversion( + builder, input.getLoc(), type, input); + if (!newInput) { + return emitSilenceableFailure(input.getLoc()) + << "Failed to materialize conversion of " << input << " to type " + << type; + } + input = newInput; + } + } + + auto callOp = builder.create(insertionPoint->getLoc(), + targetFunction, inputs); + + // Cast the call results back to the expected types. If any conversions fail + // this is a definite failure as the call has been constructed at this point. + for (auto [output, newOutput] : + llvm::zip_equal(outputs, callOp.getResults())) { + Value convertedOutput = newOutput; + if (output.getType() != newOutput.getType()) { + convertedOutput = converter.materializeTargetConversion( + builder, output.getLoc(), output.getType(), newOutput); + if (!convertedOutput) { + return emitSilenceableFailure(output.getLoc()) + << "Failed to materialize conversion of " << newOutput + << " to type " << output.getType(); + } + } + output.replaceAllUsesExcept(convertedOutput, callOp); + } + results.set(cast(getResult()), {callOp}); + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::CastAndCallOp::verify() { + if (!getRegion().empty()) { + for (Operation &op : getRegion().front()) { + if (!isa(&op)) { + InFlightDiagnostic diag = emitOpError() + << "expected children ops to implement " + "TypeConversionOpInterface"; + diag.attachNote(op.getLoc()) << "op without interface"; + return diag; + } + } + } + if (!getFunction() && !getFunctionName()) { + return emitOpError() << "expected a function handle or name to call"; + } + if (getFunction() && getFunctionName()) { + return emitOpError() << "function handle and name are mutually exclusive"; + } + return success(); +} + +void transform::CastAndCallOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getInsertionPoint(), effects); + if (getInputs()) + transform::onlyReadsHandle(getInputs(), effects); + if (getOutputs()) + transform::onlyReadsHandle(getOutputs(), effects); + if (getFunction()) + transform::onlyReadsHandle(getFunction(), effects); + transform::producesHandle(getResult(), effects); + transform::modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index ed27423870471..0c89ba2a1f189 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -15,6 +15,8 @@ #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace tensor; @@ -128,6 +130,44 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( tensor::populateRewriteAsConstantPatterns(patterns); } +//===----------------------------------------------------------------------===// +// TypeConversionCastOp +//===----------------------------------------------------------------------===// + +void transform::TypeConversionCastOp::populateTypeMaterializations( + TypeConverter &converter) { + bool ignoreDynamicInfo = getIgnoreDynamicInfo(); + converter.addSourceMaterialization([ignoreDynamicInfo]( + OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { + if (inputs.size() != 1) { + return std::nullopt; + } + Value input = inputs[0]; + if (!ignoreDynamicInfo && + !tensor::preservesStaticInformation(resultType, input.getType())) { + return std::nullopt; + } + if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { + return std::nullopt; + } + return builder.create(loc, resultType, input).getResult(); + }); + converter.addTargetMaterialization([](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { + if (inputs.size() != 1) { + return std::nullopt; + } + Value input = inputs[0]; + if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { + return std::nullopt; + } + return builder.create(loc, resultType, input).getResult(); + }); +} + //===----------------------------------------------------------------------===// // MakeLoopIndependentOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 485d4448e7c36..f2a57383cc5bf 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -16,10 +16,12 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -30,11 +32,13 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include #define DEBUG_TYPE "transform-dialect" diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir new file mode 100644 index 0000000000000..6aab07b0cb38a --- /dev/null +++ b/mlir/test/Dialect/Func/func-transform.mlir @@ -0,0 +1,120 @@ +// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @basic_cast_and_call +func.func @basic_cast_and_call() { + // CHECK-NEXT: call @second() + "test.foo"() : () -> () + // CHECK-NEXT: test.foo + // CHECK-NEXT: call @third() + func.return +} + +func.func @second() { + "test.bar"() : () -> () + func.return +} + +func.func private @third() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %f:3 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op + transform.func.cast_and_call @second before %foo : (!transform.any_op) -> !transform.any_op + transform.func.cast_and_call %f#2 after %foo : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @non_empty_arg_and_out +func.func @non_empty_arg_and_out(%arg0 : index) -> i32 { + // CHECK-NEXT: %[[FOO:.+]] = "test.foo" + %0 = "test.foo"(%arg0) : (index) -> (index) + // CHECK-NEXT: %[[CALL:.+]] = call @second(%[[FOO]]) : (index) -> i32 + %1 = "test.bar"(%0) : (index) -> (i32) + // CHECK: return %[[CALL]] : i32 + func.return %1 : i32 +} + +func.func private @second(%arg1 : index) -> i32 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op + %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op + %in = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value + %out = transform.get_result %bar[0] : (!transform.any_op) -> !transform.any_value + transform.func.cast_and_call %f#1(%in) -> %out before %bar + : (!transform.any_op, !transform.any_value, + !transform.any_value, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @multi_arg_and_result +func.func @multi_arg_and_result(%arg0 : index) -> (index, index) { + // CHECK-NEXT: %[[FOO:.+]] = "test.foo" + %0 = "test.foo"(%arg0) : (index) -> (index) + %1 = "test.bar"(%0) : (index) -> (index) + %2 = "test.bar"(%0) : (index) -> (index) + // CHECK: %[[CALL:.+]]:2 = call @second(%[[FOO]], %[[FOO]]) : (index, index) -> (index, index) + // CHECK: return %[[CALL]]#0, %[[CALL]]#1 : index, index + func.return %1, %2 : index, index +} + +func.func private @second(%arg1: index, %arg2: index) -> (index, index) + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op + %bars = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op + %in0 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value + %in1 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value + %ins = transform.merge_handles %in0, %in1 : !transform.any_value + + %outs = transform.get_result %bars[0] : (!transform.any_op) -> !transform.any_value + + transform.func.cast_and_call %f#1(%ins) -> %outs after %foo + : (!transform.any_op, !transform.any_value, + !transform.any_value, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @nested_call +func.func @nested_call() { + // CHECK-NEXT: %[[ARG:.+]] = "test.arg" + // CHECK-NEXT: test.foo + %0 = "test.arg"() : () -> (index) + "test.foo"() ({ + // CHECK-NEXT: call @second(%[[ARG]]) : (index) -> () + "test.bar"(%0) : (index) -> () + }) : () -> () +} + +func.func private @second(%arg1: index) -> () + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %arg = transform.structured.match ops{["test.arg"]} in %f#0 : (!transform.any_op) -> !transform.any_op + %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op + %in = transform.get_result %arg[0] : (!transform.any_op) -> !transform.any_value + + transform.func.cast_and_call %f#1(%in) before %bar + : (!transform.any_op, !transform.any_value, !transform.any_op) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Tensor/transform-op-casting.mlir b/mlir/test/Dialect/Tensor/transform-op-casting.mlir new file mode 100644 index 0000000000000..fd2fc8a1883a3 --- /dev/null +++ b/mlir/test/Dialect/Tensor/transform-op-casting.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s + +func.func @cast_to_dynamic(%arg0: tensor<10x13xf32>, %arg1: tensor<3x13xf32>) -> tensor<13x13xf32> { + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<10x13xf32>, tensor<3x13xf32>) -> tensor<13x13xf32> + func.return %0 : tensor<13x13xf32> +} + +func.func private @concat_replacement(%arg0: tensor, %arg1: tensor) -> tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op + %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value + %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value + transform.func.cast_and_call %f#1(%ins) -> %out before %concat { + transform.type_conversion.tensor.cast + } : (!transform.any_op, !transform.any_value, + !transform.any_value, !transform.any_op) -> !transform.any_op + transform.apply_dce to %f#0 : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func.func @cast_to_dynamic +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<10x13xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x13xf32> +// CHECK-DAG: %[[CAST0:.+]] = tensor.cast %[[ARG0]] : tensor<10x13xf32> to tensor +// CHECK-DAG: %[[CAST1:.+]] = tensor.cast %[[ARG1]] : tensor<3x13xf32> to tensor +// CHECK: %[[CALL:.+]] = call @concat_replacement(%[[CAST0]], %[[CAST1]]) +// CHECK: %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor to tensor<13x13xf32> +// CHECK: return %[[CAST_RES]] : tensor<13x13xf32> + +// ----- + +func.func @cast_to_static(%arg0: tensor) -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor into tensor + func.return %0 : tensor +} + +func.func private @collapse_replacement(%arg0: tensor<4x5xf32>) -> tensor<20xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op + %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value + %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value + transform.func.cast_and_call %f#1(%ins) -> %out before %concat { + transform.type_conversion.tensor.cast ignore_dynamic_info + } : (!transform.any_op, !transform.any_value, + !transform.any_value, !transform.any_op) -> !transform.any_op + transform.apply_dce to %f#0 : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func.func @cast_to_static +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[CAST_IN:.+]] = tensor.cast %[[ARG0]] : tensor to tensor<4x5xf32> +// CHECK: %[[CALL:.+]] = call @collapse_replacement(%[[CAST_IN]]) +// CHECK: %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<20xf32> to tensor +// CHECK: return %[[CAST_RES]] : tensor From e6211958bc210b909de4c76f17c169e3fb44ece8 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 18 Jan 2024 00:10:21 -0500 Subject: [PATCH 2/4] Collapse TypeConversion interface into converter builder interface and address comments --- .../Func/TransformOps/FuncTransformOps.td | 16 ++--- .../MemRef/TransformOps/MemRefTransformOps.td | 3 +- .../Tensor/TransformOps/TensorTransformOps.td | 15 +++-- .../Transform/IR/TransformInterfaces.td | 43 ++++++-------- .../Func/TransformOps/FuncTransformOps.cpp | 58 +++++++++---------- .../TransformOps/TensorTransformOps.cpp | 6 +- .../Dialect/Tensor/transform-op-casting.mlir | 12 ++-- .../TestTransformDialectExtension.td | 3 +- 8 files changed, 75 insertions(+), 81 deletions(-) diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td index e5086c26c55a4..afb08ebd5eb43 100644 --- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td +++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td @@ -38,22 +38,22 @@ def CastAndCallOp : Op]> { + ["getTypeConverter", + "getTypeConverterType"]>]> { let description = [{ This operation provides an "LLVMTypeConverter" that lowers memref types to LLVM types. diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td index 28e9249c82e30..39e1d7fa3494a 100644 --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -169,12 +169,17 @@ def MakeLoopIndependentOp }]; } -def TypeConversionCastOp : Op]> { +def TypeConversionCastShapeDynamicDimsOp : Op]> { let description = [{ - Indicates that tensor ops (such as tensor.generate) should be replaced with - constants (arith.constant) when possible. + Populates a type converter with conversion materialization functions that + cast a tensor value between two cast-compatible tensors. See `tensor.cast` + for more information on cast compatibility between tensors. + + If `ignore_dynamic_info` is not set, this will set an additional constraint + that source materializations do not cast dynamic dimensions to static ones. }]; let arguments = (ins UnitAttr:$ignore_dynamic_info); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td index 3b601f42a6452..1ef094436881a 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -280,34 +280,12 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> { ]; } -def TypeConversionOpInterface : OpInterface<"TypeConversionOpInterface"> { - let description = [{ - This interface should be implemented by ops that populate type casting - of a `transform.cast_and_inline` op. It provides a method to populate a - type converter with source/target materialization patterns. - }]; - - let cppNamespace = "::mlir::transform"; - - let methods = [ - InterfaceMethod< - /*desc=*/[{ - Populate the given type converter with source/target materialization - functions. - }], - /*returnType=*/"void", - /*name=*/"populateTypeMaterializations", - /*arguments=*/(ins "::mlir::TypeConverter &":$converter) - > - ]; -} - def TypeConverterBuilderOpInterface : OpInterface<"TypeConverterBuilderOpInterface"> { let description = [{ This interface should be implemented by ops that specify a type converter - for a dialect conversion. Such ops can be used with - "apply_conversion_patterns". + for a dialect conversion, or to populate a type converter with + conversions. Such ops can be used with "apply_conversion_patterns". }]; let cppNamespace = "::mlir::transform"; @@ -319,7 +297,11 @@ def TypeConverterBuilderOpInterface }], /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>", /*name=*/"getTypeConverter", - /*arguments=*/(ins) + /*arguments=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return std::make_unique<::mlir::TypeConverter>(); + }] >, StaticInterfaceMethod< /*desc=*/[{ @@ -332,6 +314,17 @@ def TypeConverterBuilderOpInterface /*methodBody=*/"", /*defaultImplementation=*/[{ return "TypeConverter"; }] >, + InterfaceMethod< + /*desc=*/[{ + Populate the given type converter with source/target materialization + functions. + }], + /*returnType=*/"void", + /*name=*/"populateTypeMaterializations", + /*arguments=*/(ins "::mlir::TypeConverter &":$converter), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return; }] + >, ]; } diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp index 14b6e633520d6..9e79b086c0be8 100644 --- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp +++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp @@ -47,21 +47,19 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, transform::TransformState &state) { SmallVector inputs; if (getInputs()) - for (Value input : state.getPayloadValues(getInputs())) - inputs.push_back(input); - SmallVector outputs; - if (getOutputs()) - for (Value output : state.getPayloadValues(getOutputs())) - outputs.push_back(output); + llvm::append_range(inputs, state.getPayloadValues(getInputs())); - // Verify that the set of output values to be replaced is unique. - llvm::SmallDenseSet outputSet; - for (Value output : outputs) { - outputSet.insert(output); - } - if (outputSet.size() != outputs.size()) { - return emitSilenceableFailure(getLoc()) - << "cast and call output values must be unique"; + SetVector outputs; + if (getOutputs()) { + for (auto output : state.getPayloadValues(getOutputs())) + outputs.insert(output); + + // Verify that the set of output values to be replaced is unique. + if (outputs.size() != + llvm::range_size(state.getPayloadValues(getOutputs()))) { + return emitSilenceableFailure(getLoc()) + << "cast and call output values must be unique"; + } } // Get the insertion point for the call. @@ -106,7 +104,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, } } - // Get the function to inline. This can either be specified by symbol or as a + // Get the function to call. This can either be specified by symbol or as a // transform handle. func::FuncOp targetFunction = nullptr; if (getFunctionName()) { @@ -129,7 +127,6 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, llvm_unreachable("Invalid CastAndCall op without a function to call"); return emitDefiniteFailure(); } - assert(targetFunction && "no target function found"); // Verify that the function argument and result lengths match the inputs and // outputs given to this op. @@ -147,37 +144,34 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, } // Gather all specified converters. - MLIRContext *ctx = insertionPoint->getContext(); mlir::TypeConverter converter; if (!getRegion().empty()) { for (Operation &op : getRegion().front()) { - cast(&op) + cast(&op) .populateTypeMaterializations(converter); } } - OpBuilder builder(ctx); if (insertAfter) - builder.setInsertionPointAfter(insertionPoint); + rewriter.setInsertionPointAfter(insertionPoint); else - builder.setInsertionPoint(insertionPoint); + rewriter.setInsertionPoint(insertionPoint); for (auto [input, type] : llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) { if (input.getType() != type) { Value newInput = converter.materializeSourceConversion( - builder, input.getLoc(), type, input); + rewriter, input.getLoc(), type, input); if (!newInput) { - return emitSilenceableFailure(input.getLoc()) - << "Failed to materialize conversion of " << input << " to type " - << type; + return emitDefiniteFailure() << "Failed to materialize conversion of " + << input << " to type " << type; } input = newInput; } } - auto callOp = builder.create(insertionPoint->getLoc(), - targetFunction, inputs); + auto callOp = rewriter.create(insertionPoint->getLoc(), + targetFunction, inputs); // Cast the call results back to the expected types. If any conversions fail // this is a definite failure as the call has been constructed at this point. @@ -186,14 +180,14 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, Value convertedOutput = newOutput; if (output.getType() != newOutput.getType()) { convertedOutput = converter.materializeTargetConversion( - builder, output.getLoc(), output.getType(), newOutput); + rewriter, output.getLoc(), output.getType(), newOutput); if (!convertedOutput) { - return emitSilenceableFailure(output.getLoc()) + return emitDefiniteFailure() << "Failed to materialize conversion of " << newOutput << " to type " << output.getType(); } } - output.replaceAllUsesExcept(convertedOutput, callOp); + rewriter.replaceAllUsesExcept(output, convertedOutput, callOp); } results.set(cast(getResult()), {callOp}); return DiagnosedSilenceableFailure::success(); @@ -202,10 +196,10 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, LogicalResult transform::CastAndCallOp::verify() { if (!getRegion().empty()) { for (Operation &op : getRegion().front()) { - if (!isa(&op)) { + if (!isa(&op)) { InFlightDiagnostic diag = emitOpError() << "expected children ops to implement " - "TypeConversionOpInterface"; + "TypeConverterBuilderOpInterface"; diag.attachNote(op.getLoc()) << "op without interface"; return diag; } diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 0c89ba2a1f189..38f1824a3634a 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -131,11 +131,11 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( } //===----------------------------------------------------------------------===// -// TypeConversionCastOp +// TypeConversionCastTensorShapeOp //===----------------------------------------------------------------------===// -void transform::TypeConversionCastOp::populateTypeMaterializations( - TypeConverter &converter) { +void transform::TypeConversionCastShapeDynamicDimsOp:: + populateTypeMaterializations(TypeConverter &converter) { bool ignoreDynamicInfo = getIgnoreDynamicInfo(); converter.addSourceMaterialization([ignoreDynamicInfo]( OpBuilder &builder, Type resultType, diff --git a/mlir/test/Dialect/Tensor/transform-op-casting.mlir b/mlir/test/Dialect/Tensor/transform-op-casting.mlir index fd2fc8a1883a3..16a1fa2b0ba9c 100644 --- a/mlir/test/Dialect/Tensor/transform-op-casting.mlir +++ b/mlir/test/Dialect/Tensor/transform-op-casting.mlir @@ -12,10 +12,10 @@ module attributes {transform.with_named_sequence} { %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op - %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value - %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value + %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value + %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value transform.func.cast_and_call %f#1(%ins) -> %out before %concat { - transform.type_conversion.tensor.cast + transform.type_conversion.tensor.cast_shape_dynamic_dims } : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op transform.apply_dce to %f#0 : !transform.any_op @@ -46,10 +46,10 @@ module attributes {transform.with_named_sequence} { %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op - %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value - %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value + %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value + %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value transform.func.cast_and_call %f#1(%ins) -> %out before %concat { - transform.type_conversion.tensor.cast ignore_dynamic_info + transform.type_conversion.tensor.cast_shape_dynamic_dims ignore_dynamic_info } : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op transform.apply_dce to %f#0 : !transform.any_op diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index 54036f7929d1b..c00cc560e83e9 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -502,7 +502,8 @@ def ApplyTestConversionPatternsOp def TestTypeConverterOp : Op]> { + [DeclareOpInterfaceMethods]> { let arguments = (ins); let results = (outs); let assemblyFormat = "attr-dict"; From 2565400bf4f9377fbf9969a2865bc5fe9b96348d Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 18 Jan 2024 10:44:32 -0500 Subject: [PATCH 3/4] Add warning about potential invalidation cases to the op --- .../mlir/Dialect/Func/TransformOps/FuncTransformOps.td | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td index afb08ebd5eb43..c36fdd1505562 100644 --- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td +++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td @@ -58,6 +58,13 @@ def CastAndCallOp : Op Date: Fri, 19 Jan 2024 11:38:02 -0500 Subject: [PATCH 4/4] Add note about interface function implementations --- .../mlir/Dialect/Transform/IR/TransformInterfaces.td | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td index 1ef094436881a..8f7b8f1999e0c 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -285,7 +285,13 @@ def TypeConverterBuilderOpInterface let description = [{ This interface should be implemented by ops that specify a type converter for a dialect conversion, or to populate a type converter with - conversions. Such ops can be used with "apply_conversion_patterns". + conversions. + + When such ops are intended to be used with "apply_conversion_patterns" or + other operations that expect a type converter, a non-default implementation + of `getTypeConverter` should be implemented. For use with "cast_and_call" + like ops that construct a type converter iteratively, non-default + `populateTypeMaterializations` should be implemented. }]; let cppNamespace = "::mlir::transform";