diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index afdd31a748c8c..db41b9f19e7e8 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -154,12 +154,13 @@ namespace { struct RewriterState { RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, unsigned numReplacements, unsigned numArgReplacements, - unsigned numRewrites, unsigned numIgnoredOperations) + unsigned numRewrites, unsigned numIgnoredOperations, + unsigned numErased) : numCreatedOps(numCreatedOps), numUnresolvedMaterializations(numUnresolvedMaterializations), numReplacements(numReplacements), numArgReplacements(numArgReplacements), numRewrites(numRewrites), - numIgnoredOperations(numIgnoredOperations) {} + numIgnoredOperations(numIgnoredOperations), numErased(numErased) {} /// The current number of created operations. unsigned numCreatedOps; @@ -178,6 +179,9 @@ struct RewriterState { /// The current number of ignored operations. unsigned numIgnoredOperations; + + /// The current number of erased operations/blocks. + unsigned numErased; }; //===----------------------------------------------------------------------===// @@ -292,370 +296,6 @@ static Value buildUnresolvedTargetMaterialization( outputType, outputType, converter, unresolvedMaterializations); } -//===----------------------------------------------------------------------===// -// ArgConverter -//===----------------------------------------------------------------------===// -namespace { -/// This class provides a simple interface for converting the types of block -/// arguments. This is done by creating a new block that contains the new legal -/// types and extracting the block that contains the old illegal types to allow -/// for undoing pending rewrites in the case of failure. -struct ArgConverter { - ArgConverter( - PatternRewriter &rewriter, - SmallVectorImpl &unresolvedMaterializations) - : rewriter(rewriter), - unresolvedMaterializations(unresolvedMaterializations) {} - - /// This structure contains the information pertaining to an argument that has - /// been converted. - struct ConvertedArgInfo { - ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, - Value castValue = nullptr) - : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} - - /// The start index of in the new argument list that contains arguments that - /// replace the original. - unsigned newArgIdx; - - /// The number of arguments that replaced the original argument. - unsigned newArgSize; - - /// The cast value that was created to cast from the new arguments to the - /// old. This only used if 'newArgSize' > 1. - Value castValue; - }; - - /// This structure contains information pertaining to a block that has had its - /// signature converted. - struct ConvertedBlockInfo { - ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter) - : origBlock(origBlock), converter(converter) {} - - /// The original block that was requested to have its signature converted. - Block *origBlock; - - /// The conversion information for each of the arguments. The information is - /// std::nullopt if the argument was dropped during conversion. - SmallVector, 1> argInfo; - - /// The type converter used to convert the arguments. - const TypeConverter *converter; - }; - - //===--------------------------------------------------------------------===// - // Rewrite Application - //===--------------------------------------------------------------------===// - - /// Erase any rewrites registered for the blocks within the given operation - /// which is about to be removed. This merely drops the rewrites without - /// undoing them. - void notifyOpRemoved(Operation *op); - - /// Cleanup and undo any generated conversions for the arguments of block. - /// This method replaces the new block with the original, reverting the IR to - /// its original state. - void discardRewrites(Block *block); - - /// Fully replace uses of the old arguments with the new. - void applyRewrites(ConversionValueMapping &mapping); - - /// Materialize any necessary conversions for converted arguments that have - /// live users, using the provided `findLiveUser` to search for a user that - /// survives the conversion process. - LogicalResult - materializeLiveConversions(ConversionValueMapping &mapping, - OpBuilder &builder, - function_ref findLiveUser); - - //===--------------------------------------------------------------------===// - // Conversion - //===--------------------------------------------------------------------===// - - /// Attempt to convert the signature of the given block, if successful a new - /// block is returned containing the new arguments. Returns `block` if it did - /// not require conversion. - FailureOr - convertSignature(Block *block, const TypeConverter *converter, - ConversionValueMapping &mapping, - SmallVectorImpl &argReplacements); - - /// Apply the given signature conversion on the given block. The new block - /// containing the updated signature is returned. If no conversions were - /// necessary, e.g. if the block has no arguments, `block` is returned. - /// `converter` is used to generate any necessary cast operations that - /// translate between the origin argument types and those specified in the - /// signature conversion. - Block *applySignatureConversion( - Block *block, const TypeConverter *converter, - TypeConverter::SignatureConversion &signatureConversion, - ConversionValueMapping &mapping, - SmallVectorImpl &argReplacements); - - /// A collection of blocks that have had their arguments converted. This is a - /// map from the new replacement block, back to the original block. - llvm::MapVector conversionInfo; - - /// The pattern rewriter to use when materializing conversions. - PatternRewriter &rewriter; - - /// An ordered set of unresolved materializations during conversion. - SmallVectorImpl &unresolvedMaterializations; -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Rewrite Application - -void ArgConverter::notifyOpRemoved(Operation *op) { - if (conversionInfo.empty()) - return; - - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - // Drop any rewrites from within. - for (Operation &nestedOp : block) - if (nestedOp.getNumRegions()) - notifyOpRemoved(&nestedOp); - - // Check if this block was converted. - auto *it = conversionInfo.find(&block); - if (it == conversionInfo.end()) - continue; - - // Drop all uses of the original arguments and delete the original block. - Block *origBlock = it->second.origBlock; - for (BlockArgument arg : origBlock->getArguments()) - arg.dropAllUses(); - conversionInfo.erase(it); - } - } -} - -void ArgConverter::discardRewrites(Block *block) { - auto *it = conversionInfo.find(block); - if (it == conversionInfo.end()) - return; - Block *origBlock = it->second.origBlock; - - // Drop all uses of the new block arguments and replace uses of the new block. - for (int i = block->getNumArguments() - 1; i >= 0; --i) - block->getArgument(i).dropAllUses(); - block->replaceAllUsesWith(origBlock); - - // Move the operations back the original block, move the original block back - // into its original location and the delete the new block. - origBlock->getOperations().splice(origBlock->end(), block->getOperations()); - block->getParent()->getBlocks().insert(Region::iterator(block), origBlock); - block->erase(); - - conversionInfo.erase(it); -} - -void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { - for (auto &info : conversionInfo) { - ConvertedBlockInfo &blockInfo = info.second; - Block *origBlock = blockInfo.origBlock; - - // Process the remapping for each of the original arguments. - for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { - std::optional &argInfo = blockInfo.argInfo[i]; - BlockArgument origArg = origBlock->getArgument(i); - - // Handle the case of a 1->0 value mapping. - if (!argInfo) { - if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType())) - origArg.replaceAllUsesWith(newArg); - continue; - } - - // Otherwise this is a 1->1+ value mapping. - Value castValue = argInfo->castValue; - assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); - - // If the argument is still used, replace it with the generated cast. - if (!origArg.use_empty()) { - origArg.replaceAllUsesWith( - mapping.lookupOrDefault(castValue, origArg.getType())); - } - } - - delete origBlock; - blockInfo.origBlock = nullptr; - } -} - -LogicalResult ArgConverter::materializeLiveConversions( - ConversionValueMapping &mapping, OpBuilder &builder, - function_ref findLiveUser) { - for (auto &info : conversionInfo) { - Block *newBlock = info.first; - ConvertedBlockInfo &blockInfo = info.second; - Block *origBlock = blockInfo.origBlock; - - // Process the remapping for each of the original arguments. - for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { - // If the type of this argument changed and the argument is still live, we - // need to materialize a conversion. - BlockArgument origArg = origBlock->getArgument(i); - if (mapping.lookupOrNull(origArg, origArg.getType())) - continue; - Operation *liveUser = findLiveUser(origArg); - if (!liveUser) - continue; - - Value replacementValue = mapping.lookupOrDefault(origArg); - bool isDroppedArg = replacementValue == origArg; - if (isDroppedArg) - rewriter.setInsertionPointToStart(newBlock); - else - rewriter.setInsertionPointAfterValue(replacementValue); - Value newArg; - if (blockInfo.converter) { - newArg = blockInfo.converter->materializeSourceConversion( - rewriter, origArg.getLoc(), origArg.getType(), - isDroppedArg ? ValueRange() : ValueRange(replacementValue)); - assert((!newArg || newArg.getType() == origArg.getType()) && - "materialization hook did not provide a value of the expected " - "type"); - } - if (!newArg) { - InFlightDiagnostic diag = - emitError(origArg.getLoc()) - << "failed to materialize conversion for block argument #" << i - << " that remained live after conversion, type was " - << origArg.getType(); - if (!isDroppedArg) - diag << ", with target type " << replacementValue.getType(); - diag.attachNote(liveUser->getLoc()) - << "see existing live user here: " << *liveUser; - return failure(); - } - mapping.map(origArg, newArg); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Conversion - -FailureOr ArgConverter::convertSignature( - Block *block, const TypeConverter *converter, - ConversionValueMapping &mapping, - SmallVectorImpl &argReplacements) { - assert(block->getParent() && "cannot convert signature of detached block"); - - // If a converter wasn't provided, and the block wasn't already converted, - // there is nothing we can do. - if (!converter) - return failure(); - - // Try to convert the signature for the block with the provided converter. - if (auto conversion = converter->convertBlockSignature(block)) - return applySignatureConversion(block, converter, *conversion, mapping, - argReplacements); - return failure(); -} - -Block *ArgConverter::applySignatureConversion( - Block *block, const TypeConverter *converter, - TypeConverter::SignatureConversion &signatureConversion, - ConversionValueMapping &mapping, - SmallVectorImpl &argReplacements) { - // If no arguments are being changed or added, there is nothing to do. - unsigned origArgCount = block->getNumArguments(); - auto convertedTypes = signatureConversion.getConvertedTypes(); - if (llvm::equal(block->getArgumentTypes(), convertedTypes)) - return block; - - // Split the block at the beginning to get a new block to use for the updated - // signature. - Block *newBlock = block->splitBlock(block->begin()); - block->replaceAllUsesWith(newBlock); - // Unlink the block, but do not erase it yet, so that the change can be rolled - // back. - block->getParent()->getBlocks().remove(block); - - // Map all new arguments to the location of the argument they originate from. - SmallVector newLocs(convertedTypes.size(), - rewriter.getUnknownLoc()); - for (unsigned i = 0; i < origArgCount; ++i) { - auto inputMap = signatureConversion.getInputMapping(i); - if (!inputMap || inputMap->replacementValue) - continue; - Location origLoc = block->getArgument(i).getLoc(); - for (unsigned j = 0; j < inputMap->size; ++j) - newLocs[inputMap->inputNo + j] = origLoc; - } - - SmallVector newArgRange( - newBlock->addArguments(convertedTypes, newLocs)); - ArrayRef newArgs(newArgRange); - - // Remap each of the original arguments as determined by the signature - // conversion. - ConvertedBlockInfo info(block, converter); - info.argInfo.resize(origArgCount); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(newBlock); - for (unsigned i = 0; i != origArgCount; ++i) { - auto inputMap = signatureConversion.getInputMapping(i); - if (!inputMap) - continue; - BlockArgument origArg = block->getArgument(i); - - // If inputMap->replacementValue is not nullptr, then the argument is - // dropped and a replacement value is provided to be the remappedValue. - if (inputMap->replacementValue) { - assert(inputMap->size == 0 && - "invalid to provide a replacement value when the argument isn't " - "dropped"); - mapping.map(origArg, inputMap->replacementValue); - argReplacements.push_back(origArg); - continue; - } - - // Otherwise, this is a 1->1+ mapping. - auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Value newArg; - - // If this is a 1->1 mapping and the types of new and replacement arguments - // match (i.e. it's an identity map), then the argument is mapped to its - // original type. - // FIXME: We simply pass through the replacement argument if there wasn't a - // converter, which isn't great as it allows implicit type conversions to - // appear. We should properly restructure this code to handle cases where a - // converter isn't provided and also to properly handle the case where an - // argument materialization is actually a temporary source materialization - // (e.g. in the case of 1->N). - if (replArgs.size() == 1 && - (!converter || replArgs[0].getType() == origArg.getType())) { - newArg = replArgs.front(); - } else { - Type origOutputType = origArg.getType(); - - // Legalize the argument output type. - Type outputType = origOutputType; - if (Type legalOutputType = converter->convertType(outputType)) - outputType = legalOutputType; - - newArg = buildUnresolvedArgumentMaterialization( - rewriter, origArg.getLoc(), replArgs, origOutputType, outputType, - converter, unresolvedMaterializations); - } - - mapping.map(origArg, newArg); - argReplacements.push_back(origArg); - info.argInfo[i] = - ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); - } - - conversionInfo.insert({newBlock, std::move(info)}); - return newBlock; -} - //===----------------------------------------------------------------------===// // IR rewrites //===----------------------------------------------------------------------===// @@ -702,6 +342,12 @@ class IRRewrite { IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl) : kind(kind), rewriterImpl(rewriterImpl) {} + /// Erase the given op (unless it was already erased). + void eraseOp(Operation *op); + + /// Erase the given block (unless it was already erased). + void eraseBlock(Block *block); + const Kind kind; ConversionPatternRewriterImpl &rewriterImpl; }; @@ -744,8 +390,7 @@ class CreateBlockRewrite : public BlockRewrite { auto &blockOps = block->getOperations(); while (!blockOps.empty()) blockOps.remove(blockOps.begin()); - block->dropAllDefinedValueUses(); - block->erase(); + eraseBlock(block); } }; @@ -881,8 +526,7 @@ class SplitBlockRewrite : public BlockRewrite { // Merge back the block that was split out. originalBlock->getOperations().splice(originalBlock->end(), block->getOperations()); - block->dropAllDefinedValueUses(); - block->erase(); + eraseBlock(block); } private: @@ -890,20 +534,59 @@ class SplitBlockRewrite : public BlockRewrite { Block *originalBlock; }; +/// This structure contains the information pertaining to an argument that has +/// been converted. +struct ConvertedArgInfo { + ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, + Value castValue = nullptr) + : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} + + /// The start index of in the new argument list that contains arguments that + /// replace the original. + unsigned newArgIdx; + + /// The number of arguments that replaced the original argument. + unsigned newArgSize; + + /// The cast value that was created to cast from the new arguments to the + /// old. This only used if 'newArgSize' > 1. + Value castValue; +}; + /// Block type conversion. This rewrite is partially reflected in the IR. class BlockTypeConversionRewrite : public BlockRewrite { public: - BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block) - : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {} + BlockTypeConversionRewrite( + ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Block *origBlock, SmallVector, 1> argInfo, + const TypeConverter *converter) + : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block), + origBlock(origBlock), argInfo(argInfo), converter(converter) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::BlockTypeConversion; } - // TODO: Block type conversions are currently committed in - // `ArgConverter::applyRewrites`. This should be done in the "commit" method. + /// Materialize any necessary conversions for converted arguments that have + /// live users, using the provided `findLiveUser` to search for a user that + /// survives the conversion process. + LogicalResult + materializeLiveConversions(function_ref findLiveUser); + + void commit() override; + void rollback() override; + +private: + /// The original block that was requested to have its signature converted. + Block *origBlock; + + /// The conversion information for each of the arguments. The information is + /// std::nullopt if the argument was dropped during conversion. + SmallVector, 1> argInfo; + + /// The type converter used to convert the arguments. + const TypeConverter *converter; }; /// An operation rewrite. @@ -949,8 +632,8 @@ class MoveOperationRewrite : public OperationRewrite { // The block in which this operation was previously contained. Block *block; - // The original successor of this operation before it was moved. "nullptr" if - // this operation was the only operation in the region. + // The original successor of this operation before it was moved. "nullptr" + // if this operation was the only operation in the region. Operation *insertBeforeOp; }; @@ -1027,6 +710,26 @@ static bool hasRewrite(R &&rewrites, Operation *op) { }); } +/// Find the single rewrite object of the specified type and block among the +/// given rewrites. In debug mode, asserts that there is mo more than one such +/// object. Return "nullptr" if no object was found. +template +static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) { + RewriteTy *result = nullptr; + for (auto &rewrite : rewrites) { + auto *rewriteTy = dyn_cast(rewrite.get()); + if (rewriteTy && rewriteTy->getBlock() == block) { +#ifndef NDEBUG + assert(!result && "expected single matching rewrite"); + result = rewriteTy; +#else + return rewriteTy; +#endif // NDEBUG + } + } + return result; +} + //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -1034,7 +737,7 @@ namespace mlir { namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter) - : argConverter(rewriter, unresolvedMaterializations), + : rewriter(rewriter), eraseRewriter(rewriter.getContext()), notifyCallback(nullptr) {} /// Cleanup and destroy any generated rewrite operations. This method is @@ -1084,15 +787,33 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// removes them from being considered for legalization. void markNestedOpsIgnored(Operation *op); + /// Detach any operations nested in the given operation from their parent + /// blocks, and erase the given operation. This can be used when the nested + /// operations are scheduled for erasure themselves, so deleting the regions + /// of the given operation together with their content would result in + /// double-free. This happens, for example, when rolling back op creation in + /// the reverse order and if the nested ops were created before the parent op. + /// This function does not need to collect nested ops recursively because it + /// is expected to also be called for each nested op when it is about to be + /// deleted. + void detachNestedAndErase(Operation *op); + //===--------------------------------------------------------------------===// // Type Conversion //===--------------------------------------------------------------------===// - /// Convert the signature of the given block. + /// Attempt to convert the signature of the given block, if successful a new + /// block is returned containing the new arguments. Returns `block` if it did + /// not require conversion. FailureOr convertBlockSignature( Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion = nullptr); + /// Convert the types of non-entry block arguments within the given region. + LogicalResult convertNonEntryRegionTypes( + Region *region, const TypeConverter &converter, + ArrayRef blockConversions = {}); + /// Apply a signature conversion on the given region, using `converter` for /// materializations if not null. Block * @@ -1105,10 +826,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion); - /// Convert the types of non-entry block arguments within the given region. - LogicalResult convertNonEntryRegionTypes( - Region *region, const TypeConverter &converter, - ArrayRef blockConversions = {}); + /// Apply the given signature conversion on the given block. The new block + /// containing the updated signature is returned. If no conversions were + /// necessary, e.g. if the block has no arguments, `block` is returned. + /// `converter` is used to generate any necessary cast operations that + /// translate between the origin argument types and those specified in the + /// signature conversion. + Block *applySignatureConversion( + Block *block, const TypeConverter *converter, + TypeConverter::SignatureConversion &signatureConversion); //===--------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1140,17 +866,54 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { notifyMatchFailure(Location loc, function_ref reasonCallback) override; + //===--------------------------------------------------------------------===// + // IR Erasure + //===--------------------------------------------------------------------===// + + /// A rewriter that keeps track of erased ops and blocks. It ensures that no + /// operation or block is erased multiple times. This rewriter assumes that + /// no new IR is created between calls to `eraseOp`/`eraseBlock`. + struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener { + public: + SingleEraseRewriter(MLIRContext *context) + : RewriterBase(context, /*listener=*/this) {} + + /// Erase the given op (unless it was already erased). + void eraseOp(Operation *op) override { + if (erased.contains(op)) + return; + op->dropAllUses(); + RewriterBase::eraseOp(op); + } + + /// Erase the given block (unless it was already erased). + void eraseBlock(Block *block) override { + if (erased.contains(block)) + return; + block->dropAllDefinedValueUses(); + RewriterBase::eraseBlock(block); + } + + void notifyOperationErased(Operation *op) override { erased.insert(op); } + void notifyBlockErased(Block *block) override { erased.insert(block); } + + /// Pointers to all erased operations and blocks. + SetVector erased; + }; + //===--------------------------------------------------------------------===// // State //===--------------------------------------------------------------------===// + PatternRewriter &rewriter; + + /// This rewriter must be used for erasing ops/blocks. + SingleEraseRewriter eraseRewriter; + // Mapping between replaced values that differ in type. This happens when // replacing a value with one of a different type. ConversionValueMapping mapping; - /// Utility used to convert block arguments. - ArgConverter argConverter; - /// Ordered vector of all of the newly created operations during conversion. SmallVector createdOps; @@ -1207,20 +970,100 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { } // namespace detail } // namespace mlir +void IRRewrite::eraseOp(Operation *op) { + rewriterImpl.eraseRewriter.eraseOp(op); +} + +void IRRewrite::eraseBlock(Block *block) { + rewriterImpl.eraseRewriter.eraseBlock(block); +} + +void BlockTypeConversionRewrite::commit() { + // Process the remapping for each of the original arguments. + for (auto [origArg, info] : + llvm::zip_equal(origBlock->getArguments(), argInfo)) { + // Handle the case of a 1->0 value mapping. + if (!info) { + if (Value newArg = + rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) + origArg.replaceAllUsesWith(newArg); + continue; + } + + // Otherwise this is a 1->1+ value mapping. + Value castValue = info->castValue; + assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); + + // If the argument is still used, replace it with the generated cast. + if (!origArg.use_empty()) { + origArg.replaceAllUsesWith( + rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType())); + } + } + + delete origBlock; + origBlock = nullptr; +} + void BlockTypeConversionRewrite::rollback() { - // Undo the type conversion. - rewriterImpl.argConverter.discardRewrites(block); -} - -/// Detach any operations nested in the given operation from their parent -/// blocks, and erase the given operation. This can be used when the nested -/// operations are scheduled for erasure themselves, so deleting the regions of -/// the given operation together with their content would result in double-free. -/// This happens, for example, when rolling back op creation in the reverse -/// order and if the nested ops were created before the parent op. This function -/// does not need to collect nested ops recursively because it is expected to -/// also be called for each nested op when it is about to be deleted. -static void detachNestedAndErase(Operation *op) { + // Drop all uses of the new block arguments and replace uses of the new block. + for (int i = block->getNumArguments() - 1; i >= 0; --i) + block->getArgument(i).dropAllUses(); + block->replaceAllUsesWith(origBlock); + + // Move the operations back the original block, move the original block back + // into its original location and the delete the new block. + origBlock->getOperations().splice(origBlock->end(), block->getOperations()); + block->getParent()->getBlocks().insert(Region::iterator(block), origBlock); + eraseBlock(block); +} + +LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( + function_ref findLiveUser) { + // Process the remapping for each of the original arguments. + for (auto it : llvm::enumerate(origBlock->getArguments())) { + // If the type of this argument changed and the argument is still live, we + // need to materialize a conversion. + BlockArgument origArg = it.value(); + if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) + continue; + Operation *liveUser = findLiveUser(origArg); + if (!liveUser) + continue; + + Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg); + bool isDroppedArg = replacementValue == origArg; + if (isDroppedArg) + rewriterImpl.rewriter.setInsertionPointToStart(getBlock()); + else + rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue); + Value newArg; + if (converter) { + newArg = converter->materializeSourceConversion( + rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(), + isDroppedArg ? ValueRange() : ValueRange(replacementValue)); + assert((!newArg || newArg.getType() == origArg.getType()) && + "materialization hook did not provide a value of the expected " + "type"); + } + if (!newArg) { + InFlightDiagnostic diag = + emitError(origArg.getLoc()) + << "failed to materialize conversion for block argument #" + << it.index() << " that remained live after conversion, type was " + << origArg.getType(); + if (!isDroppedArg) + diag << ", with target type " << replacementValue.getType(); + diag.attachNote(liveUser->getLoc()) + << "see existing live user here: " << *liveUser; + return failure(); + } + rewriterImpl.mapping.map(origArg, newArg); + } + return success(); +} + +void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) { for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { while (!block.getOperations().empty()) @@ -1228,8 +1071,7 @@ static void detachNestedAndErase(Operation *op) { block.dropAllDefinedValueUses(); } } - op->dropAllUses(); - op->erase(); + eraseRewriter.eraseOp(op); } void ConversionPatternRewriterImpl::discardRewrites() { @@ -1248,11 +1090,6 @@ void ConversionPatternRewriterImpl::applyRewrites() { for (OpResult result : repl.first->getResults()) if (Value newValue = mapping.lookupOrNull(result, result.getType())) result.replaceAllUsesWith(newValue); - - // If this operation defines any regions, drop any pending argument - // rewrites. - if (repl.first->getNumRegions()) - argConverter.notifyOpRemoved(repl.first); } // Apply all of the requested argument replacements. @@ -1279,22 +1116,16 @@ void ConversionPatternRewriterImpl::applyRewrites() { // Drop all of the unresolved materialization operations created during // conversion. - for (auto &mat : unresolvedMaterializations) { - mat.getOp()->dropAllUses(); - mat.getOp()->erase(); - } + for (auto &mat : unresolvedMaterializations) + eraseRewriter.eraseOp(mat.getOp()); // In a second pass, erase all of the replaced operations in reverse. This // allows processing nested operations before their parent region is // destroyed. Because we process in reverse order, producers may be deleted // before their users (a pattern deleting a producer and then the consumer) // so we first drop all uses explicitly. - for (auto &repl : llvm::reverse(replacements)) { - repl.first->dropAllUses(); - repl.first->erase(); - } - - argConverter.applyRewrites(mapping); + for (auto &repl : llvm::reverse(replacements)) + eraseRewriter.eraseOp(repl.first); // Commit all rewrites. for (auto &rewrite : rewrites) @@ -1307,7 +1138,8 @@ void ConversionPatternRewriterImpl::applyRewrites() { RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), unresolvedMaterializations.size(), replacements.size(), argReplacements.size(), - rewrites.size(), ignoredOps.size()); + rewrites.size(), ignoredOps.size(), + eraseRewriter.erased.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { @@ -1355,6 +1187,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { while (!operationsWithChangedResults.empty() && operationsWithChangedResults.back() >= state.numReplacements) operationsWithChangedResults.pop_back(); + + while (eraseRewriter.erased.size() != state.numErased) + eraseRewriter.erased.pop_back(); } void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { @@ -1443,18 +1278,18 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { FailureOr ConversionPatternRewriterImpl::convertBlockSignature( Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion) { - FailureOr result = - conversion ? argConverter.applySignatureConversion( - block, converter, *conversion, mapping, argReplacements) - : argConverter.convertSignature(block, converter, mapping, - argReplacements); - if (failed(result)) + if (conversion) + return applySignatureConversion(block, converter, *conversion); + + // If a converter wasn't provided, and the block wasn't already converted, + // there is nothing we can do. + if (!converter) return failure(); - if (Block *newBlock = *result) { - if (newBlock != block) - appendRewrite(newBlock); - } - return result; + + // Try to convert the signature for the block with the provided converter. + if (auto conversion = converter->convertBlockSignature(block)) + return applySignatureConversion(block, converter, *conversion); + return failure(); } Block *ConversionPatternRewriterImpl::applySignatureConversion( @@ -1508,6 +1343,102 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( return success(); } +Block *ConversionPatternRewriterImpl::applySignatureConversion( + Block *block, const TypeConverter *converter, + TypeConverter::SignatureConversion &signatureConversion) { + // If no arguments are being changed or added, there is nothing to do. + unsigned origArgCount = block->getNumArguments(); + auto convertedTypes = signatureConversion.getConvertedTypes(); + if (llvm::equal(block->getArgumentTypes(), convertedTypes)) + return block; + + // Split the block at the beginning to get a new block to use for the updated + // signature. + Block *newBlock = block->splitBlock(block->begin()); + block->replaceAllUsesWith(newBlock); + // Unlink the block, but do not erase it yet, so that the change can be rolled + // back. + block->getParent()->getBlocks().remove(block); + + // Map all new arguments to the location of the argument they originate from. + SmallVector newLocs(convertedTypes.size(), + rewriter.getUnknownLoc()); + for (unsigned i = 0; i < origArgCount; ++i) { + auto inputMap = signatureConversion.getInputMapping(i); + if (!inputMap || inputMap->replacementValue) + continue; + Location origLoc = block->getArgument(i).getLoc(); + for (unsigned j = 0; j < inputMap->size; ++j) + newLocs[inputMap->inputNo + j] = origLoc; + } + + SmallVector newArgRange( + newBlock->addArguments(convertedTypes, newLocs)); + ArrayRef newArgs(newArgRange); + + // Remap each of the original arguments as determined by the signature + // conversion. + SmallVector, 1> argInfo; + argInfo.resize(origArgCount); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(newBlock); + for (unsigned i = 0; i != origArgCount; ++i) { + auto inputMap = signatureConversion.getInputMapping(i); + if (!inputMap) + continue; + BlockArgument origArg = block->getArgument(i); + + // If inputMap->replacementValue is not nullptr, then the argument is + // dropped and a replacement value is provided to be the remappedValue. + if (inputMap->replacementValue) { + assert(inputMap->size == 0 && + "invalid to provide a replacement value when the argument isn't " + "dropped"); + mapping.map(origArg, inputMap->replacementValue); + argReplacements.push_back(origArg); + continue; + } + + // Otherwise, this is a 1->1+ mapping. + auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); + Value newArg; + + // If this is a 1->1 mapping and the types of new and replacement arguments + // match (i.e. it's an identity map), then the argument is mapped to its + // original type. + // FIXME: We simply pass through the replacement argument if there wasn't a + // converter, which isn't great as it allows implicit type conversions to + // appear. We should properly restructure this code to handle cases where a + // converter isn't provided and also to properly handle the case where an + // argument materialization is actually a temporary source materialization + // (e.g. in the case of 1->N). + if (replArgs.size() == 1 && + (!converter || replArgs[0].getType() == origArg.getType())) { + newArg = replArgs.front(); + } else { + Type origOutputType = origArg.getType(); + + // Legalize the argument output type. + Type outputType = origOutputType; + if (Type legalOutputType = converter->convertType(outputType)) + outputType = legalOutputType; + + newArg = buildUnresolvedArgumentMaterialization( + rewriter, origArg.getLoc(), replArgs, origOutputType, outputType, + converter, unresolvedMaterializations); + } + + mapping.map(origArg, newArg); + argReplacements.push_back(origArg); + argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); + } + + appendRewrite(newBlock, block, argInfo, + converter); + return newBlock; +} + //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -2635,8 +2566,11 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( }); return liveUserIt == val.user_end() ? nullptr : *liveUserIt; }; - return rewriterImpl.argConverter.materializeLiveConversions( - rewriterImpl.mapping, rewriter, findLiveUser); + for (auto &r : rewriterImpl.rewrites) + if (auto *rewrite = dyn_cast(r.get())) + if (failed(rewrite->materializeLiveConversions(findLiveUser))) + return failure(); + return success(); } /// Replace the results of a materialization operation with the given values.