Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms][NFC] Simplify ArgConverter state #81462

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Feb 12, 2024

  • When converting a block signature, ArgConverter creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in regionMapping. The old block is not yet deleted, so that the conversion can be rolled back. regionMapping is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion.
  • regionToConverter is a mapping from regions to type converter. That field is never accessed within ArgConverter. It should be stored in ConversionPatternRewriterImpl instead.
  • convertedBlocks is not needed. Old blocks are already stored in ConvertedBlockInfo.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Feb 12, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes
  • When converting a block signature, ArgConverter creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in regionMapping. The old block is not yet deleted, so that the conversion can be rolled back. regionMapping is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion.
  • regionToConverter is a mapping from regions to type converter. That field is never accessed within ArgConverter. It should be stored in ConversionPatternRewriterImpl instead.

Full diff: https://github.com/llvm/llvm-project/pull/81462.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+18-45)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 489ccd0139c7f2..53717f632621dd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -348,18 +348,6 @@ struct ArgConverter {
     return conversionInfo.count(block) || convertedBlocks.count(block);
   }
 
-  /// Set the type converter to use for the given region.
-  void setConverter(Region *region, const TypeConverter *typeConverter) {
-    assert(typeConverter && "expected valid type converter");
-    regionToConverter[region] = typeConverter;
-  }
-
-  /// Return the type converter to use for the given region, or null if there
-  /// isn't one.
-  const TypeConverter *getConverter(Region *region) {
-    return regionToConverter.lookup(region);
-  }
-
   //===--------------------------------------------------------------------===//
   // Rewrite Application
   //===--------------------------------------------------------------------===//
@@ -409,9 +397,6 @@ struct ArgConverter {
       ConversionValueMapping &mapping,
       SmallVectorImpl<BlockArgument> &argReplacements);
 
-  /// Insert a new conversion into the cache.
-  void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
-
   /// A collection of blocks that have had their arguments converted. This is a
   /// map from the new replacement block, back to the original block.
   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
@@ -419,14 +404,6 @@ struct ArgConverter {
   /// The set of original blocks that were converted.
   DenseSet<Block *> convertedBlocks;
 
-  /// A mapping from valid regions, to those containing the original blocks of a
-  /// conversion.
-  DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
-
-  /// A mapping of regions to type converters that should be used when
-  /// converting the arguments of blocks within that region.
-  DenseMap<Region *, const TypeConverter *> regionToConverter;
-
   /// The pattern rewriter to use when materializing conversions.
   PatternRewriter &rewriter;
 
@@ -474,9 +451,10 @@ void ArgConverter::discardRewrites(Block *block) {
     block->getArgument(i).dropAllUses();
   block->replaceAllUsesWith(origBlock);
 
-  // Move the operations back the original block and the delete the new block.
+  // Move the operations back the original block, move the original block back
+  // into its original location and the delete the new block.
   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
-  origBlock->moveBefore(block);
+  block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
   block->erase();
 
   convertedBlocks.erase(origBlock);
@@ -510,6 +488,9 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
             mapping.lookupOrDefault(castValue, origArg.getType()));
       }
     }
+
+    delete origBlock;
+    blockInfo.origBlock = nullptr;
   }
 }
 
@@ -603,6 +584,9 @@ Block *ArgConverter::applySignatureConversion(
   // signature.
   Block *newBlock = block->splitBlock(block->begin());
   block->replaceAllUsesWith(newBlock);
+  // Unlink the block, but do not erase it yet, so that the change can be rolled
+  // back.
+  block->getParent()->getBlocks().remove(block);
 
   // Map all new arguments to the location of the argument they originate from.
   SmallVector<Location> newLocs(convertedTypes.size(),
@@ -679,24 +663,9 @@ Block *ArgConverter::applySignatureConversion(
         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
-  // Remove the original block from the region and return the new one.
-  insertConversion(newBlock, std::move(info));
-  return newBlock;
-}
-
-void ArgConverter::insertConversion(Block *newBlock,
-                                    ConvertedBlockInfo &&info) {
-  // Get a region to insert the old block.
-  Region *region = newBlock->getParent();
-  std::unique_ptr<Region> &mappedRegion = regionMapping[region];
-  if (!mappedRegion)
-    mappedRegion = std::make_unique<Region>(region->getParentOp());
-
-  // Move the original block to the mapped region and emplace the conversion.
-  mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
-                                   info.origBlock->getIterator());
-  convertedBlocks.insert(info.origBlock);
+  convertedBlocks.insert(block);
   conversionInfo.insert({newBlock, std::move(info)});
+  return newBlock;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1182,6 +1151,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// active.
   const TypeConverter *currentTypeConverter = nullptr;
 
+  /// A mapping of regions to type converters that should be used when
+  /// converting the arguments of blocks within that region.
+  DenseMap<Region *, const TypeConverter *> regionToConverter;
+
   /// This allows the user to collect the match failure message.
   function_ref<void(Diagnostic &)> notifyCallback;
 
@@ -1459,7 +1432,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
     Region *region, const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
-  argConverter.setConverter(region, &converter);
+  regionToConverter[region] = &converter;
   if (region->empty())
     return nullptr;
 
@@ -1474,7 +1447,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
 LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
     Region *region, const TypeConverter &converter,
     ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  argConverter.setConverter(region, &converter);
+  regionToConverter[region] = &converter;
   if (region->empty())
     return success();
 
@@ -2154,7 +2127,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
 
     // If the region of the block has a type converter, try to convert the block
     // directly.
-    if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
+    if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
       if (failed(impl.convertBlockSignature(block, converter))) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));

@matthias-springer matthias-springer force-pushed the users/matthias-springer/simplify_arg_converter branch from 5af1476 to ba1a808 Compare February 12, 2024 11:25
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_modify_op_inplace branch from f7010ea to 4bb6521 Compare February 14, 2024 16:18
@matthias-springer matthias-springer force-pushed the users/matthias-springer/simplify_arg_converter branch from ba1a808 to c7afdb2 Compare February 14, 2024 16:21
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_modify_op_inplace branch from 4bb6521 to 820bcdd Compare February 14, 2024 16:40
@matthias-springer matthias-springer force-pushed the users/matthias-springer/simplify_arg_converter branch from c7afdb2 to a79501e Compare February 14, 2024 16:41
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_modify_op_inplace branch from 820bcdd to 1c69f42 Compare February 16, 2024 15:11
@matthias-springer matthias-springer force-pushed the users/matthias-springer/simplify_arg_converter branch from a79501e to a7fffc3 Compare February 16, 2024 15:12
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_modify_op_inplace branch 2 times, most recently from 8a8a79d to 5cabd6c Compare February 21, 2024 15:20
Base automatically changed from users/matthias-springer/dialect_conversion_modify_op_inplace to main February 21, 2024 15:34
* When converting a block signature, `ArgConverter` creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in `regionMapping`. The old block is not yet deleted, so that the conversion can be rolled back. `regionMapping` is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion.
* `regionToConverter` is a mapping from regions to type converter. That field is never accessed within `ArgConverter`. It should be stored in `ConversionPatternRewriterImpl` instead.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/simplify_arg_converter branch from a7fffc3 to ae28cd9 Compare February 21, 2024 15:44
@matthias-springer matthias-springer merged commit b49f155 into main Feb 21, 2024
3 of 4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/simplify_arg_converter branch February 21, 2024 15:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants