From cbbf7417717aff35e59d0403c1ec82aaa7fb8afc Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 6 Jul 2024 14:28:41 +0200 Subject: [PATCH] fix test --- mlir/docs/DialectConversion.md | 3 +- .../SCF/TransformOps/SCFTransformOps.td | 11 +++ .../mlir/Transforms/DialectConversion.h | 10 +-- .../Conversion/LLVMCommon/TypeConverter.cpp | 28 ++++-- .../Dialect/SCF/TransformOps/CMakeLists.txt | 1 + .../SCF/TransformOps/SCFTransformOps.cpp | 13 ++- .../Transforms/Utils/DialectConversion.cpp | 87 ++++++++++--------- .../FuncToLLVM/func-memref-return.mlir | 4 +- .../Transforms/test-block-legalization.mlir | 44 ++++++++++ 9 files changed, 141 insertions(+), 60 deletions(-) create mode 100644 mlir/test/Transforms/test-block-legalization.mlir diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index db26e6477d5fc..23e74470a835f 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -352,7 +352,8 @@ class TypeConverter { /// This method registers a materialization that will be called when /// converting (potentially multiple) block arguments that were the result of - /// a signature conversion of a single block argument, to a single SSA value. + /// a signature conversion of a single block argument, to a single SSA value + /// with the old argument type. template ::template arg_t<1>> void addArgumentMaterialization(FnT &&callback) { diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index 7bf914f6456ce..20880d94a83ca 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -38,6 +38,17 @@ def ApplySCFStructuralConversionPatternsOp : Op]> { + let description = [{ + Collects patterns that lower structured control flow ops to unstructured + control flow. + }]; + + let assemblyFormat = "attr-dict"; +} + def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">; def ForallToForOp : Op>::template arg_t<1>> void addArgumentMaterialization(FnT &&callback) { diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index f5620a6a7cd91..32d02d5e438bd 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -153,9 +153,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, type.isVarArg()); }); - // Materialization for memrefs creates descriptor structs from individual - // values constituting them, when descriptors are used, i.e. more than one - // value represents a memref. + // Argument materializations convert from the new block argument types + // (multiple SSA values that make up a memref descriptor) back to the + // original block argument type. The dialect conversion framework will then + // insert a target materialization from the original block argument type to + // a legal type. addArgumentMaterialization( [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) -> std::optional { @@ -164,12 +166,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, // memref descriptor cannot be built just from a bare pointer. return std::nullopt; } - return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, - inputs); + Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this, + resultType, inputs); + // An argument materialization must return a value of type + // `resultType`, so insert a cast from the memref descriptor type + // (!llvm.struct) to the original memref type. + return builder.create(loc, resultType, desc) + .getResult(0); }); addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc) -> std::optional { + Value desc; if (inputs.size() == 1) { // This is a bare pointer. We allow bare pointers only for function entry // blocks. @@ -180,10 +188,16 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, if (!block->isEntryBlock() || !isa(block->getParentOp())) return std::nullopt; - return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType, + desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType, inputs[0]); + } else { + desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); } - return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); + // An argument materialization must return a value of type `resultType`, + // so insert a cast from the memref descriptor type (!llvm.struct) to the + // original memref type. + return builder.create(loc, resultType, desc) + .getResult(0); }); // Add generic source and target materializations to handle cases where // non-LLVM types persist after an LLVM conversion. diff --git a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt index 1d6f9ebd153f0..06bccab80e7d8 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps MLIRIR MLIRLoopLikeInterface MLIRSCFDialect + MLIRSCFToControlFlow MLIRSCFTransforms MLIRSCFUtils MLIRTransformDialect diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 56ff2709a589e..c4a55c302d0a3 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" + +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -49,6 +51,11 @@ void transform::ApplySCFStructuralConversionPatternsOp:: conversionTarget); } +void transform::ApplySCFToControlFlowPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + populateSCFToControlFlowConversionPatterns(patterns); +} + //===----------------------------------------------------------------------===// // ForallToForOp //===----------------------------------------------------------------------===// @@ -261,8 +268,10 @@ loopScheduling(scf::ForOp forOp, return 1; }; - std::optional ubConstant = getConstantIntValue(forOp.getUpperBound()); - std::optional lbConstant = getConstantIntValue(forOp.getLowerBound()); + std::optional ubConstant = + getConstantIntValue(forOp.getUpperBound()); + std::optional lbConstant = + getConstantIntValue(forOp.getLowerBound()); DenseMap opCycles; std::map> wrappedSchedule; for (Operation &op : forOp.getBody()->getOperations()) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index e6c0ee2ab2949..1e0afee2373a9 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { UnresolvedMaterializationRewrite( ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, - MaterializationKind kind = MaterializationKind::Target, - Type origOutputType = nullptr) + MaterializationKind kind = MaterializationKind::Target) : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), - converterAndKind(converter, kind), origOutputType(origOutputType) {} + converterAndKind(converter, kind) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::UnresolvedMaterialization; @@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { return converterAndKind.getInt(); } - /// Return the original illegal output type of the input values. - Type getOrigOutputType() const { return origOutputType; } - private: /// The corresponding type converter to use when resolving this /// materialization, and the kind of this materialization. llvm::PointerIntPair converterAndKind; - - /// The original output type. This is only used for argument conversions. - Type origOutputType; }; } // namespace @@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, - Type origOutputType, const TypeConverter *converter); Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, ValueRange inputs, - Type origOutputType, Type outputType, const TypeConverter *converter); @@ -1388,20 +1379,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( if (replArgs.size() == 1 && (!converter || replArgs[0].getType() == origArg.getType())) { newArg = replArgs.front(); + mapping.map(origArg, newArg); } else { - Type origOutputType = origArg.getType(); - - // Legalize the argument output type. - Type outputType = origOutputType; - if (Type legalOutputType = converter->convertType(outputType)) - outputType = legalOutputType; - - newArg = buildUnresolvedArgumentMaterialization( - newBlock, origArg.getLoc(), replArgs, origOutputType, outputType, - converter); + // Build argument materialization: new block arguments -> old block + // argument type. + Value argMat = buildUnresolvedArgumentMaterialization( + newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter); + mapping.map(origArg, argMat); + + // Build target materialization: old block argument type -> legal type. + // Note: This function returns an "empty" type if no valid conversion to + // a legal type exists. In that case, we continue the conversion with the + // original block argument type. + Type legalOutputType = converter->convertType(origArg.getType()); + if (legalOutputType && legalOutputType != origArg.getType()) { + newArg = buildUnresolvedTargetMaterialization( + origArg.getLoc(), argMat, legalOutputType, converter); + mapping.map(argMat, newArg); + } else { + newArg = argMat; + } } - mapping.map(origArg, newArg); appendRewrite(block, origArg); argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } @@ -1424,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// of input operands. Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, - Location loc, ValueRange inputs, Type outputType, Type origOutputType, + Location loc, ValueRange inputs, Type outputType, const TypeConverter *converter) { // Avoid materializing an unnecessary cast. if (inputs.size() == 1 && inputs.front().getType() == outputType) @@ -1435,16 +1434,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( OpBuilder builder(insertBlock, insertPt); auto convertOp = builder.create(loc, outputType, inputs); - appendRewrite(convertOp, converter, kind, - origOutputType); + appendRewrite(convertOp, converter, kind); return convertOp.getResult(0); } Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization( - Block *block, Location loc, ValueRange inputs, Type origOutputType, - Type outputType, const TypeConverter *converter) { + Block *block, Location loc, ValueRange inputs, Type outputType, + const TypeConverter *converter) { return buildUnresolvedMaterialization(MaterializationKind::Argument, block, block->begin(), loc, inputs, outputType, - origOutputType, converter); + converter); } Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( Location loc, Value input, Type outputType, @@ -1456,7 +1454,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( return buildUnresolvedMaterialization(MaterializationKind::Target, insertBlock, insertPt, loc, input, - outputType, outputType, converter); + outputType, converter); } //===----------------------------------------------------------------------===// @@ -2672,6 +2670,9 @@ static void computeNecessaryMaterializations( ConversionPatternRewriterImpl &rewriterImpl, DenseMap> &inverseMapping, SetVector &necessaryMaterializations) { + // Helper function to check if the given value or a not yet materialized + // replacement of the given value is live. + // Note: `inverseMapping` maps from replaced values to original values. auto isLive = [&](Value value) { auto findFn = [&](Operation *user) { auto matIt = materializationOps.find(user); @@ -2679,12 +2680,18 @@ static void computeNecessaryMaterializations( return !necessaryMaterializations.count(matIt->second); return rewriterImpl.isOpIgnored(user); }; - // This value may be replacing another value that has a live user. - for (Value inv : inverseMapping.lookup(value)) - if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end()) + // A worklist is needed because a value may have gone through a chain of + // replacements and each of the replaced values may have live users. + SmallVector worklist; + worklist.push_back(value); + while (!worklist.empty()) { + Value next = worklist.pop_back_val(); + if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end()) return true; - // Or have live users itself. - return llvm::find_if_not(value.getUsers(), findFn) != value.user_end(); + // This value may be replacing another value that has a live user. + llvm::append_range(worklist, inverseMapping.lookup(next)); + } + return false; }; llvm::unique_function lookupRemappedValue = @@ -2844,18 +2851,10 @@ static LogicalResult legalizeUnresolvedMaterialization( switch (mat.getMaterializationKind()) { case MaterializationKind::Argument: // Try to materialize an argument conversion. - // FIXME: The current argument materialization hook expects the original - // output type, even though it doesn't use that as the actual output type - // of the generated IR. The output type is just used as an indicator of - // the type of materialization to do. This behavior is really awkward in - // that it diverges from the behavior of the other hooks, and can be - // easily misunderstood. We should clean up the argument hooks to better - // represent the desired invariants we actually care about. newMaterialization = converter->materializeArgumentConversion( - rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands); + rewriter, op->getLoc(), outputType, inputOperands); if (newMaterialization) break; - // If an argument materialization failed, fallback to trying a target // materialization. [[fallthrough]]; @@ -2865,6 +2864,8 @@ static LogicalResult legalizeUnresolvedMaterialization( break; } if (newMaterialization) { + assert(newMaterialization.getType() == outputType && + "materialization callback produced value of incorrect type"); replaceMaterialization(rewriterImpl, opResult, newMaterialization, inverseMapping); return success(); diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir index 91ef571cb3bf7..6b9df32fe02dd 100644 --- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir +++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir @@ -1,8 +1,8 @@ // RUN: mlir-opt -convert-func-to-llvm -reconcile-unrealized-casts %s | FileCheck %s -// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' %s | FileCheck %s --check-prefix=BAREPTR +// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR -// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR +// RUN: mlir-opt -transform-interpreter -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR // These tests were separated from func-memref.mlir because applying // -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting diff --git a/mlir/test/Transforms/test-block-legalization.mlir b/mlir/test/Transforms/test-block-legalization.mlir new file mode 100644 index 0000000000000..d739f95a56947 --- /dev/null +++ b/mlir/test/Transforms/test-block-legalization.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -transform-interpreter | FileCheck %s + +// CHECK-LABEL: func @complex_block_signature_conversion( +// CHECK: %[[cst:.*]] = complex.constant +// CHECK: %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex to !llvm.struct<(f64, f64)> +// Note: Some blocks are omitted. +// CHECK: llvm.br ^[[block1:.*]](%[[complex_llvm]] +// CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>): +// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex +// CHECK: llvm.br ^[[block2:.*]] +// CHECK: ^[[block2]]: +// CHECK: "test.consumer_of_complex"(%[[cast]]) : (complex) -> () +func.func @complex_block_signature_conversion() { + %cst = complex.constant [0.000000e+00, 0.000000e+00] : complex + %true = arith.constant true + %0 = scf.if %true -> complex { + scf.yield %cst : complex + } else { + scf.yield %cst : complex + } + + // Regression test to ensure that the a source materialization is inserted. + // The operand of "test.consumer_of_complex" must not change. + "test.consumer_of_complex"(%0) : (complex) -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %toplevel_module + : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %func { + transform.apply_conversion_patterns.dialect_to_llvm "cf" + transform.apply_conversion_patterns.func.func_to_llvm + transform.apply_conversion_patterns.scf.scf_to_control_flow + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + } { + legal_dialects = ["llvm"], + partial_conversion + } : !transform.any_op + transform.yield + } +}