Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][SROA] Reuse allocators to avoid rewalking the IR #91971

Merged
merged 3 commits into from
May 14, 2024

Conversation

Dinistro
Copy link
Contributor

This commit extends the SROA interfaces to ensure the interface instantiations can communicate newly created allocators to the algorithm. This ensures that the SROA implementation does no longer require re-walking the IR to find new allocators.

This commit extends the SROA interfaces to ensure the impementations can
communicate newly created allocators to the algorithm. This ensures that
the SROA implementation does no longer require re-walking the IR to find
new allocators.
@llvmbot
Copy link
Collaborator

llvmbot commented May 13, 2024

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Christian Ulmann (Dinistro)

Changes

This commit extends the SROA interfaces to ensure the interface instantiations can communicate newly created allocators to the algorithm. This ensures that the SROA implementation does no longer require re-walking the IR to find new allocators.


Full diff: https://github.com/llvm/llvm-project/pull/91971.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+8-2)
  • (modified) mlir/include/mlir/Transforms/SROA.h (+4-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+8-5)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+8-5)
  • (modified) mlir/lib/Transforms/SROA.cpp (+56-30)
  • (added) mlir/test/Transforms/sroa.mlir (+31)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+67-12)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+3-2)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index e2409cbec5fde..29960f457373c 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -298,14 +298,20 @@ def DestructurableAllocationOpInterface
       "destructure",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
-           "::mlir::OpBuilder &":$builder)
+           "::mlir::OpBuilder &":$builder,
+           "::mlir::SmallVectorImpl<::mlir::DestructurableAllocationOpInterface> &":
+             $newAllocators)
     >,
     InterfaceMethod<[{
         Hook triggered once the destructuring of a slot is complete, meaning the
         original slot is no longer being refered to and could be deleted.
         This will only be called for slots declared by this operation.
+
+        Must return a new destructurable allocation op if this operation
+        produced multiple destructurable slots, nullopt otherwise.
       }],
-      "void", "handleDestructuringComplete",
+      "::std::optional<::mlir::DestructurableAllocationOpInterface>",
+        "handleDestructuringComplete",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "::mlir::OpBuilder &":$builder)
     >,
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index fa84fb1eae73a..d48f809d30766 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -27,8 +27,10 @@ struct SROAStatistics {
   llvm::Statistic *maxSubelementAmount = nullptr;
 };
 
-/// Attempts to destructure the slots of destructurable allocators. Returns
-/// failure if no slot was destructured.
+/// Attempts to destructure the slots of destructurable allocators. Iteratively
+/// retries the destructuring of all slots as destructuring one slot might
+/// enable subsequent destructuring. Returns failure if no slot was
+/// destructured.
 LogicalResult tryToDestructureMemorySlots(
     ArrayRef<DestructurableAllocationOpInterface> allocators,
     OpBuilder &builder, const DataLayout &dataLayout,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 4fdf847a559ce..3f1e5b1773bf7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -77,10 +77,10 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
                                    *destructuredType}};
 }
 
-DenseMap<Attribute, MemorySlot>
-LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
-                            const SmallPtrSetImpl<Attribute> &usedIndices,
-                            OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
+    const DestructurableMemorySlot &slot,
+    const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
   assert(slot.ptr == getResult());
   builder.setInsertionPointAfter(*this);
 
@@ -92,16 +92,19 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
     auto subAlloca = builder.create<LLVM::AllocaOp>(
         getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
         getArraySize());
+    newAllocators.push_back(subAlloca);
     slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
   }
 
   return slotMap;
 }
 
-void LLVM::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+LLVM::AllocaOp::handleDestructuringComplete(
     const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
   this->erase();
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index e30598e6878f4..631dee2d40538 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -126,10 +126,10 @@ memref::AllocaOp::getDestructurableSlots() {
       DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
 }
 
-DenseMap<Attribute, MemorySlot>
-memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
-                              const SmallPtrSetImpl<Attribute> &usedIndices,
-                              OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
+    const DestructurableMemorySlot &slot,
+    const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
   builder.setInsertionPointAfter(*this);
 
   DenseMap<Attribute, MemorySlot> slotMap;
@@ -139,6 +139,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
     Type elemType = memrefType.getTypeAtIndex(usedIndex);
     MemRefType elemPtr = MemRefType::get({}, elemType);
     auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
+    newAllocators.push_back(subAlloca);
     slotMap.try_emplace<MemorySlot>(usedIndex,
                                     {subAlloca.getResult(), elemType});
   }
@@ -146,10 +147,12 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
   return slotMap;
 }
 
