Skip to content

Commit

Permalink
[mlir][Transforms][NFC] Simplify BlockTypeConversionRewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
matthias-springer committed Mar 4, 2024
1 parent 9606655 commit 2c59864
Showing 1 changed file with 40 additions and 44 deletions.
84 changes: 40 additions & 44 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {

void commit() override;

void cleanup() override;

void rollback() override;

private:
Expand Down Expand Up @@ -788,24 +786,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// block is returned containing the new arguments. Returns `block` if it did
/// not require conversion.
FailureOr<Block *> convertBlockSignature(
Block *block, const TypeConverter *converter,
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);

/// Convert the types of non-entry block arguments within the given region.
LogicalResult convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});

/// Apply a signature conversion on the given region, using `converter` for
/// materializations if not null.
Block *
applySignatureConversion(Region *region,
applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter);

/// Convert the types of block arguments within the given region.
FailureOr<Block *>
convertRegionTypes(Region *region, const TypeConverter &converter,
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);

/// Apply the given signature conversion on the given block. The new block
Expand All @@ -815,7 +816,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
Block *block, const TypeConverter *converter,
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);

//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -990,24 +992,8 @@ void BlockTypeConversionRewrite::commit() {
}
}

void BlockTypeConversionRewrite::cleanup() {
assert(origBlock->empty() && "expected empty block");
origBlock->dropAllDefinedValueUses();
delete origBlock;
origBlock = nullptr;
}

void BlockTypeConversionRewrite::rollback() {
// Drop all uses of the new block arguments and replace uses of the new block.
for (int i = block->getNumArguments() - 1; i >= 0; --i)
block->getArgument(i).dropAllUses();
block->replaceAllUsesWith(origBlock);

// Move the operations back the original block, move the original block back
// into its original location and the delete the new block.
origBlock->getOperations().splice(origBlock->end(), block->getOperations());
block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
eraseBlock(block);
}

LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
Expand Down Expand Up @@ -1223,10 +1209,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Type Conversion

FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
Block *block, const TypeConverter *converter,
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
if (conversion)
return applySignatureConversion(block, converter, *conversion);
return applySignatureConversion(rewriter, block, converter, *conversion);

// If a converter wasn't provided, and the block wasn't already converted,
// there is nothing we can do.
Expand All @@ -1235,35 +1222,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(

// Try to convert the signature for the block with the provided converter.
if (auto conversion = converter->convertBlockSignature(block))
return applySignatureConversion(block, converter, *conversion);
return applySignatureConversion(rewriter, block, converter, *conversion);
return failure();
}

Block *ConversionPatternRewriterImpl::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
if (!region->empty())
return *convertBlockSignature(&region->front(), converter, &conversion);
return *convertBlockSignature(rewriter, &region->front(), converter,
&conversion);
return nullptr;
}

FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
Region *region, const TypeConverter &converter,
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
return nullptr;

if (failed(convertNonEntryRegionTypes(region, converter)))
if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
return failure();

FailureOr<Block *> newEntry =
convertBlockSignature(&region->front(), &converter, entryConversion);
FailureOr<Block *> newEntry = convertBlockSignature(
rewriter, &region->front(), &converter, entryConversion);
return newEntry;
}

LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
regionToConverter[region] = &converter;
if (region->empty())
Expand All @@ -1284,16 +1275,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
: const_cast<TypeConverter::SignatureConversion *>(
&blockConversions[blockIdx++]);

if (failed(convertBlockSignature(&block, &converter, blockConversion)))
if (failed(convertBlockSignature(rewriter, &block, &converter,
blockConversion)))
return failure();
}
return success();
}

Block *ConversionPatternRewriterImpl::applySignatureConversion(
Block *block, const TypeConverter *converter,
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
MLIRContext *ctx = eraseRewriter.getContext();
MLIRContext *ctx = rewriter.getContext();

// If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
Expand All @@ -1303,11 +1296,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(

// Split the block at the beginning to get a new block to use for the updated
// signature.
Block *newBlock = block->splitBlock(block->begin());
Block *newBlock = rewriter.splitBlock(block, block->begin());
block->replaceAllUsesWith(newBlock);
// Unlink the block, but do not erase it yet, so that the change can be rolled
// back.
block->getParent()->getBlocks().remove(block);

// Map all new arguments to the location of the argument they originate from.
SmallVector<Location> newLocs(convertedTypes.size(),
Expand Down Expand Up @@ -1383,6 +1373,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(

appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
converter);

// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
rewriter.eraseBlock(block);

return newBlock;
}

Expand Down Expand Up @@ -1590,7 +1585,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->applySignatureConversion(region, conversion, converter);
return impl->applySignatureConversion(*this, region, conversion, converter);
}

FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
Expand All @@ -1599,7 +1594,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->convertRegionTypes(region, converter, entryConversion);
return impl->convertRegionTypes(*this, region, converter, entryConversion);
}

LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
Expand All @@ -1608,7 +1603,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
return impl->convertNonEntryRegionTypes(*this, region, converter,
blockConversions);
}

void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Expand Down Expand Up @@ -2102,7 +2098,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
if (failed(impl.convertBlockSignature(block, converter))) {
if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
Expand Down

0 comments on commit 2c59864

Please sign in to comment.