diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 704597148dfac..635a2cb00f388 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -152,15 +152,11 @@ namespace { /// This class contains a snapshot of the current conversion rewriter state. /// This is useful when saving and undoing a set of rewrites. struct RewriterState { - RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites, - unsigned numIgnoredOperations, unsigned numErased) - : numUnresolvedMaterializations(numUnresolvedMaterializations), - numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), + RewriterState(unsigned numRewrites, unsigned numIgnoredOperations, + unsigned numErased) + : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), numErased(numErased) {} - /// The current number of unresolved materializations. - unsigned numUnresolvedMaterializations; - /// The current number of rewrites performed. unsigned numRewrites; @@ -171,109 +167,10 @@ struct RewriterState { unsigned numErased; }; -//===----------------------------------------------------------------------===// -// UnresolvedMaterialization - -/// This class represents an unresolved materialization, i.e. a materialization -/// that was inserted during conversion that needs to be legalized at the end of -/// the conversion process. -class UnresolvedMaterialization { -public: - /// The type of materialization. - enum Kind { - /// This materialization materializes a conversion for an illegal block - /// argument type, to a legal one. - Argument, - - /// This materialization materializes a conversion from an illegal type to a - /// legal one. - Target - }; - - UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr, - const TypeConverter *converter = nullptr, - Kind kind = Target, Type origOutputType = nullptr) - : op(op), converterAndKind(converter, kind), - origOutputType(origOutputType) {} - - /// Return the temporary conversion operation inserted for this - /// materialization. - UnrealizedConversionCastOp getOp() const { return op; } - - /// Return the type converter of this materialization (which may be null). - const TypeConverter *getConverter() const { - return converterAndKind.getPointer(); - } - - /// Return the kind of this materialization. - Kind getKind() const { return converterAndKind.getInt(); } - - /// Set the kind of this materialization. - void setKind(Kind kind) { converterAndKind.setInt(kind); } - - /// Return the original illegal output type of the input values. - Type getOrigOutputType() const { return origOutputType; } - -private: - /// The unresolved materialization operation created during conversion. - UnrealizedConversionCastOp op; - - /// The corresponding type converter to use when resolving this - /// materialization, and the kind of this materialization. - llvm::PointerIntPair converterAndKind; - - /// The original output type. This is only used for argument conversions. - Type origOutputType; -}; -} // namespace - -/// Build an unresolved materialization operation given an output type and set -/// of input operands. -static Value buildUnresolvedMaterialization( - UnresolvedMaterialization::Kind kind, Block *insertBlock, - Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, - Type origOutputType, const TypeConverter *converter, - SmallVectorImpl &unresolvedMaterializations) { - // Avoid materializing an unnecessary cast. - if (inputs.size() == 1 && inputs.front().getType() == outputType) - return inputs.front(); - - // Create an unresolved materialization. We use a new OpBuilder to avoid - // tracking the materialization like we do for other operations. - OpBuilder builder(insertBlock, insertPt); - auto convertOp = - builder.create(loc, outputType, inputs); - unresolvedMaterializations.emplace_back(convertOp, converter, kind, - origOutputType); - return convertOp.getResult(0); -} -static Value buildUnresolvedArgumentMaterialization( - PatternRewriter &rewriter, Location loc, ValueRange inputs, - Type origOutputType, Type outputType, const TypeConverter *converter, - SmallVectorImpl &unresolvedMaterializations) { - return buildUnresolvedMaterialization( - UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(), - rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType, - converter, unresolvedMaterializations); -} -static Value buildUnresolvedTargetMaterialization( - Location loc, Value input, Type outputType, const TypeConverter *converter, - SmallVectorImpl &unresolvedMaterializations) { - Block *insertBlock = input.getParentBlock(); - Block::iterator insertPt = insertBlock->begin(); - if (OpResult inputRes = dyn_cast(input)) - insertPt = ++inputRes.getOwner()->getIterator(); - - return buildUnresolvedMaterialization( - UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input, - outputType, outputType, converter, unresolvedMaterializations); -} - //===----------------------------------------------------------------------===// // IR rewrites //===----------------------------------------------------------------------===// -namespace { /// An IR rewrite that can be committed (upon success) or rolled back (upon /// failure). /// @@ -299,7 +196,8 @@ class IRRewrite { MoveOperation, ModifyOperation, ReplaceOperation, - CreateOperation + CreateOperation, + UnresolvedMaterialization }; virtual ~IRRewrite() = default; @@ -605,7 +503,7 @@ class OperationRewrite : public IRRewrite { static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::MoveOperation && - rewrite->getKind() <= Kind::CreateOperation; + rewrite->getKind() <= Kind::UnresolvedMaterialization; } protected: @@ -752,6 +650,70 @@ class CreateOperationRewrite : public OperationRewrite { void rollback() override; }; + +/// The type of materialization. +enum MaterializationKind { + /// This materialization materializes a conversion for an illegal block + /// argument type, to a legal one. + Argument, + + /// This materialization materializes a conversion from an illegal type to a + /// legal one. + Target +}; + +/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" +/// op. Unresolved materializations are erased at the end of the dialect +/// conversion. +class UnresolvedMaterializationRewrite : public OperationRewrite { +public: + UnresolvedMaterializationRewrite( + ConversionPatternRewriterImpl &rewriterImpl, + UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, + MaterializationKind kind = MaterializationKind::Target, + Type origOutputType = nullptr) + : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), + converterAndKind(converter, kind), origOutputType(origOutputType) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::UnresolvedMaterialization; + } + + UnrealizedConversionCastOp getOperation() const { + return cast(op); + } + + void rollback() override; + + void cleanup() override; + + /// Return the type converter of this materialization (which may be null). + const TypeConverter *getConverter() const { + return converterAndKind.getPointer(); + } + + /// Return the kind of this materialization. + MaterializationKind getMaterializationKind() const { + return converterAndKind.getInt(); + } + + /// Set the kind of this materialization. + void setMaterializationKind(MaterializationKind kind) { + converterAndKind.setInt(kind); + } + + /// Return the original illegal output type of the input values. + Type getOrigOutputType() const { return origOutputType; } + +private: + /// The corresponding type converter to use when resolving this + /// materialization, and the kind of this materialization. + llvm::PointerIntPair + converterAndKind; + + /// The original output type. This is only used for argument conversions. + Type origOutputType; +}; } // namespace /// Return "true" if there is an operation rewrite that matches the specified @@ -794,14 +756,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { : rewriter(rewriter), eraseRewriter(rewriter.getContext()), notifyCallback(nullptr) {} - /// Cleanup and destroy any generated rewrite operations. This method is - /// invoked when the conversion process fails. - void discardRewrites(); - - /// Apply all requested operation rewrites. This method is invoked when the - /// conversion process succeeds. - void applyRewrites(); - //===--------------------------------------------------------------------===// // State Management //===--------------------------------------------------------------------===// @@ -809,6 +763,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Return the current state of the rewriter. RewriterState getCurrentState(); + /// Apply all requested operation rewrites. This method is invoked when the + /// conversion process succeeds. + void applyRewrites(); + /// Reset the state of the rewriter to a previously saved point. void resetState(RewriterState state); @@ -841,17 +799,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// removes them from being considered for legalization. void markNestedOpsIgnored(Operation *op); - /// Detach any operations nested in the given operation from their parent - /// blocks, and erase the given operation. This can be used when the nested - /// operations are scheduled for erasure themselves, so deleting the regions - /// of the given operation together with their content would result in - /// double-free. This happens, for example, when rolling back op creation in - /// the reverse order and if the nested ops were created before the parent op. - /// This function does not need to collect nested ops recursively because it - /// is expected to also be called for each nested op when it is about to be - /// deleted. - void detachNestedAndErase(Operation *op); - //===--------------------------------------------------------------------===// // Type Conversion //===--------------------------------------------------------------------===// @@ -890,6 +837,28 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion); + //===--------------------------------------------------------------------===// + // Materializations + //===--------------------------------------------------------------------===// + /// Build an unresolved materialization operation given an output type and set + /// of input operands. + Value buildUnresolvedMaterialization(MaterializationKind kind, + Block *insertBlock, + Block::iterator insertPt, Location loc, + ValueRange inputs, Type outputType, + Type origOutputType, + const TypeConverter *converter); + + Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter, + Location loc, ValueRange inputs, + Type origOutputType, + Type outputType, + const TypeConverter *converter); + + Value buildUnresolvedTargetMaterialization(Location loc, Value input, + Type outputType, + const TypeConverter *converter); + //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -969,10 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // replacing a value with one of a different type. ConversionValueMapping mapping; - /// Ordered vector of all unresolved type conversion materializations during - /// conversion. - SmallVector unresolvedMaterializations; - /// Ordered list of block operations (creations, splits, motions). SmallVector> rewrites; @@ -1162,24 +1127,15 @@ void CreateOperationRewrite::rollback() { eraseOp(op); } -void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) { - for (Region ®ion : op->getRegions()) { - for (Block &block : region.getBlocks()) { - while (!block.getOperations().empty()) - block.getOperations().remove(block.getOperations().begin()); - block.dropAllDefinedValueUses(); - } +void UnresolvedMaterializationRewrite::rollback() { + if (getMaterializationKind() == MaterializationKind::Target) { + for (Value input : op->getOperands()) + rewriterImpl.mapping.erase(input); } - eraseRewriter.eraseOp(op); + eraseOp(op); } -void ConversionPatternRewriterImpl::discardRewrites() { - undoRewrites(); - - // Remove any newly created ops. - for (UnresolvedMaterialization &materialization : unresolvedMaterializations) - detachNestedAndErase(materialization.getOp()); -} +void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); } void ConversionPatternRewriterImpl::applyRewrites() { // Commit all rewrites. @@ -1187,39 +1143,20 @@ void ConversionPatternRewriterImpl::applyRewrites() { rewrite->commit(); for (auto &rewrite : rewrites) rewrite->cleanup(); - - // Drop all of the unresolved materialization operations created during - // conversion. - for (auto &mat : unresolvedMaterializations) - eraseRewriter.eraseOp(mat.getOp()); } //===----------------------------------------------------------------------===// // State Management RewriterState ConversionPatternRewriterImpl::getCurrentState() { - return RewriterState(unresolvedMaterializations.size(), rewrites.size(), - ignoredOps.size(), eraseRewriter.erased.size()); + return RewriterState(rewrites.size(), ignoredOps.size(), + eraseRewriter.erased.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { // Undo any rewrites. undoRewrites(state.numRewrites); - // Pop all of the newly inserted materializations. - while (unresolvedMaterializations.size() != - state.numUnresolvedMaterializations) { - UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val(); - UnrealizedConversionCastOp op = mat.getOp(); - - // If this was a target materialization, drop the mapping that was inserted. - if (mat.getKind() == UnresolvedMaterialization::Target) { - for (Value input : op->getOperands()) - mapping.erase(input); - } - detachNestedAndErase(op); - } - // Pop all of the recorded ignored operations that are no longer valid. while (ignoredOps.size() != state.numIgnoredOperations) ignoredOps.pop_back(); @@ -1280,8 +1217,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( if (currentTypeConverter && desiredType && newOperandType != desiredType) { Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); Value castValue = buildUnresolvedTargetMaterialization( - operandLoc, newOperand, desiredType, currentTypeConverter, - unresolvedMaterializations); + operandLoc, newOperand, desiredType, currentTypeConverter); mapping.map(mapping.lookupOrDefault(newOperand), castValue); newOperand = castValue; } @@ -1463,7 +1399,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( newArg = buildUnresolvedArgumentMaterialization( rewriter, origArg.getLoc(), replArgs, origOutputType, outputType, - converter, unresolvedMaterializations); + converter); } mapping.map(origArg, newArg); @@ -1476,6 +1412,50 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( return newBlock; } +//===----------------------------------------------------------------------===// +// Materializations +//===----------------------------------------------------------------------===// + +/// Build an unresolved materialization operation given an output type and set +/// of input operands. +Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( + MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, + Location loc, ValueRange inputs, Type outputType, Type origOutputType, + const TypeConverter *converter) { + // Avoid materializing an unnecessary cast. + if (inputs.size() == 1 && inputs.front().getType() == outputType) + return inputs.front(); + + // Create an unresolved materialization. We use a new OpBuilder to avoid + // tracking the materialization like we do for other operations. + OpBuilder builder(insertBlock, insertPt); + auto convertOp = + builder.create(loc, outputType, inputs); + appendRewrite(convertOp, converter, kind, + origOutputType); + return convertOp.getResult(0); +} +Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization( + PatternRewriter &rewriter, Location loc, ValueRange inputs, + Type origOutputType, Type outputType, const TypeConverter *converter) { + return buildUnresolvedMaterialization( + MaterializationKind::Argument, rewriter.getInsertionBlock(), + rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType, + converter); +} +Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( + Location loc, Value input, Type outputType, + const TypeConverter *converter) { + Block *insertBlock = input.getParentBlock(); + Block::iterator insertPt = insertBlock->begin(); + if (OpResult inputRes = dyn_cast(input)) + insertPt = ++inputRes.getOwner()->getIterator(); + + return buildUnresolvedMaterialization(MaterializationKind::Target, + insertBlock, insertPt, loc, input, + outputType, outputType, converter); +} + //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -2528,18 +2508,18 @@ LogicalResult OperationConverter::convertOperations( for (auto *op : toConvert) if (failed(convert(rewriter, op))) - return rewriterImpl.discardRewrites(), failure(); + return rewriterImpl.undoRewrites(), failure(); // Now that all of the operations have been converted, finalize the conversion // process to ensure any lingering conversion artifacts are cleaned up and // legalized. if (failed(finalize(rewriter))) - return rewriterImpl.discardRewrites(), failure(); + return rewriterImpl.undoRewrites(), failure(); // After a successful conversion, apply rewrites if this is not an analysis // conversion. if (mode == OpConversionMode::Analysis) { - rewriterImpl.discardRewrites(); + rewriterImpl.undoRewrites(); } else { rewriterImpl.applyRewrites(); } @@ -2645,11 +2625,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, /// Compute all of the unresolved materializations that will persist beyond the /// conversion process, and require inserting a proper user materialization for. static void computeNecessaryMaterializations( - DenseMap &materializationOps, + DenseMap + &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap> &inverseMapping, - SetVector &necessaryMaterializations) { + SetVector &necessaryMaterializations) { auto isLive = [&](Value value) { auto findFn = [&](Operation *user) { auto matIt = materializationOps.find(user); @@ -2684,14 +2665,17 @@ static void computeNecessaryMaterializations( return Value(); }; - SetVector worklist; - for (auto &mat : rewriterImpl.unresolvedMaterializations) { - materializationOps.try_emplace(mat.getOp(), &mat); - worklist.insert(&mat); + SetVector worklist; + for (auto &rewrite : rewriterImpl.rewrites) { + auto *mat = dyn_cast(rewrite.get()); + if (!mat) + continue; + materializationOps.try_emplace(mat->getOperation(), mat); + worklist.insert(mat); } while (!worklist.empty()) { - UnresolvedMaterialization *mat = worklist.pop_back_val(); - UnrealizedConversionCastOp op = mat->getOp(); + UnresolvedMaterializationRewrite *mat = worklist.pop_back_val(); + UnrealizedConversionCastOp op = mat->getOperation(); // We currently only handle target materializations here. assert(op->getNumResults() == 1 && "unexpected materialization type"); @@ -2733,7 +2717,7 @@ static void computeNecessaryMaterializations( auto isBlockArg = [](Value v) { return isa(v); }; if (llvm::any_of(op->getOperands(), isBlockArg) || llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) { - mat->setKind(UnresolvedMaterialization::Argument); + mat->setMaterializationKind(MaterializationKind::Argument); } // If the materialization does not have any live users, we don't need to @@ -2743,7 +2727,7 @@ static void computeNecessaryMaterializations( // value replacement even if the types differ in some cases. When those // patterns are fixed, we can drop the argument special case here. bool isMaterializationLive = isLive(opResult); - if (mat->getKind() == UnresolvedMaterialization::Argument) + if (mat->getMaterializationKind() == MaterializationKind::Argument) isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive); if (!isMaterializationLive) continue; @@ -2763,8 +2747,9 @@ static void computeNecessaryMaterializations( /// Legalize the given unresolved materialization. Returns success if the /// materialization was legalized, failure otherise. static LogicalResult legalizeUnresolvedMaterialization( - UnresolvedMaterialization &mat, - DenseMap &materializationOps, + UnresolvedMaterializationRewrite &mat, + DenseMap + &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap> &inverseMapping) { @@ -2784,7 +2769,7 @@ static LogicalResult legalizeUnresolvedMaterialization( return Value(); }; - UnrealizedConversionCastOp op = mat.getOp(); + UnrealizedConversionCastOp op = mat.getOperation(); if (!rewriterImpl.ignoredOps.insert(op)) return success(); @@ -2834,8 +2819,8 @@ static LogicalResult legalizeUnresolvedMaterialization( rewriter.setInsertionPoint(op); Value newMaterialization; - switch (mat.getKind()) { - case UnresolvedMaterialization::Argument: + switch (mat.getMaterializationKind()) { + case MaterializationKind::Argument: // Try to materialize an argument conversion. // FIXME: The current argument materialization hook expects the original // output type, even though it doesn't use that as the actual output type @@ -2852,7 +2837,7 @@ static LogicalResult legalizeUnresolvedMaterialization( // If an argument materialization failed, fallback to trying a target // materialization. [[fallthrough]]; - case UnresolvedMaterialization::Target: + case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( rewriter, op->getLoc(), outputType, inputOperands); break; @@ -2880,14 +2865,12 @@ LogicalResult OperationConverter::legalizeUnresolvedMaterializations( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, std::optional>> &inverseMapping) { - if (rewriterImpl.unresolvedMaterializations.empty()) - return success(); inverseMapping = rewriterImpl.mapping.getInverse(); // As an initial step, compute all of the inserted materializations that we // expect to persist beyond the conversion process. - DenseMap materializationOps; - SetVector necessaryMaterializations; + DenseMap materializationOps; + SetVector necessaryMaterializations; computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl, *inverseMapping, necessaryMaterializations);