-void memref::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+memref::AllocaOp::handleDestructuringComplete(
     const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
   this->erase();
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 4e28fa687ffd4..67cbade07bc94 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -132,16 +132,17 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
 /// Performs the destructuring of a destructible slot given associated
 /// destructuring information. The provided slot will be destructured in
 /// subslots as specified by its allocator.
-static void destructureSlot(DestructurableMemorySlot &slot,
-                            DestructurableAllocationOpInterface allocator,
-                            OpBuilder &builder, const DataLayout &dataLayout,
-                            MemorySlotDestructuringInfo &info,
-                            const SROAStatistics &statistics) {
+static void destructureSlot(
+    DestructurableMemorySlot &slot,
+    DestructurableAllocationOpInterface allocator, OpBuilder &builder,
+    const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators,
+    const SROAStatistics &statistics) {
   OpBuilder::InsertionGuard guard(builder);
 
   builder.setInsertionPointToStart(slot.ptr.getParentBlock());
   DenseMap<Attribute, MemorySlot> subslots =
-      allocator.destructure(slot, info.usedIndices, builder);
+      allocator.destructure(slot, info.usedIndices, builder, newAllocators);
 
   if (statistics.slotsWithMemoryBenefit &&
       slot.elementPtrs.size() != info.usedIndices.size())
@@ -185,7 +186,11 @@ static void destructureSlot(DestructurableMemorySlot &slot,
   if (statistics.destructuredAmount)
     (*statistics.destructuredAmount)++;
 
-  allocator.handleDestructuringComplete(slot, builder);
+  std::optional<DestructurableAllocationOpInterface> newAllocator =
+      allocator.handleDestructuringComplete(slot, builder);
+  // Add newly created allocators to the worklist for further processing.
+  if (newAllocator)
+    newAllocators.push_back(*newAllocator);
 }
 
 LogicalResult mlir::tryToDestructureMemorySlots(
@@ -194,16 +199,44 @@ LogicalResult mlir::tryToDestructureMemorySlots(
     SROAStatistics statistics) {
   bool destructuredAny = false;
 
-  for (DestructurableAllocationOpInterface allocator : allocators) {
-    for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
-      std::optional<MemorySlotDestructuringInfo> info =
-          computeDestructuringInfo(slot, dataLayout);
-      if (!info)
-        continue;
+  SmallVector<DestructurableAllocationOpInterface> workList(allocators.begin(),
+                                                            allocators.end());
+  SmallVector<DestructurableAllocationOpInterface> newWorkList;
+  newWorkList.reserve(allocators.size());
+  // Destructuring a slot can allow for further destructuring of other
+  // slots, destructuring is tried until no destructuring succeeds.
+  while (true) {
+    bool changesInThisRound = false;
+
+    for (DestructurableAllocationOpInterface allocator : workList) {
+      bool destructuredAnySlot = false;
+      for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
+        std::optional<MemorySlotDestructuringInfo> info =
+            computeDestructuringInfo(slot, dataLayout);
+        if (!info)
+          continue;
 
-      destructureSlot(slot, allocator, builder, dataLayout, *info, statistics);
-      destructuredAny = true;
+        destructureSlot(slot, allocator, builder, dataLayout, *info,
+                        newWorkList, statistics);
+        destructuredAnySlot = true;
+
+        // A break is required, since destructuring a slot may invalidate the
+        // remaning slots of an allocator.
+        break;
+      }
+      if (!destructuredAnySlot)
+        newWorkList.push_back(allocator);
+      changesInThisRound |= destructuredAnySlot;
     }
+
+    if (!changesInThisRound)
+      break;
+    destructuredAny |= changesInThisRound;
+
+    // Swap the vector's backing memory and clear the entries in newWorkList
+    // afterwards. This ensures that additional heap allocations can be avoided.
+    workList.swap(newWorkList);
+    newWorkList.clear();
   }
 
   return success(destructuredAny);
@@ -230,23 +263,16 @@ struct SROA : public impl::SROABase<SROA> {
 
       OpBuilder builder(&region.front(), region.front().begin());
 
-      // Destructuring a slot can allow for further destructuring of other
-      // slots, destructuring is tried until no destructuring succeeds.
-      while (true) {
-        SmallVector<DestructurableAllocationOpInterface> allocators;
-        // Build a list of allocators to attempt to destructure the slots of.
-        // TODO: Update list on the fly to avoid repeated visiting of the same
-        // allocators.
-        region.walk([&](DestructurableAllocationOpInterface allocator) {
-          allocators.emplace_back(allocator);
-        });
-
-        if (failed(tryToDestructureMemorySlots(allocators, builder, dataLayout,
-                                               statistics)))
-          break;
+      SmallVector<DestructurableAllocationOpInterface> allocators;
+      // Build a list of allocators to attempt to destructure the slots of.
+      region.walk([&](DestructurableAllocationOpInterface allocator) {
+        allocators.emplace_back(allocator);
+      });
 
+      // Attempt to destructure as many slots as possible.
+      if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
+                                                statistics)))
         changed = true;
-      }
     }
     if (!changed)
       markAllAnalysesPreserved();
diff --git a/mlir/test/Transforms/sroa.mlir b/mlir/test/Transforms/sroa.mlir
new file mode 100644
index 0000000000000..c9e80a6cf8dd1
--- /dev/null
+++ b/mlir/test/Transforms/sroa.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(sroa))' --split-input-file | FileCheck %s
+
+// Verifies that allocators with mutliple slots are handled properly.
+
+// CHECK-LABEL: func.func @multi_slot_alloca
+func.func @multi_slot_alloca() -> (i32, i32) {
+  %0 = arith.constant 0 : index
+  %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+  // CHECK-COUNT-2: test.multi_slot_alloca : () -> memref<i32>
+  %3 = memref.load %1[%0] {first}: memref<2xi32>
+  %4 = memref.load %2[%0] {second} : memref<4xi32>
+  return %3, %4 : i32, i32
+}
+
+// -----
+
+// Verifies that a multi slot allocator can be partially destructured.
+
+func.func private @consumer(memref<2xi32>)
+
+// CHECK-LABEL: func.func @multi_slot_alloca_only_second
+func.func @multi_slot_alloca_only_second() -> (i32, i32) {
+  %0 = arith.constant 0 : index
+  // CHECK: test.multi_slot_alloca : () -> memref<2xi32>
+  // CHECK: test.multi_slot_alloca : () -> memref<i32>
+  %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+  func.call @consumer(%1) : (memref<2xi32>) -> ()
+  %3 = memref.load %1[%0] : memref<2xi32>
+  %4 = memref.load %2[%0] : memref<4xi32>
+  return %3, %4 : i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index d22d48b139a04..c8949e8d1af72 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1199,22 +1199,20 @@ void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
   // Not relevant for testing.
 }
 
-std::optional<PromotableAllocationOpInterface>
-TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
-                                             Value defaultValue,
-                                             OpBuilder &builder) {
-  if (defaultValue && defaultValue.use_empty())
-    defaultValue.getDefiningOp()->erase();
+/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
+static std::optional<TestMultiSlotAlloca>
+createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
+                                TestMultiSlotAlloca oldOp) {
 
-  if (getNumResults() == 1) {
-    erase();
+  if (oldOp.getNumResults() == 1) {
+    oldOp.erase();
     return std::nullopt;
   }
 
   SmallVector<Type> newTypes;
   SmallVector<Value> remainingValues;
 
-  for (Value oldResult : getResults()) {
+  for (Value oldResult : oldOp.getResults()) {
     if (oldResult == slot.ptr)
       continue;
     remainingValues.push_back(oldResult);
@@ -1222,12 +1220,69 @@ TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
   }
 
   OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPoint(*this);
-  auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
+  builder.setInsertionPoint(oldOp);
+  auto replacement =
+      builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
   for (auto [oldResult, newResult] :
        llvm::zip_equal(remainingValues, replacement.getResults()))
     oldResult.replaceAllUsesWith(newResult);
 
-  erase();
+  oldOp.erase();
   return replacement;
 }
+
+std::optional<PromotableAllocationOpInterface>
+TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
+                                             Value defaultValue,
+                                             OpBuilder &builder) {
+  if (defaultValue && defaultValue.use_empty())
+    defaultValue.getDefiningOp()->erase();
+  return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
+
+SmallVector<DestructurableMemorySlot>
+TestMultiSlotAlloca::getDestructurableSlots() {
+  SmallVector<DestructurableMemorySlot> slots;
+  for (Value result : getResults()) {
+    auto memrefType = cast<MemRefType>(result.getType());
+    auto destructurable =
+        llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
+    if (!destructurable)
+      continue;
+
+    std::optional<DenseMap<Attribute, Type>> destructuredType =
+        destructurable.getSubelementIndexMap();
+    if (!destructuredType)
+      continue;
+    slots.emplace_back(
+        DestructurableMemorySlot{{result, memrefType}, *destructuredType});
+  }
+  return slots;
+}
+
+DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
+    const DestructurableMemorySlot &slot,
+    const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointAfter(*this);
+
+  DenseMap<Attribute, MemorySlot> slotMap;
+
+  for (Attribute usedIndex : usedIndices) {
+    Type elemType = slot.elementPtrs.lookup(usedIndex);
+    MemRefType elemPtr = MemRefType::get({}, elemType);
+    auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
+    newAllocators.push_back(subAlloca);
+    slotMap.try_emplace<MemorySlot>(usedIndex,
+                                    {subAlloca.getResult(0), elemType});
+  }
+
+  return slotMap;
+}
+
+std::optional<DestructurableAllocationOpInterface>
+TestMultiSlotAlloca::handleDestructuringComplete(
+    const DestructurableMemorySlot &slot, OpBuilder &builder) {
+  return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e16ea2407314e..7fc3d22d18958 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3169,11 +3169,12 @@ def TestOpOptionallyImplementingInterface
 }
 
 //===----------------------------------------------------------------------===//
-// Test Mem2Reg
+// Test Mem2Reg & SROA
 //===----------------------------------------------------------------------===//
 
 def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
-    [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+    [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
+     DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]> {
   let results = (outs Variadic<MemRefOf<[I32]>>:$results);
   let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
 }

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

mlir/test/lib/Dialect/Test/TestOpDefs.cpp Outdated Show resolved Hide resolved
@Dinistro Dinistro merged commit 0b5b202 into main May 14, 2024
3 of 4 checks passed
@Dinistro Dinistro deleted the users/dinistro/sroa-retry-without-rewalking branch May 14, 2024 08:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants