diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 508ee7416d55d..d015bd5290123 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -756,10 +756,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) { namespace mlir { namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { - explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter, + explicit ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config) - : rewriter(rewriter), eraseRewriter(rewriter.getContext()), - config(config) {} + : eraseRewriter(ctx), config(config) {} //===--------------------------------------------------------------------===// // State Management @@ -854,8 +853,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { Type origOutputType, const TypeConverter *converter); - Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter, - Location loc, ValueRange inputs, + Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, + ValueRange inputs, Type origOutputType, Type outputType, const TypeConverter *converter); @@ -934,8 +933,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // State //===--------------------------------------------------------------------===// - PatternRewriter &rewriter; - /// This rewriter must be used for erasing ops/blocks. SingleEraseRewriter eraseRewriter; @@ -1037,8 +1034,12 @@ void BlockTypeConversionRewrite::rollback() { LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( function_ref findLiveUser) { + auto builder = OpBuilder::atBlockBegin(block, /*listener=*/&rewriterImpl); + // Process the remapping for each of the original arguments. for (auto it : llvm::enumerate(origBlock->getArguments())) { + OpBuilder::InsertionGuard g(builder); + // If the type of this argument changed and the argument is still live, we // need to materialize a conversion. BlockArgument origArg = it.value(); @@ -1050,14 +1051,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg); bool isDroppedArg = replacementValue == origArg; - if (isDroppedArg) - rewriterImpl.rewriter.setInsertionPointToStart(getBlock()); - else - rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue); + if (!isDroppedArg) + builder.setInsertionPointAfterValue(replacementValue); Value newArg; if (converter) { newArg = converter->materializeSourceConversion( - rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(), + builder, origArg.getLoc(), origArg.getType(), isDroppedArg ? ValueRange() : ValueRange(replacementValue)); assert((!newArg || newArg.getType() == origArg.getType()) && "materialization hook did not provide a value of the expected " @@ -1322,6 +1321,8 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( Block *ConversionPatternRewriterImpl::applySignatureConversion( Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion) { + MLIRContext *ctx = block->getParentOp()->getContext(); + // If no arguments are being changed or added, there is nothing to do. unsigned origArgCount = block->getNumArguments(); auto convertedTypes = signatureConversion.getConvertedTypes(); @@ -1338,7 +1339,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // Map all new arguments to the location of the argument they originate from. SmallVector newLocs(convertedTypes.size(), - rewriter.getUnknownLoc()); + Builder(ctx).getUnknownLoc()); for (unsigned i = 0; i < origArgCount; ++i) { auto inputMap = signatureConversion.getInputMapping(i); if (!inputMap || inputMap->replacementValue) @@ -1357,8 +1358,6 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( SmallVector, 1> argInfo; argInfo.resize(origArgCount); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(newBlock); for (unsigned i = 0; i != origArgCount; ++i) { auto inputMap = signatureConversion.getInputMapping(i); if (!inputMap) @@ -1401,7 +1400,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( outputType = legalOutputType; newArg = buildUnresolvedArgumentMaterialization( - rewriter, origArg.getLoc(), replArgs, origOutputType, outputType, + newBlock, origArg.getLoc(), replArgs, origOutputType, outputType, converter); } @@ -1439,12 +1438,11 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( return convertOp.getResult(0); } Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization( - PatternRewriter &rewriter, Location loc, ValueRange inputs, - Type origOutputType, Type outputType, const TypeConverter *converter) { - return buildUnresolvedMaterialization( - MaterializationKind::Argument, rewriter.getInsertionBlock(), - rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType, - converter); + Block *block, Location loc, ValueRange inputs, Type origOutputType, + Type outputType, const TypeConverter *converter) { + return buildUnresolvedMaterialization(MaterializationKind::Argument, block, + block->begin(), loc, inputs, outputType, + origOutputType, converter); } Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( Location loc, Value input, Type outputType, @@ -1556,7 +1554,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( ConversionPatternRewriter::ConversionPatternRewriter( MLIRContext *ctx, const ConversionConfig &config) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this, config)) { + impl(new detail::ConversionPatternRewriterImpl(ctx, config)) { setListener(impl.get()); }