From 0ab5c83ac7067140c564470f0b979d0fc444ab80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 18 Apr 2023 10:04:56 +0200 Subject: [PATCH 01/25] first draft for SROA interfaces --- .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 2 +- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 2 +- mlir/include/mlir/Interfaces/CMakeLists.txt | 2 +- ...RegInterfaces.h => MemorySlotInterfaces.h} | 14 +++- ...gInterfaces.td => MemorySlotInterfaces.td} | 79 ++++++++++++++++++- mlir/include/mlir/Transforms/Mem2Reg.h | 2 +- mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 4 +- .../{LLVMMem2Reg.cpp => LLVMMemorySlot.cpp} | 7 +- mlir/lib/Interfaces/CMakeLists.txt | 4 +- ...nterfaces.cpp => MemorySlotInterfaces.cpp} | 6 +- mlir/lib/Transforms/CMakeLists.txt | 2 +- 11 files changed, 101 insertions(+), 23 deletions(-) rename mlir/include/mlir/Interfaces/{Mem2RegInterfaces.h => MemorySlotInterfaces.h} (69%) rename mlir/include/mlir/Interfaces/{Mem2RegInterfaces.td => MemorySlotInterfaces.td} (75%) rename mlir/lib/Dialect/LLVMIR/IR/{LLVMMem2Reg.cpp => LLVMMemorySlot.cpp} (97%) rename mlir/lib/Interfaces/{Mem2RegInterfaces.cpp => MemorySlotInterfaces.cpp} (64%) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 4b30d0c164c81..d4849d4c44d87 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -6,7 +6,7 @@ include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td" include "mlir/Dialect/LLVMIR/LLVMEnums.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/Mem2RegInterfaces.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" // Operations that correspond to LLVM intrinsics. With MLIR operation set being // extendable, there is no reason to introduce a hard boundary between "core" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index f9025e01c1f69..fefc602ae3729 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -22,7 +22,7 @@ include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/Mem2RegInterfaces.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" class LLVM_Builder { diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index 4e0f7ac5a0400..f18a87aeae71e 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -7,7 +7,7 @@ add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) -add_mlir_interface(Mem2RegInterfaces) +add_mlir_interface(MemorySlotInterfaces) add_mlir_interface(ParallelCombiningOpInterface) add_mlir_interface(RuntimeVerifiableOpInterface) add_mlir_interface(ShapedOpInterfaces) diff --git a/mlir/include/mlir/Interfaces/Mem2RegInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h similarity index 69% rename from mlir/include/mlir/Interfaces/Mem2RegInterfaces.h rename to mlir/include/mlir/Interfaces/MemorySlotInterfaces.h index c962d98624b42..1c0efee386eaf 100644 --- a/mlir/include/mlir/Interfaces/Mem2RegInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_INTERFACES_MEM2REGINTERFACES_H -#define MLIR_INTERFACES_MEM2REGINTERFACES_H +#ifndef MLIR_INTERFACES_MEMORYSLOTINTERFACES_H +#define MLIR_INTERFACES_MEMORYSLOTINTERFACES_H #include "mlir/IR/Dominance.h" #include "mlir/IR/OpDefinition.h" @@ -23,6 +23,12 @@ struct MemorySlot { Type elemType; }; +struct DestructibleMemorySlot : public MemorySlot { + /// Maps from an index within the memory slot the type of the pointer that + /// will be generated to access the element directly. + DenseMap elementsPtrs; +}; + /// Returned by operation promotion logic requesting the deletion of an /// operation. enum class DeletionKind { @@ -34,6 +40,6 @@ enum class DeletionKind { } // namespace mlir -#include "mlir/Interfaces/Mem2RegInterfaces.h.inc" +#include "mlir/Interfaces/MemorySlotInterfaces.h.inc" -#endif // MLIR_INTERFACES_MEM2REGINTERFACES_H +#endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/Mem2RegInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td similarity index 75% rename from mlir/include/mlir/Interfaces/Mem2RegInterfaces.td rename to mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index b0d0a8cb54bdd..a94b03914dc00 100644 --- a/mlir/include/mlir/Interfaces/Mem2RegInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -1,4 +1,4 @@ -//===-- Mem2RegInterfaces.td - Mem2Reg interfaces ----------*- tablegen -*-===// +//===-- MemorySlotInterfaces.td - MemorySlot interfaces ----*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_INTERFACES_MEM2REGINTERFACES -#define MLIR_INTERFACES_MEM2REGINTERFACES +#ifndef MLIR_INTERFACES_MEMORYSLOTINTERFACES +#define MLIR_INTERFACES_MEMORYSLOTINTERFACES include "mlir/IR/OpBase.td" @@ -193,4 +193,75 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { ]; } -#endif // MLIR_INTERFACES_MEM2REGINTERFACES +def DestructibleAllocationOpInterface + : OpInterface<"DestructibleAllocationOpInterface"> { + let description = [{ + TODO + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns the list of slots for which destruction should be attempted, + specifying in which way the slot should be destructed into subslots. The + subslots are indexed by attributes. The type of the pointers of each + subslots to be generated must be provided. + }], + "::llvm::SmallVector<::mlir::DestructibleMemorySlot>", + "getDestructibleSlots", + (ins) + >, + InterfaceMethod<[{ + Destructures this slot into multiple subslots. The newly generated slots + may belong to a different allocator. The original slot must still exist + at the end of this call. + + The builder is located at the beginning of the block where the slot + pointer is defined. + }], + "::llvm::SmallVector<::mlir::MemorySlot>", + "destruct", + (ins "const ::mlir::DestructibleMemorySlot &":$slot, + "::mlir::OpBuilder &":$builder) + >, + InterfaceMethod<[{ + Hook triggered once the destruction of a slot is complete, meaning the + original slot is no longer being refered to and could be deleted. + This will only be called for slots declared by this operation. + }], + "void", "handleDestructionComplete", + (ins "const ::mlir::DestructibleMemorySlot &":$slot) + >, + ]; +} + +def DestructibleAccessorOpInterface + : OpInterface<"DestructibleAccessorOpInterface"> { + let description = [{ + TODO + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + For a given destructible memory slot, returns whether this operation can + rewire its uses of the slot to use the slots generated after + destruction. This may involve creating new operations, and usually + amounts to checking the pointer types match. + }], + "::llvm::SmallVector<::mlir::MemorySlot>", + "canRewire", + (ins "const ::mlir::DestructibleMemorySlot &":$slot) + >, + InterfaceMethod<[{ + Rewires the use of a slot to the generated subslots. + }], + "void", + "rewire", + (ins "const ::mlir::DestructibleMemorySlot &":$slot, + "::llvm::DenseMap<::mlir::Attribute, ::mlir::Value> &":$subslots) + > + ]; +} + +#endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h index e2da88f7e00c7..0d037c3e4283a 100644 --- a/mlir/include/mlir/Transforms/Mem2Reg.h +++ b/mlir/include/mlir/Transforms/Mem2Reg.h @@ -11,7 +11,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/Mem2RegInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" namespace mlir { diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index 9ba03153536c6..7bf95951f1205 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -6,7 +6,7 @@ add_mlir_dialect_library(MLIRLLVMDialect IR/LLVMDialect.cpp IR/LLVMInlining.cpp IR/LLVMInterfaces.cpp - IR/LLVMMem2Reg.cpp + IR/LLVMMemorySlot.cpp IR/LLVMTypes.cpp IR/LLVMTypeSyntax.cpp @@ -34,7 +34,7 @@ add_mlir_dialect_library(MLIRLLVMDialect MLIRDataLayoutInterfaces MLIRInferTypeOpInterface MLIRIR - MLIRMem2RegInterfaces + MLIRMemorySlotInterfaces MLIRSideEffectInterfaces MLIRSupport ) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMem2Reg.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp similarity index 97% rename from mlir/lib/Dialect/LLVMIR/IR/LLVMMem2Reg.cpp rename to mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 7fa8ebceed5fb..ddcf2d91b7cd1 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMem2Reg.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -1,4 +1,4 @@ -//===- LLVMMem2Reg.cpp - Mem2Reg Interfaces ---------------------*- C++ -*-===// +//===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,13 @@ // //===----------------------------------------------------------------------===// // -// This file implements Mem2Reg-related interfaces for LLVM dialect operations. +// This file implements MemorySlot-related interfaces for LLVM dialect +// operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Interfaces/Mem2RegInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" using namespace mlir; diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index dbf6e69a45255..665e4c04e8464 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -9,7 +9,7 @@ set(LLVM_OPTIONAL_SOURCES InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp - Mem2RegInterfaces.cpp + MemorySlotInterfaces.cpp ParallelCombiningOpInterface.cpp RuntimeVerifiableOpInterface.cpp ShapedOpInterfaces.cpp @@ -46,7 +46,7 @@ add_mlir_interface_library(DestinationStyleOpInterface) add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) -add_mlir_interface_library(Mem2RegInterfaces) +add_mlir_interface_library(MemorySlotInterfaces) add_mlir_interface_library(ParallelCombiningOpInterface) add_mlir_interface_library(RuntimeVerifiableOpInterface) add_mlir_interface_library(ShapedOpInterfaces) diff --git a/mlir/lib/Interfaces/Mem2RegInterfaces.cpp b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp similarity index 64% rename from mlir/lib/Interfaces/Mem2RegInterfaces.cpp rename to mlir/lib/Interfaces/MemorySlotInterfaces.cpp index aadd76b44df53..ad15296389797 100644 --- a/mlir/lib/Interfaces/Mem2RegInterfaces.cpp +++ b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp @@ -1,4 +1,4 @@ -//===-- Mem2RegInterfaces.cpp - Mem2Reg interfaces --------------*- C++ -*-===// +//===-- MemorySlotInterfaces.cpp - MemorySlot interfaces --------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,6 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Interfaces/Mem2RegInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" -#include "mlir/Interfaces/Mem2RegInterfaces.cpp.inc" +#include "mlir/Interfaces/MemorySlotInterfaces.cpp.inc" diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 7b4fb4d6df881..b7e1cd927d6b9 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -28,7 +28,7 @@ add_mlir_library(MLIRTransforms MLIRAnalysis MLIRCopyOpInterface MLIRLoopLikeInterface - MLIRMem2RegInterfaces + MLIRMemorySlotInterfaces MLIRPass MLIRRuntimeVerifiableOpInterface MLIRSideEffectInterfaces From beb7682f463f263771fb6d256f1eac8a0da7a2b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 18 Apr 2023 11:08:30 +0200 Subject: [PATCH 02/25] decouple mem2reg analysis and promotion --- mlir/lib/Transforms/Mem2Reg.cpp | 168 +++++++++++++++++++------------- 1 file changed, 99 insertions(+), 69 deletions(-) diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index 5952de9ddb63b..7b3c241734d33 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -9,6 +9,7 @@ #include "mlir/Transforms/Mem2Reg.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" @@ -42,7 +43,10 @@ using namespace mlir; /// this, the value stored can be well defined at block boundaries, allowing /// the propagation of replacement through blocks. /// -/// This pass computes this transformation in four main steps: +/// This pass computes this transformation in four main steps. The two first +/// steps are performed during an analysis phase that does not mutate IR. +/// +/// The two steps of the analysis phase are the following: /// - A first step computes the list of operations that transitively use the /// memory slot we would like to promote. The purpose of this phase is to /// identify which uses must be removed to promote the slot, either by rewiring @@ -60,6 +64,9 @@ using namespace mlir; /// existing. Computing this information in advance allows making sure the /// terminators that will forward values are capable of doing so (inability to /// do so aborts promotion at this step). +/// +/// At this point, promotion is guaranteed to happen, and the mutation phase can +/// begin with the following steps: /// - A third step computes the reaching definition of the memory slot at each /// blocking user. This is the core of the mem2reg algorithm, also known as /// load-store forwarding. This analyses loads and stores and propagates which @@ -73,10 +80,6 @@ using namespace mlir; /// - The final fourth step uses the reaching definition to remove blocking uses /// in topological order. /// -/// The two first steps do not mutate IR because promotion can still be aborted -/// at this point. Once the two last steps are reached, promotion is guaranteed -/// to succeed, allowing to start mutating IR. -/// /// For further reading, chapter three of SSA-based Compiler Design [1] /// showcases SSA construction, where mem2reg is an adaptation of the same /// process. @@ -86,26 +89,31 @@ using namespace mlir; namespace { -/// The SlotPromoter handles the state of promoting a memory slot. It wraps a -/// slot and its associated allocator, along with analysis results related to -/// the slot. -class SlotPromoter { +/// Metadata computed during promotion analysis used to compute actual +/// promotion. +struct SlotPromotionMetadata { + /// Blocks for which at least two definitions of the slot values clash. + SmallPtrSet mergePoints; + /// Contains, for each operation, which uses must be eliminated by promotion. + /// This is a DAG structure because an operation that must eliminate some of + /// its uses always comes from a request from an operation that must + /// eliminate some of its own uses. + DenseMap> userToBlockingUses; +}; + +/// Computes metadata for basic slot promotion. This will check that direct slot +/// promotion can be performed, and provide the metadata to execute the +/// promotion. This does not mutate IR. +class SlotPromotionAnalyzer { public: - SlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, - OpBuilder &builder, DominanceInfo &dominance); + SlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance) + : slot(slot), dominance(dominance) {} - /// Prepare data for the promotion of the slot while checking if it can be - /// promoted. Succeeds if the slot can be promoted. This method does not - /// mutate IR. - LogicalResult prepareSlotPromotion(); - - /// Actually promotes the slot by mutating IR. This method must only be - /// called after a successful call to `SlotPromoter::prepareSlotPromotion`. - /// Promoting a slot does not invalidate the preparation of other slots. - void promoteSlot(); + /// Computes the metadata for slot promotion if promotion is possible, returns + /// nothing otherwise. + Optional computeMetadata(); private: - /// This is the first step of the promotion algorithm. /// Computes the transitive uses of the slot that block promotion. This finds /// uses that would block the promotion, checks that the operation has a /// solution to remove the blocking use, and potentially forwards the analysis @@ -113,7 +121,8 @@ class SlotPromoter { /// uses (typically, removing its users because it will delete itself to /// resolve its own blocking uses). This will fail if one of the transitive /// users cannot remove a requested use, and should prevent promotion. - LogicalResult computeBlockingUses(); + LogicalResult computeBlockingUses( + DenseMap> &userToBlockingUses); /// Computes in which blocks the value stored in the slot is actually used, /// meaning blocks leading to a load. This method uses `definingBlocks`, the @@ -122,30 +131,45 @@ class SlotPromoter { SmallPtrSet computeSlotLiveIn(SmallPtrSetImpl &definingBlocks); - /// This is the second step of the promotion algorithm. /// Computes the points in which multiple re-definitions of the slot's value /// (stores) may conflict. - void computeMergePoints(); + void computeMergePoints(SmallPtrSetImpl &mergePoints); /// Ensures predecessors of merge points can properly provide their current /// definition of the value stored in the slot to the merge point. This can /// notably be an issue if the terminator used does not have the ability to /// forward values through block operands. - bool areMergePointsUsable(); + bool areMergePointsUsable(SmallPtrSetImpl &mergePoints); + + MemorySlot slot; + DominanceInfo &dominance; +}; +/// The SlotPromoter handles the state of promoting a memory slot. It wraps a +/// slot and its associated allocator. This will perform the mutation of IR. +class SlotPromoter { +public: + SlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, + OpBuilder &builder, DominanceInfo &dominance, + SlotPromotionMetadata metadata); + + /// Actually promotes the slot by mutating IR. This method must only be + /// called after a successful call to `SlotPromoter::prepareSlotPromotion`. + /// Promoting a slot does not invalidate the preparation of other slots. + void promoteSlot(); + +private: /// Computes the reaching definition for all the operations that require /// promotion. `reachingDef` is the value the slot should contain at the /// beginning of the block. This method returns the reached definition at the /// end of the block. Value computeReachingDefInBlock(Block *block, Value reachingDef); - /// This is the third step of the promotion algorithm. /// Computes the reaching definition for all the operations that require /// promotion. `reachingDef` corresponds to the initial value the /// slot will contain before any write, typically a poison value. void computeReachingDefInRegion(Region *region, Value reachingDef); - /// This is the fourth step of the promotion algorithm. /// Removes the blocking uses of the slot, in topological order. void removeBlockingUses(); @@ -156,28 +180,24 @@ class SlotPromoter { MemorySlot slot; PromotableAllocationOpInterface allocator; OpBuilder &builder; - /// Potentially non-initialized default value. Use `lazyDefaultValue` to + /// Potentially non-initialized default value. Use `getLazyDefaultValue` to /// initialize it on demand. Value defaultValue; - /// Blocks where multiple definitions of the slot value clash. - SmallPtrSet mergePoints; - /// Contains, for each operation, which uses must be eliminated by promotion. - /// This is a DAG structure because an operation that must eliminate some of - /// its uses always comes from a request from an operation that must - /// eliminate some of its own uses. - DenseMap> userToBlockingUses; /// Contains the reaching definition at this operation. Reaching definitions /// are only computed for promotable memory operations with blocking uses. DenseMap reachingDefs; DominanceInfo &dominance; + SlotPromotionMetadata metadata; }; } // namespace SlotPromoter::SlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, - OpBuilder &builder, DominanceInfo &dominance) - : slot(slot), allocator(allocator), builder(builder), dominance(dominance) { + OpBuilder &builder, DominanceInfo &dominance, + SlotPromotionMetadata metadata) + : slot(slot), allocator(allocator), builder(builder), dominance(dominance), + metadata(std::move(metadata)) { bool isResultOrNewBlockArgument = slot.ptr.getDefiningOp() == allocator; if (BlockArgument arg = slot.ptr.dyn_cast()) isResultOrNewBlockArgument = isResultOrNewBlockArgument || @@ -197,7 +217,8 @@ Value SlotPromoter::getLazyDefaultValue() { return defaultValue = allocator.getDefaultValue(slot, builder); } -LogicalResult SlotPromoter::computeBlockingUses() { +LogicalResult SlotPromotionAnalyzer::computeBlockingUses( + DenseMap> &userToBlockingUses) { // The promotion of an operation may require the promotion of further // operations (typically, removing operations that use an operation that must // delete itself). We thus need to start from the use of the slot pointer and @@ -213,7 +234,7 @@ LogicalResult SlotPromoter::computeBlockingUses() { // Then, propagate the requirements for the removal of uses. The // topologically-sorted forward slice allows for all blocking uses of an - // operation to have been computed before we reach it. Operations are + // operation to have been computed before it is reached. Operations are // traversed in topological order of their uses, starting from the slot // pointer. SetVector forwardSlice; @@ -263,8 +284,8 @@ LogicalResult SlotPromoter::computeBlockingUses() { return success(); } -SmallPtrSet -SlotPromoter::computeSlotLiveIn(SmallPtrSetImpl &definingBlocks) { +SmallPtrSet SlotPromotionAnalyzer::computeSlotLiveIn( + SmallPtrSetImpl &definingBlocks) { SmallPtrSet liveIn; // The worklist contains blocks in which it is known that the slot value is @@ -321,7 +342,8 @@ SlotPromoter::computeSlotLiveIn(SmallPtrSetImpl &definingBlocks) { } using IDFCalculator = llvm::IDFCalculatorBase; -void SlotPromoter::computeMergePoints() { +void SlotPromotionAnalyzer::computeMergePoints( + SmallPtrSetImpl &mergePoints) { if (slot.ptr.getParentRegion()->hasOneBlock()) return; @@ -344,7 +366,8 @@ void SlotPromoter::computeMergePoints() { mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end()); } -bool SlotPromoter::areMergePointsUsable() { +bool SlotPromotionAnalyzer::areMergePointsUsable( + SmallPtrSetImpl &mergePoints) { for (Block *mergePoint : mergePoints) for (Block *pred : mergePoint->getPredecessors()) if (!isa(pred->getTerminator())) @@ -353,6 +376,30 @@ bool SlotPromoter::areMergePointsUsable() { return true; } +Optional SlotPromotionAnalyzer::computeMetadata() { + SlotPromotionMetadata metadata; + + // First, find the set of operations that will need to be changed for the + // promotion to happen. These operations need to resolve some of their uses, + // either by rewiring them or simply deleting themselves. If any of them + // cannot find a way to resolve their blocking uses, we abort the promotion. + if (failed(computeBlockingUses(metadata.userToBlockingUses))) + return {}; + + // Then, compute blocks in which two or more definitions of the allocated + // variable may conflict. These blocks will need a new block argument to + // accomodate this. + computeMergePoints(metadata.mergePoints); + + // The slot can be promoted if the block arguments to be created can + // actually be populated with values, which may not be possible depending + // on their predecessors. + if (!areMergePointsUsable(metadata.mergePoints)) + return {}; + + return {std::move(metadata)}; +} + Value SlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) { for (Operation &op : block->getOperations()) { if (auto memOp = dyn_cast(op)) { @@ -390,7 +437,7 @@ void SlotPromoter::computeReachingDefInRegion(Region *region, DfsJob job = dfsStack.pop_back_val(); Block *block = job.block->getBlock(); - if (mergePoints.contains(block)) { + if (metadata.mergePoints.contains(block)) { BlockArgument blockArgument = block->addArgument(slot.elemType, slot.ptr.getLoc()); builder.setInsertionPointToStart(block); @@ -402,7 +449,7 @@ void SlotPromoter::computeReachingDefInRegion(Region *region, if (auto terminator = dyn_cast(block->getTerminator())) { for (BlockOperand &blockOperand : terminator->getBlockOperands()) { - if (mergePoints.contains(blockOperand.get())) { + if (metadata.mergePoints.contains(blockOperand.get())) { if (!job.reachingDef) job.reachingDef = getLazyDefaultValue(); terminator.getSuccessorOperands(blockOperand.getOperandNumber()) @@ -418,7 +465,7 @@ void SlotPromoter::computeReachingDefInRegion(Region *region, void SlotPromoter::removeBlockingUses() { llvm::SetVector usersToRemoveUses; - for (auto &user : llvm::make_first_range(userToBlockingUses)) + for (auto &user : llvm::make_first_range(metadata.userToBlockingUses)) usersToRemoveUses.insert(user); SetVector sortedUsersToRemoveUses = mlir::topologicalSort(usersToRemoveUses); @@ -463,7 +510,7 @@ void SlotPromoter::promoteSlot() { // Update terminators in dead branches to forward default if they are // succeeded by a merge points. - for (Block *mergePoint : mergePoints) { + for (Block *mergePoint : metadata.mergePoints) { for (BlockOperand &use : mergePoint->getUses()) { auto user = cast(use.getOwner()); SuccessorOperands succOperands = @@ -478,25 +525,6 @@ void SlotPromoter::promoteSlot() { allocator.handlePromotionComplete(slot, defaultValue); } -LogicalResult SlotPromoter::prepareSlotPromotion() { - // First, find the set of operations that will need to be changed for the - // promotion to happen. These operations need to resolve some of their uses, - // either by rewiring them or simply deleting themselves. If any of them - // cannot find a way to resolve their blocking uses, we abort the promotion. - if (failed(computeBlockingUses())) - return failure(); - - // Then, compute blocks in which two or more definitions of the allocated - // variable may conflict. These blocks will need a new block argument to - // accomodate this. - computeMergePoints(); - - // The slot can be promoted if the block arguments to be created can - // actually be populated with values, which may not be possible depending - // on their predecessors. - return success(areMergePointsUsable()); -} - LogicalResult mlir::tryToPromoteMemorySlots( ArrayRef allocators, OpBuilder &builder, DominanceInfo &dominance) { @@ -508,9 +536,11 @@ LogicalResult mlir::tryToPromoteMemorySlots( if (slot.ptr.use_empty()) continue; - SlotPromoter promoter(slot, allocator, builder, dominance); - if (succeeded(promoter.prepareSlotPromotion())) - toPromote.emplace_back(std::move(promoter)); + SlotPromotionAnalyzer analyzer(slot, dominance); + Optional metadata = analyzer.computeMetadata(); + if (metadata.has_value()) + toPromote.emplace_back(slot, allocator, builder, dominance, + std::move(metadata.value())); } } From b03a2e65d5ae7d2beae3ef60beb869fe5995960c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 18 Apr 2023 14:00:47 +0200 Subject: [PATCH 03/25] address comments --- mlir/lib/Transforms/Mem2Reg.cpp | 62 ++++++++++++++++----------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index 7b3c241734d33..d7fb5d9f02eb8 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -89,9 +89,9 @@ using namespace mlir; namespace { -/// Metadata computed during promotion analysis used to compute actual +/// Information computed during promotion analysis used to perform actual /// promotion. -struct SlotPromotionMetadata { +struct SlotPromotionInfo { /// Blocks for which at least two definitions of the slot values clash. SmallPtrSet mergePoints; /// Contains, for each operation, which uses must be eliminated by promotion. @@ -101,17 +101,17 @@ struct SlotPromotionMetadata { DenseMap> userToBlockingUses; }; -/// Computes metadata for basic slot promotion. This will check that direct slot -/// promotion can be performed, and provide the metadata to execute the +/// Computes information for basic slot promotion. This will check that direct +/// slot promotion can be performed, and provide the information to execute the /// promotion. This does not mutate IR. class SlotPromotionAnalyzer { public: SlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance) : slot(slot), dominance(dominance) {} - /// Computes the metadata for slot promotion if promotion is possible, returns - /// nothing otherwise. - Optional computeMetadata(); + /// Computes the information for slot promotion if promotion is possible, + /// returns nothing otherwise. + Optional computeInfo(); private: /// Computes the transitive uses of the slot that block promotion. This finds @@ -151,11 +151,10 @@ class SlotPromoter { public: SlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, OpBuilder &builder, DominanceInfo &dominance, - SlotPromotionMetadata metadata); + SlotPromotionInfo info); - /// Actually promotes the slot by mutating IR. This method must only be - /// called after a successful call to `SlotPromoter::prepareSlotPromotion`. - /// Promoting a slot does not invalidate the preparation of other slots. + /// Actually promotes the slot by mutating IR. Promoting a slot does not + /// invalidate the SlotPromotionInfo of other slots. void promoteSlot(); private: @@ -187,7 +186,7 @@ class SlotPromoter { /// are only computed for promotable memory operations with blocking uses. DenseMap reachingDefs; DominanceInfo &dominance; - SlotPromotionMetadata metadata; + SlotPromotionInfo info; }; } // namespace @@ -195,9 +194,9 @@ class SlotPromoter { SlotPromoter::SlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, OpBuilder &builder, DominanceInfo &dominance, - SlotPromotionMetadata metadata) + SlotPromotionInfo info) : slot(slot), allocator(allocator), builder(builder), dominance(dominance), - metadata(std::move(metadata)) { + info(std::move(info)) { bool isResultOrNewBlockArgument = slot.ptr.getDefiningOp() == allocator; if (BlockArgument arg = slot.ptr.dyn_cast()) isResultOrNewBlockArgument = isResultOrNewBlockArgument || @@ -376,34 +375,34 @@ bool SlotPromotionAnalyzer::areMergePointsUsable( return true; } -Optional SlotPromotionAnalyzer::computeMetadata() { - SlotPromotionMetadata metadata; +Optional SlotPromotionAnalyzer::computeInfo() { + SlotPromotionInfo info; // First, find the set of operations that will need to be changed for the // promotion to happen. These operations need to resolve some of their uses, // either by rewiring them or simply deleting themselves. If any of them // cannot find a way to resolve their blocking uses, we abort the promotion. - if (failed(computeBlockingUses(metadata.userToBlockingUses))) + if (failed(computeBlockingUses(info.userToBlockingUses))) return {}; // Then, compute blocks in which two or more definitions of the allocated // variable may conflict. These blocks will need a new block argument to // accomodate this. - computeMergePoints(metadata.mergePoints); + computeMergePoints(info.mergePoints); // The slot can be promoted if the block arguments to be created can // actually be populated with values, which may not be possible depending // on their predecessors. - if (!areMergePointsUsable(metadata.mergePoints)) + if (!areMergePointsUsable(info.mergePoints)) return {}; - return {std::move(metadata)}; + return {std::move(info)}; } Value SlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) { for (Operation &op : block->getOperations()) { if (auto memOp = dyn_cast(op)) { - if (userToBlockingUses.contains(memOp)) + if (info.userToBlockingUses.contains(memOp)) reachingDefs.insert({memOp, reachingDef}); if (Value stored = memOp.getStored(slot)) @@ -437,7 +436,7 @@ void SlotPromoter::computeReachingDefInRegion(Region *region, DfsJob job = dfsStack.pop_back_val(); Block *block = job.block->getBlock(); - if (metadata.mergePoints.contains(block)) { + if (info.mergePoints.contains(block)) { BlockArgument blockArgument = block->addArgument(slot.elemType, slot.ptr.getLoc()); builder.setInsertionPointToStart(block); @@ -449,7 +448,7 @@ void SlotPromoter::computeReachingDefInRegion(Region *region, if (auto terminator = dyn_cast(block->getTerminator())) { for (BlockOperand &blockOperand : terminator->getBlockOperands()) { - if (metadata.mergePoints.contains(blockOperand.get())) { + if (info.mergePoints.contains(blockOperand.get())) { if (!job.reachingDef) job.reachingDef = getLazyDefaultValue(); terminator.getSuccessorOperands(blockOperand.getOperandNumber()) @@ -465,7 +464,7 @@ void SlotPromoter::computeReachingDefInRegion(Region *region, void SlotPromoter::removeBlockingUses() { llvm::SetVector usersToRemoveUses; - for (auto &user : llvm::make_first_range(metadata.userToBlockingUses)) + for (auto &user : llvm::make_first_range(info.userToBlockingUses)) usersToRemoveUses.insert(user); SetVector sortedUsersToRemoveUses = mlir::topologicalSort(usersToRemoveUses); @@ -480,8 +479,8 @@ void SlotPromoter::removeBlockingUses() { reachingDef = getLazyDefaultValue(); builder.setInsertionPointAfter(toPromote); - if (toPromoteMemOp.removeBlockingUses(slot, userToBlockingUses[toPromote], - builder, reachingDef) == + if (toPromoteMemOp.removeBlockingUses( + slot, info.userToBlockingUses[toPromote], builder, reachingDef) == DeletionKind::Delete) toErase.push_back(toPromote); @@ -490,7 +489,8 @@ void SlotPromoter::removeBlockingUses() { auto toPromoteBasic = cast(toPromote); builder.setInsertionPointAfter(toPromote); - if (toPromoteBasic.removeBlockingUses(slot, userToBlockingUses[toPromote], + if (toPromoteBasic.removeBlockingUses(slot, + info.userToBlockingUses[toPromote], builder) == DeletionKind::Delete) toErase.push_back(toPromote); } @@ -510,7 +510,7 @@ void SlotPromoter::promoteSlot() { // Update terminators in dead branches to forward default if they are // succeeded by a merge points. - for (Block *mergePoint : metadata.mergePoints) { + for (Block *mergePoint : info.mergePoints) { for (BlockOperand &use : mergePoint->getUses()) { auto user = cast(use.getOwner()); SuccessorOperands succOperands = @@ -537,10 +537,10 @@ LogicalResult mlir::tryToPromoteMemorySlots( continue; SlotPromotionAnalyzer analyzer(slot, dominance); - Optional metadata = analyzer.computeMetadata(); - if (metadata.has_value()) + Optional info = analyzer.computeInfo(); + if (info.has_value()) toPromote.emplace_back(slot, allocator, builder, dominance, - std::move(metadata.value())); + std::move(info.value())); } } From 4e5475fe7c65392445034f1fa2eda818112f8110 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Thu, 20 Apr 2023 10:37:18 +0200 Subject: [PATCH 04/25] first draft of SROA pass structure --- .../mlir/Interfaces/MemorySlotInterfaces.h | 5 + .../mlir/Interfaces/MemorySlotInterfaces.td | 6 +- mlir/include/mlir/Transforms/Mem2Reg.h | 101 +++++++++++++ mlir/include/mlir/Transforms/Passes.h | 1 + mlir/include/mlir/Transforms/Passes.td | 7 + mlir/include/mlir/Transforms/SROA.h | 43 ++++++ mlir/lib/Transforms/CMakeLists.txt | 1 + mlir/lib/Transforms/Mem2Reg.cpp | 138 +++--------------- mlir/lib/Transforms/SROA.cpp | 78 ++++++++++ 9 files changed, 258 insertions(+), 122 deletions(-) create mode 100644 mlir/include/mlir/Transforms/SROA.h create mode 100644 mlir/lib/Transforms/SROA.cpp diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h index 1c0efee386eaf..6f5ea6b42920f 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -29,6 +29,11 @@ struct DestructibleMemorySlot : public MemorySlot { DenseMap elementsPtrs; }; +struct SubElementMemorySlot : public MemorySlot { + /// Index of this memory slot in the parent memory slot. + Attribute subelementIndex; +}; + /// Returned by operation promotion logic requesting the deletion of an /// operation. enum class DeletionKind { diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index a94b03914dc00..1378752110d6a 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -248,8 +248,12 @@ def DestructibleAccessorOpInterface rewire its uses of the slot to use the slots generated after destruction. This may involve creating new operations, and usually amounts to checking the pointer types match. + + Returns memory slots representing the accessed subelements, if any. + Returned slots can only be aliased by other SubElementMemorySlots of the + same type coming from the same parent slot. }], - "::llvm::SmallVector<::mlir::MemorySlot>", + "::llvm::SmallVector<::mlir::SubElementMemorySlot>", "canRewire", (ins "const ::mlir::DestructibleMemorySlot &":$slot) >, diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h index 0d037c3e4283a..1a93c6736d07d 100644 --- a/mlir/include/mlir/Transforms/Mem2Reg.h +++ b/mlir/include/mlir/Transforms/Mem2Reg.h @@ -15,6 +15,107 @@ namespace mlir { +/// Information computed during promotion analysis used to perform actual +/// promotion. +struct MemorySlotPromotionInfo { + /// Blocks for which at least two definitions of the slot values clash. + SmallPtrSet mergePoints; + /// Contains, for each operation, which uses must be eliminated by promotion. + /// This is a DAG structure because an operation that must eliminate some of + /// its uses always comes from a request from an operation that must + /// eliminate some of its own uses. + DenseMap> userToBlockingUses; +}; + +/// Computes information for basic slot promotion. This will check that direct +/// slot promotion can be performed, and provide the information to execute the +/// promotion. This does not mutate IR. +class MemorySlotPromotionAnalyzer { +public: + MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance) + : slot(slot), dominance(dominance) {} + + /// Computes the information for slot promotion if promotion is possible, + /// returns nothing otherwise. + Optional computeInfo(); + +private: + /// Computes the transitive uses of the slot that block promotion. This finds + /// uses that would block the promotion, checks that the operation has a + /// solution to remove the blocking use, and potentially forwards the analysis + /// if the operation needs further blocking uses resolved to resolve its own + /// uses (typically, removing its users because it will delete itself to + /// resolve its own blocking uses). This will fail if one of the transitive + /// users cannot remove a requested use, and should prevent promotion. + LogicalResult computeBlockingUses( + DenseMap> &userToBlockingUses); + + /// Computes in which blocks the value stored in the slot is actually used, + /// meaning blocks leading to a load. This method uses `definingBlocks`, the + /// set of blocks containing a store to the slot (defining the value of the + /// slot). + SmallPtrSet + computeSlotLiveIn(SmallPtrSetImpl &definingBlocks); + + /// Computes the points in which multiple re-definitions of the slot's value + /// (stores) may conflict. + void computeMergePoints(SmallPtrSetImpl &mergePoints); + + /// Ensures predecessors of merge points can properly provide their current + /// definition of the value stored in the slot to the merge point. This can + /// notably be an issue if the terminator used does not have the ability to + /// forward values through block operands. + bool areMergePointsUsable(SmallPtrSetImpl &mergePoints); + + MemorySlot slot; + DominanceInfo &dominance; +}; + +/// The MemorySlotPromoter handles the state of promoting a memory slot. It +/// wraps a slot and its associated allocator. This will perform the mutation of +/// IR. +class MemorySlotPromoter { +public: + MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, + OpBuilder &builder, DominanceInfo &dominance, + MemorySlotPromotionInfo info); + + /// Actually promotes the slot by mutating IR. Promoting a slot does not + /// invalidate the MemorySlotPromotionInfo of other slots. + void promoteSlot(); + +private: + /// Computes the reaching definition for all the operations that require + /// promotion. `reachingDef` is the value the slot should contain at the + /// beginning of the block. This method returns the reached definition at the + /// end of the block. + Value computeReachingDefInBlock(Block *block, Value reachingDef); + + /// Computes the reaching definition for all the operations that require + /// promotion. `reachingDef` corresponds to the initial value the + /// slot will contain before any write, typically a poison value. + void computeReachingDefInRegion(Region *region, Value reachingDef); + + /// Removes the blocking uses of the slot, in topological order. + void removeBlockingUses(); + + /// Lazily-constructed default value representing the content of the slot when + /// no store has been executed. This function may mutate IR. + Value getLazyDefaultValue(); + + MemorySlot slot; + PromotableAllocationOpInterface allocator; + OpBuilder &builder; + /// Potentially non-initialized default value. Use `getLazyDefaultValue` to + /// initialize it on demand. + Value defaultValue; + /// Contains the reaching definition at this operation. Reaching definitions + /// are only computed for promotable memory operations with blocking uses. + DenseMap reachingDefs; + DominanceInfo &dominance; + MemorySlotPromotionInfo info; +}; + /// Attempts to promote the memory slots of the provided allocators. Succeeds if /// at least one memory slot was promoted. LogicalResult diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index f5f76076c8e07..9110b64d55a63 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -36,6 +36,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_MEM2REG #define GEN_PASS_DECL_PRINTIRPASS #define GEN_PASS_DECL_PRINTOPSTATS +#define GEN_PASS_DECL_SROA #define GEN_PASS_DECL_STRIPDEBUGINFO #define GEN_PASS_DECL_SCCP #define GEN_PASS_DECL_SYMBOLDCE diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 1cc357ca1f9f4..00509f342e8b9 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -214,6 +214,13 @@ def SCCP : Pass<"sccp"> { let constructor = "mlir::createSCCPPass()"; } +def SROA : Pass<"sroa"> { + let summary = "TODO"; + let description = [{ + TODO + }]; +} + def StripDebugInfo : Pass<"strip-debuginfo"> { let summary = "Strip debug info from all operations"; let description = [{ diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h new file mode 100644 index 0000000000000..799c609f775da --- /dev/null +++ b/mlir/include/mlir/Transforms/SROA.h @@ -0,0 +1,43 @@ +//===-- SROA.h - Scalar Replacement Of Aggregates ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_SROA_H +#define MLIR_TRANSFORMS_SROA_H + +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { + +/// Computes information for slot destruction leading to promotion. This will +/// compute whether destructing this slot and subsequent new subslots will +/// lead to only promatble slots being generated. +class MemorySlotDestructionAnalyzer { +public: + MemorySlotDestructionAnalyzer(MemorySlot slot, DominanceInfo &dominance) + : slot(slot), dominance(dominance) {} + +private: + LogicalResult computeDestructionTree(); + + LogicalResult computeBlockingUses(); + + MemorySlot slot; + DominanceInfo &dominance; +}; + +class MemorySlotDestructor { +public: + void destructSlot(); +private: +}; + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_SROA_H diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index b7e1cd927d6b9..72dd7ab94e909 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_library(MLIRTransforms OpStats.cpp PrintIR.cpp SCCP.cpp + SROA.cpp StripDebugInfo.cpp SymbolDCE.cpp SymbolPrivatize.cpp diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index d7fb5d9f02eb8..d7e8aebe6e022 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -87,114 +87,10 @@ using namespace mlir; /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), /// Springer. -namespace { - -/// Information computed during promotion analysis used to perform actual -/// promotion. -struct SlotPromotionInfo { - /// Blocks for which at least two definitions of the slot values clash. - SmallPtrSet mergePoints; - /// Contains, for each operation, which uses must be eliminated by promotion. - /// This is a DAG structure because an operation that must eliminate some of - /// its uses always comes from a request from an operation that must - /// eliminate some of its own uses. - DenseMap> userToBlockingUses; -}; - -/// Computes information for basic slot promotion. This will check that direct -/// slot promotion can be performed, and provide the information to execute the -/// promotion. This does not mutate IR. -class SlotPromotionAnalyzer { -public: - SlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance) - : slot(slot), dominance(dominance) {} - - /// Computes the information for slot promotion if promotion is possible, - /// returns nothing otherwise. - Optional computeInfo(); - -private: - /// Computes the transitive uses of the slot that block promotion. This finds - /// uses that would block the promotion, checks that the operation has a - /// solution to remove the blocking use, and potentially forwards the analysis - /// if the operation needs further blocking uses resolved to resolve its own - /// uses (typically, removing its users because it will delete itself to - /// resolve its own blocking uses). This will fail if one of the transitive - /// users cannot remove a requested use, and should prevent promotion. - LogicalResult computeBlockingUses( - DenseMap> &userToBlockingUses); - - /// Computes in which blocks the value stored in the slot is actually used, - /// meaning blocks leading to a load. This method uses `definingBlocks`, the - /// set of blocks containing a store to the slot (defining the value of the - /// slot). - SmallPtrSet - computeSlotLiveIn(SmallPtrSetImpl &definingBlocks); - - /// Computes the points in which multiple re-definitions of the slot's value - /// (stores) may conflict. - void computeMergePoints(SmallPtrSetImpl &mergePoints); - - /// Ensures predecessors of merge points can properly provide their current - /// definition of the value stored in the slot to the merge point. This can - /// notably be an issue if the terminator used does not have the ability to - /// forward values through block operands. - bool areMergePointsUsable(SmallPtrSetImpl &mergePoints); - - MemorySlot slot; - DominanceInfo &dominance; -}; - -/// The SlotPromoter handles the state of promoting a memory slot. It wraps a -/// slot and its associated allocator. This will perform the mutation of IR. -class SlotPromoter { -public: - SlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, - OpBuilder &builder, DominanceInfo &dominance, - SlotPromotionInfo info); - - /// Actually promotes the slot by mutating IR. Promoting a slot does not - /// invalidate the SlotPromotionInfo of other slots. - void promoteSlot(); - -private: - /// Computes the reaching definition for all the operations that require - /// promotion. `reachingDef` is the value the slot should contain at the - /// beginning of the block. This method returns the reached definition at the - /// end of the block. - Value computeReachingDefInBlock(Block *block, Value reachingDef); - - /// Computes the reaching definition for all the operations that require - /// promotion. `reachingDef` corresponds to the initial value the - /// slot will contain before any write, typically a poison value. - void computeReachingDefInRegion(Region *region, Value reachingDef); - - /// Removes the blocking uses of the slot, in topological order. - void removeBlockingUses(); - - /// Lazily-constructed default value representing the content of the slot when - /// no store has been executed. This function may mutate IR. - Value getLazyDefaultValue(); - - MemorySlot slot; - PromotableAllocationOpInterface allocator; - OpBuilder &builder; - /// Potentially non-initialized default value. Use `getLazyDefaultValue` to - /// initialize it on demand. - Value defaultValue; - /// Contains the reaching definition at this operation. Reaching definitions - /// are only computed for promotable memory operations with blocking uses. - DenseMap reachingDefs; - DominanceInfo &dominance; - SlotPromotionInfo info; -}; - -} // namespace - -SlotPromoter::SlotPromoter(MemorySlot slot, +MemorySlotPromoter::MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, OpBuilder &builder, DominanceInfo &dominance, - SlotPromotionInfo info) + MemorySlotPromotionInfo info) : slot(slot), allocator(allocator), builder(builder), dominance(dominance), info(std::move(info)) { bool isResultOrNewBlockArgument = slot.ptr.getDefiningOp() == allocator; @@ -207,7 +103,7 @@ SlotPromoter::SlotPromoter(MemorySlot slot, "regions of the allocator"); } -Value SlotPromoter::getLazyDefaultValue() { +Value MemorySlotPromoter::getLazyDefaultValue() { if (defaultValue) return defaultValue; @@ -216,7 +112,7 @@ Value SlotPromoter::getLazyDefaultValue() { return defaultValue = allocator.getDefaultValue(slot, builder); } -LogicalResult SlotPromotionAnalyzer::computeBlockingUses( +LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( DenseMap> &userToBlockingUses) { // The promotion of an operation may require the promotion of further // operations (typically, removing operations that use an operation that must @@ -283,7 +179,7 @@ LogicalResult SlotPromotionAnalyzer::computeBlockingUses( return success(); } -SmallPtrSet SlotPromotionAnalyzer::computeSlotLiveIn( +SmallPtrSet MemorySlotPromotionAnalyzer::computeSlotLiveIn( SmallPtrSetImpl &definingBlocks) { SmallPtrSet liveIn; @@ -341,7 +237,7 @@ SmallPtrSet SlotPromotionAnalyzer::computeSlotLiveIn( } using IDFCalculator = llvm::IDFCalculatorBase; -void SlotPromotionAnalyzer::computeMergePoints( +void MemorySlotPromotionAnalyzer::computeMergePoints( SmallPtrSetImpl &mergePoints) { if (slot.ptr.getParentRegion()->hasOneBlock()) return; @@ -365,7 +261,7 @@ void SlotPromotionAnalyzer::computeMergePoints( mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end()); } -bool SlotPromotionAnalyzer::areMergePointsUsable( +bool MemorySlotPromotionAnalyzer::areMergePointsUsable( SmallPtrSetImpl &mergePoints) { for (Block *mergePoint : mergePoints) for (Block *pred : mergePoint->getPredecessors()) @@ -375,8 +271,8 @@ bool SlotPromotionAnalyzer::areMergePointsUsable( return true; } -Optional SlotPromotionAnalyzer::computeInfo() { - SlotPromotionInfo info; +Optional MemorySlotPromotionAnalyzer::computeInfo() { + MemorySlotPromotionInfo info; // First, find the set of operations that will need to be changed for the // promotion to happen. These operations need to resolve some of their uses, @@ -399,7 +295,7 @@ Optional SlotPromotionAnalyzer::computeInfo() { return {std::move(info)}; } -Value SlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) { +Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) { for (Operation &op : block->getOperations()) { if (auto memOp = dyn_cast(op)) { if (info.userToBlockingUses.contains(memOp)) @@ -413,7 +309,7 @@ Value SlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) { return reachingDef; } -void SlotPromoter::computeReachingDefInRegion(Region *region, +void MemorySlotPromoter::computeReachingDefInRegion(Region *region, Value reachingDef) { if (region->hasOneBlock()) { computeReachingDefInBlock(®ion->front(), reachingDef); @@ -462,7 +358,7 @@ void SlotPromoter::computeReachingDefInRegion(Region *region, } } -void SlotPromoter::removeBlockingUses() { +void MemorySlotPromoter::removeBlockingUses() { llvm::SetVector usersToRemoveUses; for (auto &user : llvm::make_first_range(info.userToBlockingUses)) usersToRemoveUses.insert(user); @@ -502,7 +398,7 @@ void SlotPromoter::removeBlockingUses() { "after promotion, the slot pointer should not be used anymore"); } -void SlotPromoter::promoteSlot() { +void MemorySlotPromoter::promoteSlot() { computeReachingDefInRegion(slot.ptr.getParentRegion(), {}); // Now that reaching definitions are known, remove all users. @@ -530,21 +426,21 @@ LogicalResult mlir::tryToPromoteMemorySlots( DominanceInfo &dominance) { // Actual promotion may invalidate the dominance analysis, so slot promotion // is prepated in batches. - SmallVector toPromote; + SmallVector toPromote; for (PromotableAllocationOpInterface allocator : allocators) { for (MemorySlot slot : allocator.getPromotableSlots()) { if (slot.ptr.use_empty()) continue; - SlotPromotionAnalyzer analyzer(slot, dominance); - Optional info = analyzer.computeInfo(); + MemorySlotPromotionAnalyzer analyzer(slot, dominance); + Optional info = analyzer.computeInfo(); if (info.has_value()) toPromote.emplace_back(slot, allocator, builder, dominance, std::move(info.value())); } } - for (SlotPromoter &promoter : toPromote) + for (MemorySlotPromoter &promoter : toPromote) promoter.promoteSlot(); return success(!toPromote.empty()); diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp new file mode 100644 index 0000000000000..9cb63bd452870 --- /dev/null +++ b/mlir/lib/Transforms/SROA.cpp @@ -0,0 +1,78 @@ +//===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/SROA.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +#define GEN_PASS_DEF_SROA +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +LogicalResult MemorySlotDestructionAnalyzer::computeDestructionTree() { + // TODO: document structure + + SmallVector dfsWorklist{slot}; + while (!dfsWorklist.empty()) { + MemorySlot currentSlot = dfsWorklist.pop_back_val(); + bool mustBeLeaf = false; + + for (Operation *user : currentSlot.ptr.getUsers()) { + if (auto memOp = llvm::dyn_cast(user)) + mustBeLeaf |= + memOp.getStored(currentSlot) || memOp.loadsFrom(currentSlot); + + + } + } + + return success(); +} + +namespace { + +struct SROA : public impl::SROABase { + void runOnOperation() override { + Operation *scopeOp = getOperation(); + bool changed = false; + + for (Region ®ion : scopeOp->getRegions()) { + if (region.getBlocks().empty()) + continue; + + OpBuilder builder(®ion.front(), region.front().begin()); + + // Promoting a slot can allow for further promotion of other slots, + // promotion is tried until no promotion succeeds. + while (true) { + DominanceInfo &dominance = getAnalysis(); + + for (Block &block : region) { + for (Operation &op : block.getOperations()) { + if (auto allocator = + llvm::dyn_cast(op)) { + allocator.getDestructibleSlots(); + } + } + } + + changed = true; + getAnalysisManager().invalidate({}); + } + } + + if (!changed) + markAllAnalysesPreserved(); + } +}; + +} // namespace From a4c4021c4926df7d0b36a5ba9b9e80007e9eef5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Thu, 20 Apr 2023 11:23:00 +0200 Subject: [PATCH 05/25] improve mem2reg interfaces for use with sroa --- mlir/lib/Transforms/Mem2Reg.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index d7e8aebe6e022..d741a12a1ac0c 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -11,8 +11,10 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/GenericIteratedDominanceFrontier.h" namespace mlir { @@ -168,9 +170,9 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( } // Because this pass currently only supports analysing the parent region of - // the slot pointer, if a promotable memory op that needs promotion is - // outside of this region, promotion must fail because it will be impossible - // to provide a valid `reachingDef` for it. + // the slot pointer, if a promotable memory op that needs promotion is outside + // of this region, promotion must fail because it will be impossible to + // provide a valid `reachingDef` for it. for (auto &[toPromote, _] : userToBlockingUses) if (isa(toPromote) && toPromote->getParentRegion() != slot.ptr.getParentRegion()) @@ -295,7 +297,8 @@ Optional MemorySlotPromotionAnalyzer::computeInfo() { return {std::move(info)}; } -Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) { +Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, + Value reachingDef) { for (Operation &op : block->getOperations()) { if (auto memOp = dyn_cast(op)) { if (info.userToBlockingUses.contains(memOp)) @@ -310,7 +313,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, Value reaching } void MemorySlotPromoter::computeReachingDefInRegion(Region *region, - Value reachingDef) { + Value reachingDef) { if (region->hasOneBlock()) { computeReachingDefInBlock(®ion->front(), reachingDef); return; From 5e49c8fa79f0a9adbe98155ff0f3cfbe6195d708 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Fri, 21 Apr 2023 13:08:11 +0200 Subject: [PATCH 06/25] complete algorithm implementation --- .../mlir/Interfaces/MemorySlotInterfaces.td | 21 +-- mlir/include/mlir/Transforms/SROA.h | 33 ++--- mlir/lib/Transforms/SROA.cpp | 122 +++++++++++++++--- 3 files changed, 127 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index 1378752110d6a..b641002683bdd 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -219,7 +219,7 @@ def DestructibleAllocationOpInterface The builder is located at the beginning of the block where the slot pointer is defined. }], - "::llvm::SmallVector<::mlir::MemorySlot>", + "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot>", "destruct", (ins "const ::mlir::DestructibleMemorySlot &":$slot, "::mlir::OpBuilder &":$builder) @@ -249,21 +249,24 @@ def DestructibleAccessorOpInterface destruction. This may involve creating new operations, and usually amounts to checking the pointer types match. - Returns memory slots representing the accessed subelements, if any. - Returned slots can only be aliased by other SubElementMemorySlots of the - same type coming from the same parent slot. + In case of success, adds to `subelements` the memory slots representing + the accessed subelements, if any. These slots can only be aliased by + other SubElementMemorySlots of the same type coming from the same parent + slot. }], - "::llvm::SmallVector<::mlir::SubElementMemorySlot>", + "bool", "canRewire", - (ins "const ::mlir::DestructibleMemorySlot &":$slot) + (ins "const ::mlir::DestructibleMemorySlot &":$slot, + "::llvm::SmallVectorImpl<::mlir::SubElementMemorySlot> &":$subelems) >, InterfaceMethod<[{ - Rewires the use of a slot to the generated subslots. + Rewires the use of a slot to the generated subslots, without deleting + any operation. Returns whether the accessor should be deleted. }], - "void", + "::mlir::DeletionKind", "rewire", (ins "const ::mlir::DestructibleMemorySlot &":$slot, - "::llvm::DenseMap<::mlir::Attribute, ::mlir::Value> &":$subslots) + "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot> &":$subslots) > ]; } diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h index 799c609f775da..268db01915c22 100644 --- a/mlir/include/mlir/Transforms/SROA.h +++ b/mlir/include/mlir/Transforms/SROA.h @@ -9,34 +9,25 @@ #ifndef MLIR_TRANSFORMS_SROA_H #define MLIR_TRANSFORMS_SROA_H -#include "mlir/IR/Dominance.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Support/LogicalResult.h" namespace mlir { -/// Computes information for slot destruction leading to promotion. This will -/// compute whether destructing this slot and subsequent new subslots will -/// lead to only promatble slots being generated. -class MemorySlotDestructionAnalyzer { -public: - MemorySlotDestructionAnalyzer(MemorySlot slot, DominanceInfo &dominance) - : slot(slot), dominance(dominance) {} - -private: - LogicalResult computeDestructionTree(); - - LogicalResult computeBlockingUses(); - - MemorySlot slot; - DominanceInfo &dominance; +struct MemorySlotDestructionInfo { + DenseMap> userToBlockingUses; + SmallVector accessors; }; -class MemorySlotDestructor { -public: - void destructSlot(); -private: -}; +/// Computes information for slot destruction leading to promotion. This will +/// compute whether this slot can be destructed. Returns nothing if the slot +/// cannot be destructed. +Optional +computeDestructionInfo(DestructibleMemorySlot &slot); + +void destructSlot(DestructibleMemorySlot &slot, + DestructibleAllocationOpInterface allocator, + OpBuilder &builder, MemorySlotDestructionInfo &info); } // namespace mlir diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 9cb63bd452870..20c00e5b9fa6a 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/SROA.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" @@ -18,24 +19,101 @@ namespace mlir { using namespace mlir; -LogicalResult MemorySlotDestructionAnalyzer::computeDestructionTree() { - // TODO: document structure +template +static V &getOrInsertDefault(DenseMap &map, K key) { + return map.try_emplace(key).first->second; +} + +Optional +mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { + MemorySlotDestructionInfo info; + + for (OpOperand &use : slot.ptr.getUses()) { + if (auto accessor = + dyn_cast(use.getOwner())) { + SmallVector subelements; + if (!accessor.canRewire(slot, subelements)) + return {}; + info.accessors.push_back(accessor); + continue; + } + + SmallPtrSet &blockingUses = + getOrInsertDefault(info.userToBlockingUses, use.getOwner()); + blockingUses.insert(&use); + } - SmallVector dfsWorklist{slot}; - while (!dfsWorklist.empty()) { - MemorySlot currentSlot = dfsWorklist.pop_back_val(); - bool mustBeLeaf = false; + SetVector forwardSlice; + mlir::getForwardSlice(slot.ptr, &forwardSlice); + for (Operation *user : forwardSlice) { + // If the next operation has no blocking uses, everything is fine. + if (!info.userToBlockingUses.contains(user)) + continue; + + SmallPtrSet &blockingUses = info.userToBlockingUses[user]; + auto promotable = dyn_cast(user); + + // An operation that has blocking uses must be promoted. If it is not + // promotable, destruction must fail. + if (!promotable) + return {}; + + SmallVector newBlockingUses; + // If the operation decides it cannot deal with removing the blocking uses, + // destruction must fail. + if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses)) + return {}; + + // Then, register any new blocking uses for coming operations. + for (OpOperand *blockingUse : newBlockingUses) { + assert(llvm::find(user->getResults(), blockingUse->get()) != + user->result_end()); + + SmallPtrSetImpl &newUserBlockingUseSet = + getOrInsertDefault(info.userToBlockingUses, blockingUse->getOwner()); + newUserBlockingUseSet.insert(blockingUse); + } + } - for (Operation *user : currentSlot.ptr.getUsers()) { - if (auto memOp = llvm::dyn_cast(user)) - mustBeLeaf |= - memOp.getStored(currentSlot) || memOp.loadsFrom(currentSlot); + return info; +} - +void mlir::destructSlot(DestructibleMemorySlot &slot, + DestructibleAllocationOpInterface allocator, + OpBuilder &builder, MemorySlotDestructionInfo &info) { + OpBuilder::InsertionGuard guard(builder); + + builder.setInsertionPointToStart(slot.ptr.getParentBlock()); + DenseMap subslots = allocator.destruct(slot, builder); + + llvm::SetVector usersToRewire; + for (auto &[user, _] : info.userToBlockingUses) + usersToRewire.insert(user); + for (DestructibleAccessorOpInterface accessor : info.accessors) + usersToRewire.insert(accessor); + SetVector sortedUsersToRewire = + mlir::topologicalSort(usersToRewire); + + llvm::SmallVector toErase; + for (Operation *toRewire : llvm::reverse(sortedUsersToRewire)) { + builder.setInsertionPointAfter(toRewire); + if (auto promotable = dyn_cast(toRewire)) { + if (promotable.removeBlockingUses(slot, + info.userToBlockingUses[promotable], + builder) == DeletionKind::Delete) + toErase.push_back(promotable); + continue; } + + auto accessor = cast(toRewire); + if (accessor.rewire(slot, subslots) == DeletionKind::Delete) + toErase.push_back(accessor); } - return success(); + for (Operation *toEraseOp : toErase) + toEraseOp->erase(); + + allocator.handleDestructionComplete(slot); } namespace { @@ -51,22 +129,28 @@ struct SROA : public impl::SROABase { OpBuilder builder(®ion.front(), region.front().begin()); - // Promoting a slot can allow for further promotion of other slots, - // promotion is tried until no promotion succeeds. - while (true) { - DominanceInfo &dominance = getAnalysis(); + // Destructing a slot can allow for further destruction of other slots, + // destruction is tried until no destruction succeeds. + bool justDestructed = true; + while (justDestructed) { + justDestructed = false; for (Block &block : region) { for (Operation &op : block.getOperations()) { if (auto allocator = llvm::dyn_cast(op)) { - allocator.getDestructibleSlots(); + for (DestructibleMemorySlot slot : + allocator.getDestructibleSlots()) { + if (auto info = computeDestructionInfo(slot)) { + destructSlot(slot, allocator, builder, info.value()); + justDestructed = true; + } + } } } } - changed = true; - getAnalysisManager().invalidate({}); + changed |= justDestructed; } } From e88b3950927f8e3f75207dd0564d0a8aac3f18c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Mon, 24 Apr 2023 11:29:21 +0200 Subject: [PATCH 07/25] adjust edge case behavior --- mlir/lib/Transforms/SROA.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 20c00e5b9fa6a..af525de1b1a06 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -32,10 +32,10 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { if (auto accessor = dyn_cast(use.getOwner())) { SmallVector subelements; - if (!accessor.canRewire(slot, subelements)) - return {}; - info.accessors.push_back(accessor); - continue; + if (accessor.canRewire(slot, subelements)) { + info.accessors.push_back(accessor); + continue; + } } SmallPtrSet &blockingUses = From 805a4f4841f76d289a08fd1a9a86cea41b7a3b58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 25 Apr 2023 09:43:33 +0200 Subject: [PATCH 08/25] begin implementation of SROA interfaces for LLVM IR --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 6 +- mlir/include/mlir/Interfaces/CMakeLists.txt | 9 ++- .../mlir/Interfaces/MemorySlotInterfaces.h | 3 +- .../mlir/Interfaces/MemorySlotInterfaces.td | 24 +++++++- mlir/include/mlir/Transforms/SROA.h | 1 + mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 56 +++++++++++++++++++ mlir/lib/Interfaces/MemorySlotInterfaces.cpp | 3 +- mlir/lib/Transforms/SROA.cpp | 8 ++- 8 files changed, 102 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index fefc602ae3729..9021dde3c99d5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -173,7 +173,8 @@ def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp< // Memory-related operations. def LLVM_AllocaOp : LLVM_Op<"alloca", - [DeclareOpInterfaceMethods]>, + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, LLVM_MemOpPatterns { let arguments = (ins AnyInteger:$arraySize, OptionalAttr:$alignment, @@ -232,7 +233,8 @@ def LLVM_AllocaOp : LLVM_Op<"alloca", } def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, Variadic>:$dynamicIndices, DenseI32ArrayAttr:$rawConstantIndices, diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index f18a87aeae71e..71a980056a4ea 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -7,7 +7,6 @@ add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) -add_mlir_interface(MemorySlotInterfaces) add_mlir_interface(ParallelCombiningOpInterface) add_mlir_interface(RuntimeVerifiableOpInterface) add_mlir_interface(ShapedOpInterfaces) @@ -17,6 +16,14 @@ add_mlir_interface(ValueBoundsOpInterface) add_mlir_interface(VectorInterfaces) add_mlir_interface(ViewLikeInterface) +set(LLVM_TARGET_DEFINITIONS MemorySlotInterfaces.td) +mlir_tablegen(MemorySlotOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(MemorySlotOpInterfaces.cpp.inc -gen-op-interface-defs) +mlir_tablegen(MemorySlotTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(MemorySlotTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRMemorySlotInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRMemorySlotInterfacesIncGen) + set(LLVM_TARGET_DEFINITIONS DataLayoutInterfaces.td) mlir_tablegen(DataLayoutAttrInterface.h.inc -gen-attr-interface-decls) mlir_tablegen(DataLayoutAttrInterface.cpp.inc -gen-attr-interface-defs) diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h index 6f5ea6b42920f..a02da1bba5ffa 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -45,6 +45,7 @@ enum class DeletionKind { } // namespace mlir -#include "mlir/Interfaces/MemorySlotInterfaces.h.inc" +#include "mlir/Interfaces/MemorySlotOpInterfaces.h.inc" +#include "mlir/Interfaces/MemorySlotTypeInterfaces.h.inc" #endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index b641002683bdd..0cd3028a11cd8 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -222,6 +222,7 @@ def DestructibleAllocationOpInterface "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot>", "destruct", (ins "const ::mlir::DestructibleMemorySlot &":$slot, + "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices, "::mlir::OpBuilder &":$builder) >, InterfaceMethod<[{ @@ -252,11 +253,13 @@ def DestructibleAccessorOpInterface In case of success, adds to `subelements` the memory slots representing the accessed subelements, if any. These slots can only be aliased by other SubElementMemorySlots of the same type coming from the same parent - slot. + slot. This method must also register the indices it will access within + the `usedIndices` set. }], "bool", "canRewire", (ins "const ::mlir::DestructibleMemorySlot &":$slot, + "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices, "::llvm::SmallVectorImpl<::mlir::SubElementMemorySlot> &":$subelems) >, InterfaceMethod<[{ @@ -271,4 +274,23 @@ def DestructibleAccessorOpInterface ]; } +def DestructibleTypeInterface + : TypeInterface<"DestructibleTypeInterface"> { + let description = [{ + TODO + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Destructs the type into subelements into a map of index attributes to + types of subelements. + }], + "::llvm::DenseMap<::mlir::Attribute, ::mlir::Type>", + "destruct", + (ins) + > + ]; +} + #endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h index 268db01915c22..a6eec3b0dbc09 100644 --- a/mlir/include/mlir/Transforms/SROA.h +++ b/mlir/include/mlir/Transforms/SROA.h @@ -15,6 +15,7 @@ namespace mlir { struct MemorySlotDestructionInfo { + SmallPtrSet usedIndices; DenseMap> userToBlockingUses; SmallVector accessors; }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index ddcf2d91b7cd1..81cdc276f379a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" using namespace mlir; @@ -50,6 +51,45 @@ void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot, erase(); } +SmallVector LLVM::AllocaOp::getDestructibleSlots() { + Type elemType = + getElemType() ? *getElemType() : getResult().getType().getElementType(); + auto destructible = dyn_cast(elemType); + if (!destructible) + return {}; + return {DestructibleMemorySlot{{getResult(), elemType, getLoc()}, + destructible.destruct()}}; +} + +DenseMap +LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + OpBuilder &builder) { + assert(slot.ptr == getResult()); + Type elemType = + getElemType() ? *getElemType() : getResult().getType().getElementType(); + + DenseMap slotMap; + for (auto &[index, type] : + cast(elemType).destruct()) { + if (usedIndices.contains(index)) { + auto subAlloca = builder.create( + getLoc(), LLVM::LLVMPointerType::get(getContext()), type, + getArraySize()); + slotMap.try_emplace(index, + {subAlloca.getResult(), type, getLoc()}); + } + } + + return slotMap; +} + +void LLVM::AllocaOp::handleDestructionComplete( + const DestructibleMemorySlot &slot) { + assert(slot.ptr == getResult()); + erase(); +} + //===----------------------------------------------------------------------===// // Interfaces for LoadOp/StoreOp //===----------------------------------------------------------------------===// @@ -195,6 +235,10 @@ static bool hasAllZeroIndices(LLVM::GEPOp gepOp) { }); } +//===----------------------------------------------------------------------===// +// Interfaces for GEPOp +//===----------------------------------------------------------------------===// + bool LLVM::GEPOp::canUsesBeRemoved( const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { @@ -209,3 +253,15 @@ DeletionKind LLVM::GEPOp::removeBlockingUses( OpBuilder &builder) { return DeletionKind::Delete; } + +bool LLVM::GEPOp::canRewire( + const DestructibleMemorySlot &slot, + ::llvm::SmallPtrSetImpl<::mlir::Attribute> &usedIndices, + SmallVectorImpl &subelems) { + assert(0 && "todo"); +} + +DeletionKind LLVM::GEPOp::rewire(const DestructibleMemorySlot &slot, + DenseMap &subslots) { + assert(0 && "todo"); +} diff --git a/mlir/lib/Interfaces/MemorySlotInterfaces.cpp b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp index ad15296389797..2c9e23250e9ee 100644 --- a/mlir/lib/Interfaces/MemorySlotInterfaces.cpp +++ b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp @@ -8,4 +8,5 @@ #include "mlir/Interfaces/MemorySlotInterfaces.h" -#include "mlir/Interfaces/MemorySlotInterfaces.cpp.inc" +#include "mlir/Interfaces/MemorySlotOpInterfaces.cpp.inc" +#include "mlir/Interfaces/MemorySlotTypeInterfaces.cpp.inc" diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index af525de1b1a06..1c08a9c294da9 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -32,7 +32,7 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { if (auto accessor = dyn_cast(use.getOwner())) { SmallVector subelements; - if (accessor.canRewire(slot, subelements)) { + if (accessor.canRewire(slot, info.usedIndices, subelements)) { info.accessors.push_back(accessor); continue; } @@ -84,7 +84,8 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(slot.ptr.getParentBlock()); - DenseMap subslots = allocator.destruct(slot, builder); + DenseMap subslots = + allocator.destruct(slot, info.usedIndices, builder); llvm::SetVector usersToRewire; for (auto &[user, _] : info.userToBlockingUses) @@ -113,6 +114,9 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, for (Operation *toEraseOp : toErase) toEraseOp->erase(); + assert(slot.ptr.use_empty() && "at the end of destruction, the original slot " + "pointer should no longer be used"); + allocator.handleDestructionComplete(slot); } From 5536c493ef3a0524d083b47ae541d856927e76a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 25 Apr 2023 12:46:22 +0200 Subject: [PATCH 09/25] adjust interfaces --- .../include/mlir/Interfaces/MemorySlotInterfaces.h | 4 ++-- .../mlir/Interfaces/MemorySlotInterfaces.td | 14 ++++++++++++-- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 8 ++++++-- mlir/lib/Transforms/SROA.cpp | 4 ++-- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h index a02da1bba5ffa..69d18883d59ab 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -23,13 +23,13 @@ struct MemorySlot { Type elemType; }; -struct DestructibleMemorySlot : public MemorySlot { +struct DestructibleMemorySlot : MemorySlot { /// Maps from an index within the memory slot the type of the pointer that /// will be generated to access the element directly. DenseMap elementsPtrs; }; -struct SubElementMemorySlot : public MemorySlot { +struct SubElementMemorySlot : MemorySlot { /// Index of this memory slot in the parent memory slot. Attribute subelementIndex; }; diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index 0cd3028a11cd8..43caf80cd7629 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -76,6 +76,9 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> { to memory slots. Loads and stores must be of whole values of the same type as the slot itself. + For a memory operation on a slot to be valid, it must operate on the slot + pointer *only as a pointer to an element of the type of the slot*. + If the same operation does both loads and stores on the same slot, the load must semantically happen first. }]; @@ -259,8 +262,15 @@ def DestructibleAccessorOpInterface "bool", "canRewire", (ins "const ::mlir::DestructibleMemorySlot &":$slot, - "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices, - "::llvm::SmallVectorImpl<::mlir::SubElementMemorySlot> &":$subelems) + "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices) + >, + InterfaceMethod<[{ + For a given destructible memory slot, returns the memory slots generated + by this access pointing to subelements of the slot. + }], + "::llvm::SmallVector<::mlir::SubElementMemorySlot>", + "getSubElementMemorySlots", + (ins "const ::mlir::DestructibleMemorySlot &":$slot) >, InterfaceMethod<[{ Rewires the use of a slot to the generated subslots, without deleting diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 81cdc276f379a..157fc1ab4c17c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -256,8 +256,12 @@ DeletionKind LLVM::GEPOp::removeBlockingUses( bool LLVM::GEPOp::canRewire( const DestructibleMemorySlot &slot, - ::llvm::SmallPtrSetImpl<::mlir::Attribute> &usedIndices, - SmallVectorImpl &subelems) { + ::llvm::SmallPtrSetImpl<::mlir::Attribute> &usedIndices) { + assert(0 && "todo"); +} + +SmallVector +LLVM::GEPOp::getSubElementMemorySlots(const DestructibleMemorySlot &slot) { assert(0 && "todo"); } diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 1c08a9c294da9..cc15b3b59e2d2 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -28,11 +28,11 @@ Optional mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { MemorySlotDestructionInfo info; + // Initialize the analysis with the immediate users of the slot. for (OpOperand &use : slot.ptr.getUses()) { if (auto accessor = dyn_cast(use.getOwner())) { - SmallVector subelements; - if (accessor.canRewire(slot, info.usedIndices, subelements)) { + if (accessor.canRewire(slot, info.usedIndices)) { info.accessors.push_back(accessor); continue; } From 7dfdfe8b683bc99cca7dadaa84d66f34a07e2b05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Wed, 26 Apr 2023 10:58:11 +0200 Subject: [PATCH 10/25] progress deeper SROA analysis --- .../mlir/Interfaces/MemorySlotInterfaces.h | 20 ++++- .../mlir/Interfaces/MemorySlotInterfaces.td | 12 +-- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 38 ++++----- mlir/lib/Transforms/Mem2Reg.cpp | 12 ++- mlir/lib/Transforms/SROA.cpp | 78 +++++++++++++++++-- 5 files changed, 116 insertions(+), 44 deletions(-) diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h index 69d18883d59ab..f1ef1b18d0329 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -23,17 +23,33 @@ struct MemorySlot { Type elemType; }; -struct DestructibleMemorySlot : MemorySlot { +struct DestructibleSlotInfo { /// Maps from an index within the memory slot the type of the pointer that /// will be generated to access the element directly. DenseMap elementsPtrs; }; -struct SubElementMemorySlot : MemorySlot { +struct SubElementSlotInfo { /// Index of this memory slot in the parent memory slot. Attribute subelementIndex; }; +struct DestructibleMemorySlot { + MemorySlot slot; + DestructibleSlotInfo info; +}; + +struct SubElementMemorySlot { + MemorySlot slot; + SubElementSlotInfo info; +}; + +struct MaybeDestructibleSubElementMemorySlot { + MemorySlot slot; + SubElementSlotInfo subElementInfo; + Optional destructibleInfo; +}; + /// Returned by operation promotion logic requesting the deletion of an /// operation. enum class DeletionKind { diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index 43caf80cd7629..4f3795a55abf5 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -154,6 +154,8 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { let methods = [ InterfaceMethod<[{ + TODO + Checks that this operation can be promoted to no longer use the provided blocking uses, in the context of promoting `slot`. @@ -161,11 +163,12 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { removed, that dependency should be added to the `newBlockingUses` argument. Dependent uses must only be uses of results of this operation. }], "bool", "canUsesBeRemoved", - (ins "const ::mlir::MemorySlot &":$slot, - "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses, + (ins "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses, "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses) >, InterfaceMethod<[{ + TODO + Transforms IR to ensure that the current operation does not use the provided memory slot anymore. In contrast to `PromotableMemOpInterface`, operations implementing this interface must not need access to the @@ -189,8 +192,7 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { }], "::mlir::DeletionKind", "removeBlockingUses", - (ins "const ::mlir::MemorySlot &":$slot, - "const ::llvm::SmallPtrSetImpl &":$blockingUses, + (ins "const ::llvm::SmallPtrSetImpl &":$blockingUses, "::mlir::OpBuilder &":$builder) >, ]; @@ -268,7 +270,7 @@ def DestructibleAccessorOpInterface For a given destructible memory slot, returns the memory slots generated by this access pointing to subelements of the slot. }], - "::llvm::SmallVector<::mlir::SubElementMemorySlot>", + "::llvm::SmallVector<::mlir::MaybeDestructibleSubElementMemorySlot>", "getSubElementMemorySlots", (ins "const ::mlir::DestructibleMemorySlot &":$slot) >, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 157fc1ab4c17c..fe31b4496a2e8 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -58,14 +58,14 @@ SmallVector LLVM::AllocaOp::getDestructibleSlots() { if (!destructible) return {}; return {DestructibleMemorySlot{{getResult(), elemType, getLoc()}, - destructible.destruct()}}; + {destructible.destruct()}}}; } DenseMap LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, SmallPtrSetImpl &usedIndices, OpBuilder &builder) { - assert(slot.ptr == getResult()); + assert(slot.slot.ptr == getResult()); Type elemType = getElemType() ? *getElemType() : getResult().getType().getElementType(); @@ -86,7 +86,7 @@ LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, void LLVM::AllocaOp::handleDestructionComplete( const DestructibleMemorySlot &slot) { - assert(slot.ptr == getResult()); + assert(slot.slot.ptr == getResult()); erase(); } @@ -169,62 +169,57 @@ static bool forwardToUsers(Operation *op, } bool LLVM::BitcastOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return forwardToUsers(*this, newBlockingUses); } DeletionKind LLVM::BitcastOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } bool LLVM::AddrSpaceCastOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return forwardToUsers(*this, newBlockingUses); } DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } bool LLVM::LifetimeStartOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return true; } DeletionKind LLVM::LifetimeStartOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } bool LLVM::LifetimeEndOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return true; } DeletionKind LLVM::LifetimeEndOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } bool LLVM::DbgDeclareOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return true; } DeletionKind LLVM::DbgDeclareOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } @@ -240,7 +235,7 @@ static bool hasAllZeroIndices(LLVM::GEPOp gepOp) { //===----------------------------------------------------------------------===// bool LLVM::GEPOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { // GEP can be removed as long as it is a no-op and its users can be removed. if (!hasAllZeroIndices(*this)) @@ -249,8 +244,7 @@ bool LLVM::GEPOp::canUsesBeRemoved( } DeletionKind LLVM::GEPOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } @@ -260,7 +254,7 @@ bool LLVM::GEPOp::canRewire( assert(0 && "todo"); } -SmallVector +SmallVector LLVM::GEPOp::getSubElementMemorySlots(const DestructibleMemorySlot &slot) { assert(0 && "todo"); } diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index d741a12a1ac0c..4d7badb645a54 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -89,10 +89,9 @@ using namespace mlir; /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), /// Springer. -MemorySlotPromoter::MemorySlotPromoter(MemorySlot slot, - PromotableAllocationOpInterface allocator, - OpBuilder &builder, DominanceInfo &dominance, - MemorySlotPromotionInfo info) +MemorySlotPromoter::MemorySlotPromoter( + MemorySlot slot, PromotableAllocationOpInterface allocator, + OpBuilder &builder, DominanceInfo &dominance, MemorySlotPromotionInfo info) : slot(slot), allocator(allocator), builder(builder), dominance(dominance), info(std::move(info)) { bool isResultOrNewBlockArgument = slot.ptr.getDefiningOp() == allocator; @@ -147,7 +146,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( // If the operation decides it cannot deal with removing the blocking uses, // promotion must fail. if (auto promotable = dyn_cast(user)) { - if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses)) + if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses)) return failure(); } else if (auto promotable = dyn_cast(user)) { if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses)) @@ -388,8 +387,7 @@ void MemorySlotPromoter::removeBlockingUses() { auto toPromoteBasic = cast(toPromote); builder.setInsertionPointAfter(toPromote); - if (toPromoteBasic.removeBlockingUses(slot, - info.userToBlockingUses[toPromote], + if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], builder) == DeletionKind::Delete) toErase.push_back(toPromote); } diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index cc15b3b59e2d2..270fcb5b38f0e 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Builders.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" namespace mlir { #define GEN_PASS_DEF_SROA @@ -29,7 +30,7 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { MemorySlotDestructionInfo info; // Initialize the analysis with the immediate users of the slot. - for (OpOperand &use : slot.ptr.getUses()) { + for (OpOperand &use : slot.slot.ptr.getUses()) { if (auto accessor = dyn_cast(use.getOwner())) { if (accessor.canRewire(slot, info.usedIndices)) { @@ -43,13 +44,74 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { blockingUses.insert(&use); } + struct AccessCheckJob { + DestructibleAccessorOpInterface accessor; + DestructibleMemorySlot memorySlot; + }; + + DenseMap shouldBePromotable; + SmallVector accessCheckWorklist; + for (DestructibleAccessorOpInterface accessor : info.accessors) + accessCheckWorklist.emplace_back({accessor, slot}); + + while (!accessCheckWorklist.empty()) { + AccessCheckJob job = accessCheckWorklist.pop_back_val(); + for (MaybeDestructibleSubElementMemorySlot &subslot : + job.accessor.getSubElementMemorySlots(job.memorySlot)) { + for (OpOperand &subslotUse : subslot.slot.ptr.getUses()) { + bool shouldAbort = + TypeSwitch(subslotUse.getOwner()) + .Case([&](auto accessor) { + if (!subslot.destructibleInfo) + return true; + + accessCheckWorklist.emplace_back( + {accessor, + DestructibleMemorySlot{ + subslot.slot, subslot.destructibleInfo.value()}}); + return false; + }) + .Case([&](auto memOp) { + Operation *memOpAsOp = memOp; + SmallPtrSet &blockingUses = + getOrInsertDefault(info.userToBlockingUses, memOpAsOp); + blockingUses.insert(&subslotUse); + + shouldBePromotable.insert({memOp, subslot.slot}); + return false; + }) + .Case([&](auto promotableOp) { + Operation *promotableOpAsOp = promotableOp; + SmallPtrSet &blockingUses = + getOrInsertDefault(info.userToBlockingUses, + promotableOpAsOp); + blockingUses.insert(&subslotUse); + return false; + }) + .Default([](auto) { return true; }); + + if (shouldAbort) + return {}; + } + } + } + SetVector forwardSlice; - mlir::getForwardSlice(slot.ptr, &forwardSlice); + mlir::getForwardSlice(slot.slot.ptr, &forwardSlice); for (Operation *user : forwardSlice) { // If the next operation has no blocking uses, everything is fine. if (!info.userToBlockingUses.contains(user)) continue; + // If the operation is a mem op, we just need to check it is promotable if + // necessary. + if (auto memOp = dyn_cast(user)) { + if (!shouldBePromotable.contains(memOp)) + continue; + MemorySlot concernedSlot = shouldBePromotable.at(memOp); + assert(0 && "todo"); + } + SmallPtrSet &blockingUses = info.userToBlockingUses[user]; auto promotable = dyn_cast(user); @@ -61,7 +123,7 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { SmallVector newBlockingUses; // If the operation decides it cannot deal with removing the blocking uses, // destruction must fail. - if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses)) + if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses)) return {}; // Then, register any new blocking uses for coming operations. @@ -83,7 +145,7 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, OpBuilder &builder, MemorySlotDestructionInfo &info) { OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(slot.ptr.getParentBlock()); + builder.setInsertionPointToStart(slot.slot.ptr.getParentBlock()); DenseMap subslots = allocator.destruct(slot, info.usedIndices, builder); @@ -99,8 +161,7 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, for (Operation *toRewire : llvm::reverse(sortedUsersToRewire)) { builder.setInsertionPointAfter(toRewire); if (auto promotable = dyn_cast(toRewire)) { - if (promotable.removeBlockingUses(slot, - info.userToBlockingUses[promotable], + if (promotable.removeBlockingUses(info.userToBlockingUses[promotable], builder) == DeletionKind::Delete) toErase.push_back(promotable); continue; @@ -114,8 +175,9 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, for (Operation *toEraseOp : toErase) toEraseOp->erase(); - assert(slot.ptr.use_empty() && "at the end of destruction, the original slot " - "pointer should no longer be used"); + assert(slot.slot.ptr.use_empty() && + "at the end of destruction, the original slot " + "pointer should no longer be used"); allocator.handleDestructionComplete(slot); } From 06d4bdd164155ea63cb5b9d17ade8b5a5966d406 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Wed, 26 Apr 2023 14:10:55 +0200 Subject: [PATCH 11/25] introduce type safety interface --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 7 +- .../mlir/Interfaces/MemorySlotInterfaces.td | 37 ++++++-- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 20 +++-- mlir/lib/Transforms/SROA.cpp | 86 ++++++------------- 4 files changed, 74 insertions(+), 76 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 9021dde3c99d5..db722850fff7b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -234,6 +234,7 @@ def LLVM_AllocaOp : LLVM_Op<"alloca", def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, Variadic>:$dynamicIndices, @@ -318,7 +319,8 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, } def LLVM_LoadOp : LLVM_MemAccessOpBase<"load", - [DeclareOpInterfaceMethods]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { dag args = (ins Arg, "", [MemRead]>:$addr, OptionalAttr:$alignment, UnitAttr:$volatile_, @@ -390,7 +392,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load", } def LLVM_StoreOp : LLVM_MemAccessOpBase<"store", - [DeclareOpInterfaceMethods]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { dag args = (ins LLVM_LoadableType:$value, Arg,"",[MemWrite]>:$addr, OptionalAttr:$alignment, diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index 4f3795a55abf5..0a2cb544c6e1e 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -241,8 +241,35 @@ def DestructibleAllocationOpInterface ]; } +def TypeSafeMemOpInterface : OpInterface<"TypeSafeMemOpInterface"> { + let description = [{ + TODO + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns whether all accesses in this operation to the provided slot are + done in a type-safe manner. To be type-safe, the access must only load + the value in this type as the type of the slot, and without assuming any + context around the slot. For example, a type-safe load must not load + outside the bounds of the slot. + + If the type-safety of the accesses depends on the type-safety of the + accesses to further memory slots, the result of this method will be + conditioned to the type-safety of the accesses to the slots added by + this method to `mustBeSafelyUsed`. + }], + "bool", + "onlyTypeSafeAccesses", + (ins "const ::mlir::MemorySlot &":$slot, + "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed) + > + ]; +} + def DestructibleAccessorOpInterface - : OpInterface<"DestructibleAccessorOpInterface"> { + : OpInterface<"DestructibleAccessorOpInterface", [TypeSafeMemOpInterface]> { let description = [{ TODO }]; @@ -266,14 +293,6 @@ def DestructibleAccessorOpInterface (ins "const ::mlir::DestructibleMemorySlot &":$slot, "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices) >, - InterfaceMethod<[{ - For a given destructible memory slot, returns the memory slots generated - by this access pointing to subelements of the slot. - }], - "::llvm::SmallVector<::mlir::MaybeDestructibleSubElementMemorySlot>", - "getSubElementMemorySlots", - (ins "const ::mlir::DestructibleMemorySlot &":$slot) - >, InterfaceMethod<[{ Rewires the use of a slot to the generated subslots, without deleting any operation. Returns whether the accessor should be deleted. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index fe31b4496a2e8..a6b664f273c6d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -155,6 +155,16 @@ DeletionKind LLVM::StoreOp::removeBlockingUses( return DeletionKind::Delete; } +bool LLVM::LoadOp::onlyTypeSafeAccesses( + const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { + assert(0 && "todo"); +} + +bool LLVM::StoreOp::onlyTypeSafeAccesses( + const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { + assert(0 && "todo"); +} + //===----------------------------------------------------------------------===// // Interfaces for discardable OPs //===----------------------------------------------------------------------===// @@ -248,14 +258,14 @@ DeletionKind LLVM::GEPOp::removeBlockingUses( return DeletionKind::Delete; } -bool LLVM::GEPOp::canRewire( - const DestructibleMemorySlot &slot, - ::llvm::SmallPtrSetImpl<::mlir::Attribute> &usedIndices) { +bool LLVM::GEPOp::onlyTypeSafeAccesses( + const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { assert(0 && "todo"); } -SmallVector -LLVM::GEPOp::getSubElementMemorySlots(const DestructibleMemorySlot &slot) { +bool LLVM::GEPOp::canRewire( + const DestructibleMemorySlot &slot, + ::llvm::SmallPtrSetImpl<::mlir::Attribute> &usedIndices) { assert(0 && "todo"); } diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 270fcb5b38f0e..aa0081f0c36a5 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -29,70 +29,45 @@ Optional mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { MemorySlotDestructionInfo info; + SmallVector usedSafelyWorklist; + // Initialize the analysis with the immediate users of the slot. for (OpOperand &use : slot.slot.ptr.getUses()) { if (auto accessor = dyn_cast(use.getOwner())) { if (accessor.canRewire(slot, info.usedIndices)) { - info.accessors.push_back(accessor); - continue; + if (accessor.onlyTypeSafeAccesses(slot.slot, usedSafelyWorklist)) { + info.accessors.push_back(accessor); + continue; + } } } + // If it cannot be shown that the operation uses the slot safely, maybe it + // can be promoted out of using the slot? SmallPtrSet &blockingUses = getOrInsertDefault(info.userToBlockingUses, use.getOwner()); blockingUses.insert(&use); } - struct AccessCheckJob { - DestructibleAccessorOpInterface accessor; - DestructibleMemorySlot memorySlot; - }; - - DenseMap shouldBePromotable; - SmallVector accessCheckWorklist; - for (DestructibleAccessorOpInterface accessor : info.accessors) - accessCheckWorklist.emplace_back({accessor, slot}); - - while (!accessCheckWorklist.empty()) { - AccessCheckJob job = accessCheckWorklist.pop_back_val(); - for (MaybeDestructibleSubElementMemorySlot &subslot : - job.accessor.getSubElementMemorySlots(job.memorySlot)) { - for (OpOperand &subslotUse : subslot.slot.ptr.getUses()) { - bool shouldAbort = - TypeSwitch(subslotUse.getOwner()) - .Case([&](auto accessor) { - if (!subslot.destructibleInfo) - return true; - - accessCheckWorklist.emplace_back( - {accessor, - DestructibleMemorySlot{ - subslot.slot, subslot.destructibleInfo.value()}}); - return false; - }) - .Case([&](auto memOp) { - Operation *memOpAsOp = memOp; - SmallPtrSet &blockingUses = - getOrInsertDefault(info.userToBlockingUses, memOpAsOp); - blockingUses.insert(&subslotUse); - - shouldBePromotable.insert({memOp, subslot.slot}); - return false; - }) - .Case([&](auto promotableOp) { - Operation *promotableOpAsOp = promotableOp; - SmallPtrSet &blockingUses = - getOrInsertDefault(info.userToBlockingUses, - promotableOpAsOp); - blockingUses.insert(&subslotUse); - return false; - }) - .Default([](auto) { return true; }); - - if (shouldAbort) - return {}; - } + SmallPtrSet dealtWith; + while (!usedSafelyWorklist.empty()) { + MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val(); + for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) { + if (dealtWith.contains(&subslotUse)) + continue; + dealtWith.insert(&subslotUse); + Operation *subslotUser = subslotUse.getOwner(); + + if (auto memOp = dyn_cast(subslotUser)) + if (memOp.onlyTypeSafeAccesses(mustBeUsedSafely, usedSafelyWorklist)) + continue; + + // If it cannot be shown that the operation uses the slot safely, maybe it + // can be promoted out of using the slot? + SmallPtrSet &blockingUses = + getOrInsertDefault(info.userToBlockingUses, subslotUser); + blockingUses.insert(&subslotUse); } } @@ -103,15 +78,6 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { if (!info.userToBlockingUses.contains(user)) continue; - // If the operation is a mem op, we just need to check it is promotable if - // necessary. - if (auto memOp = dyn_cast(user)) { - if (!shouldBePromotable.contains(memOp)) - continue; - MemorySlot concernedSlot = shouldBePromotable.at(memOp); - assert(0 && "todo"); - } - SmallPtrSet &blockingUses = info.userToBlockingUses[user]; auto promotable = dyn_cast(user); From 40ad9c76a6596c7ad866787a215009aab1de94ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Wed, 26 Apr 2023 15:01:51 +0200 Subject: [PATCH 12/25] adjust interfaces + start LLVM IR implem --- .../mlir/Interfaces/MemorySlotInterfaces.td | 21 ++++++++++--------- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 16 +++++++------- mlir/lib/Transforms/SROA.cpp | 15 ++++++------- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index 0a2cb544c6e1e..ac9eb2f9c9d9d 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -210,7 +210,8 @@ def DestructibleAllocationOpInterface Returns the list of slots for which destruction should be attempted, specifying in which way the slot should be destructed into subslots. The subslots are indexed by attributes. The type of the pointers of each - subslots to be generated must be provided. + subslots to be generated must be provided. The type of the memory slot + must implement `DestructibleTypeInterface`. }], "::llvm::SmallVector<::mlir::DestructibleMemorySlot>", "getDestructibleSlots", @@ -241,7 +242,7 @@ def DestructibleAllocationOpInterface ]; } -def TypeSafeMemOpInterface : OpInterface<"TypeSafeMemOpInterface"> { +def TypeSafeOpInterface : OpInterface<"TypeSafeOpInterface"> { let description = [{ TODO }]; @@ -261,7 +262,7 @@ def TypeSafeMemOpInterface : OpInterface<"TypeSafeMemOpInterface"> { this method to `mustBeSafelyUsed`. }], "bool", - "onlyTypeSafeAccesses", + "ensureOnlyTypeSafeAccesses", (ins "const ::mlir::MemorySlot &":$slot, "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed) > @@ -269,7 +270,7 @@ def TypeSafeMemOpInterface : OpInterface<"TypeSafeMemOpInterface"> { } def DestructibleAccessorOpInterface - : OpInterface<"DestructibleAccessorOpInterface", [TypeSafeMemOpInterface]> { + : OpInterface<"DestructibleAccessorOpInterface"> { let description = [{ TODO }]; @@ -282,16 +283,16 @@ def DestructibleAccessorOpInterface destruction. This may involve creating new operations, and usually amounts to checking the pointer types match. - In case of success, adds to `subelements` the memory slots representing - the accessed subelements, if any. These slots can only be aliased by - other SubElementMemorySlots of the same type coming from the same parent - slot. This method must also register the indices it will access within - the `usedIndices` set. + This method must also register the indices it will access within the + `usedIndices` set. If the accessor generates new slots mapping to + subelements, they must be registered in `mustBeSafelyUsed` to ensure + they are used in a locally type-safe manner. }], "bool", "canRewire", (ins "const ::mlir::DestructibleMemorySlot &":$slot, - "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices) + "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices, + "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed) >, InterfaceMethod<[{ Rewires the use of a slot to the generated subslots, without deleting diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index a6b664f273c6d..7e015322ea59b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -155,14 +155,14 @@ DeletionKind LLVM::StoreOp::removeBlockingUses( return DeletionKind::Delete; } -bool LLVM::LoadOp::onlyTypeSafeAccesses( +bool LLVM::LoadOp::ensureOnlyTypeSafeAccesses( const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { - assert(0 && "todo"); + return getAddr() != slot.ptr || getType() == slot.elemType; } -bool LLVM::StoreOp::onlyTypeSafeAccesses( +bool LLVM::StoreOp::ensureOnlyTypeSafeAccesses( const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { - assert(0 && "todo"); + return getAddr() != slot.ptr || getValue().getType() == slot.elemType; } //===----------------------------------------------------------------------===// @@ -258,14 +258,14 @@ DeletionKind LLVM::GEPOp::removeBlockingUses( return DeletionKind::Delete; } -bool LLVM::GEPOp::onlyTypeSafeAccesses( +bool LLVM::GEPOp::ensureOnlyTypeSafeAccesses( const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { assert(0 && "todo"); } -bool LLVM::GEPOp::canRewire( - const DestructibleMemorySlot &slot, - ::llvm::SmallPtrSetImpl<::mlir::Attribute> &usedIndices) { +bool LLVM::GEPOp::canRewire(const DestructibleMemorySlot &slot, + SmallPtrSetImpl<::mlir::Attribute> &usedIndices, + SmallVectorImpl &mustBeSafelyUsed) { assert(0 && "todo"); } diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index aa0081f0c36a5..55ee29adcaf21 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -27,6 +27,8 @@ static V &getOrInsertDefault(DenseMap &map, K key) { Optional mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { + assert(isa(slot.slot.elemType)); + MemorySlotDestructionInfo info; SmallVector usedSafelyWorklist; @@ -35,11 +37,9 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { for (OpOperand &use : slot.slot.ptr.getUses()) { if (auto accessor = dyn_cast(use.getOwner())) { - if (accessor.canRewire(slot, info.usedIndices)) { - if (accessor.onlyTypeSafeAccesses(slot.slot, usedSafelyWorklist)) { - info.accessors.push_back(accessor); - continue; - } + if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist)) { + info.accessors.push_back(accessor); + continue; } } @@ -59,8 +59,9 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { dealtWith.insert(&subslotUse); Operation *subslotUser = subslotUse.getOwner(); - if (auto memOp = dyn_cast(subslotUser)) - if (memOp.onlyTypeSafeAccesses(mustBeUsedSafely, usedSafelyWorklist)) + if (auto memOp = dyn_cast(subslotUser)) + if (memOp.ensureOnlyTypeSafeAccesses(mustBeUsedSafely, + usedSafelyWorklist)) continue; // If it cannot be shown that the operation uses the slot safely, maybe it From 2bd8bbc49c355012ec15c549aba714813c6ee287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Thu, 27 Apr 2023 16:21:53 +0200 Subject: [PATCH 13/25] finalize LLVM SROA implementation --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 6 +- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 7 + mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td | 4 +- .../mlir/Interfaces/MemorySlotInterfaces.td | 16 +- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 172 ++++++++++++++++-- mlir/lib/Transforms/SROA.cpp | 58 +++--- 6 files changed, 218 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index db722850fff7b..e78e73a41ef0d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -234,7 +234,7 @@ def LLVM_AllocaOp : LLVM_Op<"alloca", def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, Variadic>:$dynamicIndices, @@ -320,7 +320,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, def LLVM_LoadOp : LLVM_MemAccessOpBase<"load", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { dag args = (ins Arg, "", [MemRead]>:$addr, OptionalAttr:$alignment, UnitAttr:$volatile_, @@ -393,7 +393,7 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load", def LLVM_StoreOp : LLVM_MemAccessOpBase<"store", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { dag args = (ins LLVM_LoadableType:$value, Arg,"",[MemWrite]>:$addr, OptionalAttr:$alignment, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index 9ae0ba6365dfa..fe054d814d4d3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -16,6 +16,7 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include namespace llvm { @@ -103,6 +104,7 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType); class LLVMStructType : public Type::TypeBase { public: /// Inherit base constructors. @@ -198,6 +200,11 @@ class LLVMStructType LogicalResult verifyEntries(DataLayoutEntryListRef entries, Location loc) const; + + /// Destructs the struct into its indexed field types. + Optional> destruct(); + + Type getTypeAtIndex(Attribute index); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td index caf4b58b87f56..f861ba4b0af88 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -12,6 +12,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/IR/AttrTypeBase.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" /// Base class for all LLVM dialect types. class LLVMType traits = []> @@ -24,7 +25,8 @@ class LLVMType traits = []> //===----------------------------------------------------------------------===// def LLVMArrayType : LLVMType<"LLVMArray", "array", [ - DeclareTypeInterfaceMethods]> { + DeclareTypeInterfaceMethods, + DeclareTypeInterfaceMethods]> { let summary = "LLVM array type"; let description = [{ The `!llvm.array` type represents a fixed-size array of element types. diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index ac9eb2f9c9d9d..25358c9bf4151 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -261,7 +261,7 @@ def TypeSafeOpInterface : OpInterface<"TypeSafeOpInterface"> { conditioned to the type-safety of the accesses to the slots added by this method to `mustBeSafelyUsed`. }], - "bool", + "::mlir::LogicalResult", "ensureOnlyTypeSafeAccesses", (ins "const ::mlir::MemorySlot &":$slot, "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed) @@ -316,11 +316,21 @@ def DestructibleTypeInterface let methods = [ InterfaceMethod<[{ Destructs the type into subelements into a map of index attributes to - types of subelements. + types of subelements. Returns nothing if the type cannot be destructed. }], - "::llvm::DenseMap<::mlir::Attribute, ::mlir::Type>", + "::llvm::Optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type>>", "destruct", (ins) + >, + InterfaceMethod<[{ + Indicates which type is held at the provided index, returning a null + Type if no type could be computed. While this can return information + even when the type cannot be completely destructed, it must be coherent + with the values returned by `destruct` when they exist. + }], + "::mlir::Type", + "getTypeAtIndex", + (ins "::mlir::Attribute":$index) > ]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 7e015322ea59b..4914f72d19b3f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -13,7 +13,9 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Support/LogicalResult.h" using namespace mlir; @@ -52,13 +54,31 @@ void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot, } SmallVector LLVM::AllocaOp::getDestructibleSlots() { + auto numElem = dyn_cast(getArraySize().getDefiningOp()); + if (!numElem) + return {}; + auto numElemIntAttr = numElem.getValue().dyn_cast(); + if (!numElemIntAttr || !numElemIntAttr.getType().isSignlessInteger() || + numElemIntAttr.getInt() != 1) + return {}; + Type elemType = getElemType() ? *getElemType() : getResult().getType().getElementType(); auto destructible = dyn_cast(elemType); if (!destructible) return {}; + + Optional> destructedType = destructible.destruct(); + if (!destructedType) + return {}; + + DenseMap allocaTypeMap; + for (Attribute index : llvm::make_first_range(destructedType.value())) { + allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())}); + } + return {DestructibleMemorySlot{{getResult(), elemType, getLoc()}, - {destructible.destruct()}}}; + {allocaTypeMap}}}; } DenseMap @@ -69,9 +89,12 @@ LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, Type elemType = getElemType() ? *getElemType() : getResult().getType().getElementType(); + builder.setInsertionPointAfter(*this); + DenseMap slotMap; - for (auto &[index, type] : - cast(elemType).destruct()) { + Optional> destructedType = + cast(elemType).destruct(); + for (auto &[index, type] : destructedType.value()) { if (usedIndices.contains(index)) { auto subAlloca = builder.create( getLoc(), LLVM::LLVMPointerType::get(getContext()), type, @@ -155,14 +178,15 @@ DeletionKind LLVM::StoreOp::removeBlockingUses( return DeletionKind::Delete; } -bool LLVM::LoadOp::ensureOnlyTypeSafeAccesses( +LogicalResult LLVM::LoadOp::ensureOnlyTypeSafeAccesses( const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { - return getAddr() != slot.ptr || getType() == slot.elemType; + return success(getAddr() != slot.ptr || getType() == slot.elemType); } -bool LLVM::StoreOp::ensureOnlyTypeSafeAccesses( +LogicalResult LLVM::StoreOp::ensureOnlyTypeSafeAccesses( const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { - return getAddr() != slot.ptr || getValue().getType() == slot.elemType; + return success(getAddr() != slot.ptr || + getValue().getType() == slot.elemType); } //===----------------------------------------------------------------------===// @@ -258,18 +282,142 @@ DeletionKind LLVM::GEPOp::removeBlockingUses( return DeletionKind::Delete; } -bool LLVM::GEPOp::ensureOnlyTypeSafeAccesses( +static std::pair computeReachedGEPType(LLVM::GEPOp gep) { + bool isCheckingPointer = true; + Optional maybeSelectedType = gep.getElemType(); + if (!maybeSelectedType) + return {}; + Type selectedType = maybeSelectedType.value(); + Attribute firstLevelIndex; + for (const auto &index : gep.getIndices()) { + IntegerAttr indexInt = index.dyn_cast(); + if (!indexInt) + return {}; + if (isCheckingPointer) { + isCheckingPointer = false; + if (indexInt.getInt() != 0) + return {}; + continue; + } + assert(!selectedType.isa()); + if (!firstLevelIndex) + firstLevelIndex = indexInt; + auto destructible = selectedType.dyn_cast(); + if (!destructible) + return {}; + Type field = destructible.getTypeAtIndex(indexInt); + if (!field) + return {}; + selectedType = field; + } + return std::make_pair(selectedType, firstLevelIndex); +} + +LogicalResult LLVM::GEPOp::ensureOnlyTypeSafeAccesses( const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { - assert(0 && "todo"); + if (getBase() != slot.ptr) + return success(); + if (slot.elemType != getElemType()) + return failure(); + auto [reachedType, _] = computeReachedGEPType(*this); + if (!reachedType) + return failure(); + mustBeSafelyUsed.emplace_back( + {getResult(), reachedType, getResult().getLoc()}); + return success(); } bool LLVM::GEPOp::canRewire(const DestructibleMemorySlot &slot, - SmallPtrSetImpl<::mlir::Attribute> &usedIndices, + SmallPtrSetImpl &usedIndices, SmallVectorImpl &mustBeSafelyUsed) { - assert(0 && "todo"); + if (getBase() != slot.slot.ptr || slot.slot.elemType != getElemType()) + return false; + auto [reachedType, firstLevelIndex] = computeReachedGEPType(*this); + if (!reachedType || !firstLevelIndex) + return false; + assert(slot.info.elementsPtrs.contains(firstLevelIndex)); + if (!slot.info.elementsPtrs.at(firstLevelIndex).isa()) + return false; + mustBeSafelyUsed.emplace_back( + {getResult(), reachedType, getResult().getLoc()}); + usedIndices.insert(firstLevelIndex); + return true; } DeletionKind LLVM::GEPOp::rewire(const DestructibleMemorySlot &slot, DenseMap &subslots) { - assert(0 && "todo"); + IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast(); + const MemorySlot &newSlot = subslots.at(firstLevelIndex); + + ArrayRef remainingIndices = getRawConstantIndices().slice(2); + + // If the GEP would become trivial after this transformation, eliminate it. + if (llvm::all_of(remainingIndices, + [](int32_t index) { return index == 0; })) { + getResult().replaceAllUsesWith(newSlot.ptr); + return DeletionKind::Delete; + } + + // Rewire the indices by popping off the second index. + SmallVector newIndices; + newIndices.reserve(remainingIndices.size() + 1); + newIndices.push_back(0); + newIndices.append(remainingIndices.begin(), remainingIndices.end()); + setRawConstantIndices(newIndices); + + // Rewire the pointed type. + setElemType(newSlot.elemType); + + // Rewire the pointer. + getBaseMutable().assign(newSlot.ptr); + + return DeletionKind::Keep; +} + +//===----------------------------------------------------------------------===// +// Interfaces for destructible types +//===----------------------------------------------------------------------===// + +Optional> LLVM::LLVMStructType::destruct() { + int32_t index = 0; + Type i32 = IntegerType::get(getContext(), 32); + DenseMap destructured; + for (Type elemType : getBody()) { + destructured.insert({IntegerAttr::get(i32, index), elemType}); + index++; + } + return destructured; +} + +Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) { + auto indexAttr = index.dyn_cast(); + if (!indexAttr || !indexAttr.getType().isInteger(32)) + return {}; + int32_t indexInt = indexAttr.getInt(); + ArrayRef body = getBody(); + if (indexInt < 0 || body.size() <= ((uint32_t)indexInt)) + return {}; + return body[indexInt]; +} + +Optional> LLVM::LLVMArrayType::destruct() const { + if (getNumElements() > 16) + return {}; + int32_t numElements = getNumElements(); + + Type i32 = IntegerType::get(getContext(), 32); + DenseMap destructured; + for (int32_t index = 0; index < numElements; index++) + destructured.insert({IntegerAttr::get(i32, index), getElementType()}); + return destructured; +} + +Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const { + auto indexAttr = index.dyn_cast(); + if (!indexAttr || !indexAttr.getType().isInteger(32)) + return {}; + int32_t indexInt = indexAttr.getInt(); + if (indexInt < 0 || getNumElements() <= ((uint32_t)indexInt)) + return {}; + return getElementType(); } diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 55ee29adcaf21..e3094e6aea9f3 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -60,8 +60,8 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { Operation *subslotUser = subslotUse.getOwner(); if (auto memOp = dyn_cast(subslotUser)) - if (memOp.ensureOnlyTypeSafeAccesses(mustBeUsedSafely, - usedSafelyWorklist)) + if (succeeded(memOp.ensureOnlyTypeSafeAccesses(mustBeUsedSafely, + usedSafelyWorklist))) continue; // If it cannot be shown that the operation uses the slot safely, maybe it @@ -127,16 +127,16 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, llvm::SmallVector toErase; for (Operation *toRewire : llvm::reverse(sortedUsersToRewire)) { builder.setInsertionPointAfter(toRewire); - if (auto promotable = dyn_cast(toRewire)) { - if (promotable.removeBlockingUses(info.userToBlockingUses[promotable], - builder) == DeletionKind::Delete) - toErase.push_back(promotable); + if (auto accessor = dyn_cast(toRewire)) { + if (accessor.rewire(slot, subslots) == DeletionKind::Delete) + toErase.push_back(accessor); continue; } - auto accessor = cast(toRewire); - if (accessor.rewire(slot, subslots) == DeletionKind::Delete) - toErase.push_back(accessor); + auto promotable = cast(toRewire); + if (promotable.removeBlockingUses(info.userToBlockingUses[promotable], + builder) == DeletionKind::Delete) + toErase.push_back(promotable); } for (Operation *toEraseOp : toErase) @@ -164,26 +164,32 @@ struct SROA : public impl::SROABase { // Destructing a slot can allow for further destruction of other slots, // destruction is tried until no destruction succeeds. - bool justDestructed = true; - while (justDestructed) { - justDestructed = false; + while (true) { + struct DestructionJob { + DestructibleAllocationOpInterface allocator; + DestructibleMemorySlot slot; + MemorySlotDestructionInfo info; + }; - for (Block &block : region) { - for (Operation &op : block.getOperations()) { + std::vector toDestruct; + + for (Block &block : region) + for (Operation &op : block.getOperations()) if (auto allocator = - llvm::dyn_cast(op)) { + llvm::dyn_cast(op)) for (DestructibleMemorySlot slot : - allocator.getDestructibleSlots()) { - if (auto info = computeDestructionInfo(slot)) { - destructSlot(slot, allocator, builder, info.value()); - justDestructed = true; - } - } - } - } - } - - changed |= justDestructed; + allocator.getDestructibleSlots()) + if (auto info = computeDestructionInfo(slot)) + toDestruct.emplace_back( + {allocator, slot, std::move(info.value())}); + + if (toDestruct.empty()) + break; + + for (DestructionJob &job : toDestruct) + destructSlot(job.slot, job.allocator, builder, job.info); + + changed = true; } } From cd628932abcc09b2dfdd1b28d7c17e326c90630a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Thu, 27 Apr 2023 16:26:44 +0200 Subject: [PATCH 14/25] add move --- mlir/lib/Transforms/SROA.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index e3094e6aea9f3..55853453b5430 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -181,7 +181,7 @@ struct SROA : public impl::SROABase { allocator.getDestructibleSlots()) if (auto info = computeDestructionInfo(slot)) toDestruct.emplace_back( - {allocator, slot, std::move(info.value())}); + {allocator, std::move(slot), std::move(info.value())}); if (toDestruct.empty()) break; From 14f666e77652a6798169c94252c052a1e2859416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Thu, 27 Apr 2023 17:36:10 +0200 Subject: [PATCH 15/25] remove loc from memory slot --- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 4914f72d19b3f..35de2f304152b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -77,8 +77,7 @@ SmallVector LLVM::AllocaOp::getDestructibleSlots() { allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())}); } - return {DestructibleMemorySlot{{getResult(), elemType, getLoc()}, - {allocaTypeMap}}}; + return {DestructibleMemorySlot{{getResult(), elemType}, {allocaTypeMap}}}; } DenseMap @@ -99,8 +98,7 @@ LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, auto subAlloca = builder.create( getLoc(), LLVM::LLVMPointerType::get(getContext()), type, getArraySize()); - slotMap.try_emplace(index, - {subAlloca.getResult(), type, getLoc()}); + slotMap.try_emplace(index, {subAlloca.getResult(), type}); } } @@ -322,8 +320,7 @@ LogicalResult LLVM::GEPOp::ensureOnlyTypeSafeAccesses( auto [reachedType, _] = computeReachedGEPType(*this); if (!reachedType) return failure(); - mustBeSafelyUsed.emplace_back( - {getResult(), reachedType, getResult().getLoc()}); + mustBeSafelyUsed.emplace_back({getResult(), reachedType}); return success(); } @@ -338,8 +335,7 @@ bool LLVM::GEPOp::canRewire(const DestructibleMemorySlot &slot, assert(slot.info.elementsPtrs.contains(firstLevelIndex)); if (!slot.info.elementsPtrs.at(firstLevelIndex).isa()) return false; - mustBeSafelyUsed.emplace_back( - {getResult(), reachedType, getResult().getLoc()}); + mustBeSafelyUsed.emplace_back({getResult(), reachedType}); usedIndices.insert(firstLevelIndex); return true; } From a4fce9a6c5de0a51390083a17068796eef3d9eb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Fri, 28 Apr 2023 17:02:43 +0200 Subject: [PATCH 16/25] beginning of SROA test suite --- mlir/lib/Transforms/SROA.cpp | 11 +-- mlir/test/Transforms/sroa-llvmir.mlir | 124 ++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 8 deletions(-) create mode 100644 mlir/test/Transforms/sroa-llvmir.mlir diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 55853453b5430..3472222e24420 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -20,11 +20,6 @@ namespace mlir { using namespace mlir; -template -static V &getOrInsertDefault(DenseMap &map, K key) { - return map.try_emplace(key).first->second; -} - Optional mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { assert(isa(slot.slot.elemType)); @@ -46,7 +41,7 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { // If it cannot be shown that the operation uses the slot safely, maybe it // can be promoted out of using the slot? SmallPtrSet &blockingUses = - getOrInsertDefault(info.userToBlockingUses, use.getOwner()); + info.userToBlockingUses.getOrInsertDefault(use.getOwner()); blockingUses.insert(&use); } @@ -67,7 +62,7 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { // If it cannot be shown that the operation uses the slot safely, maybe it // can be promoted out of using the slot? SmallPtrSet &blockingUses = - getOrInsertDefault(info.userToBlockingUses, subslotUser); + info.userToBlockingUses.getOrInsertDefault(subslotUser); blockingUses.insert(&subslotUse); } } @@ -99,7 +94,7 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { user->result_end()); SmallPtrSetImpl &newUserBlockingUseSet = - getOrInsertDefault(info.userToBlockingUses, blockingUse->getOwner()); + info.userToBlockingUses.getOrInsertDefault(blockingUse->getOwner()); newUserBlockingUseSet.insert(blockingUse); } } diff --git a/mlir/test/Transforms/sroa-llvmir.mlir b/mlir/test/Transforms/sroa-llvmir.mlir new file mode 100644 index 0000000000000..5a5f7be2aba47 --- /dev/null +++ b/mlir/test/Transforms/sroa-llvmir.mlir @@ -0,0 +1,124 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(sroa))" --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @basic_struct +llvm.func @basic_struct() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @basic_array +llvm.func @basic_array() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @multi_level_direct +llvm.func @multi_level_direct() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2, 1, 5, 8] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @multi_level_indirect +llvm.func @multi_level_indirect() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2, 1, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> + %3 = llvm.getelementptr inbounds %2[0, 8] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %4 = llvm.load %3 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %4 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @resolve_alias +// CHECK-SAME: (%[[ARG:.*]]: i32) +llvm.func @resolve_alias(%arg: i32) -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + %3 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %2 : i32, !llvm.ptr + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %4 = llvm.load %3 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %4 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_non_single_support +llvm.func @no_non_single_support() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant + %0 = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_pointer_indexing +llvm.func @no_pointer_indexing() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[1, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_direct_use +llvm.func @no_direct_use() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.call @use(%1) : (!llvm.ptr) -> () + llvm.return %3 : i32 +} + +llvm.func @use(!llvm.ptr) From fc227c0e33f2ee1628c689accb631dfeb88fcacf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 2 May 2023 13:56:44 +0200 Subject: [PATCH 17/25] add more tests --- mlir/test/Transforms/sroa-llvmir.mlir | 61 +++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/mlir/test/Transforms/sroa-llvmir.mlir b/mlir/test/Transforms/sroa-llvmir.mlir index 5a5f7be2aba47..92684c27b13dc 100644 --- a/mlir/test/Transforms/sroa-llvmir.mlir +++ b/mlir/test/Transforms/sroa-llvmir.mlir @@ -85,8 +85,8 @@ llvm.func @no_non_single_support() -> i32 { // CHECK: %[[SIZE:.*]] = llvm.mlir.constant %0 = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr - // CHECK-NOT: = llvm.alloca %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> %3 = llvm.load %2 : !llvm.ptr -> i32 llvm.return %3 : i32 @@ -99,8 +99,8 @@ llvm.func @no_pointer_indexing() -> i32 { // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) %0 = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr - // CHECK-NOT: = llvm.alloca %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca %2 = llvm.getelementptr %1[1, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> %3 = llvm.load %2 : !llvm.ptr -> i32 llvm.return %3 : i32 @@ -113,8 +113,8 @@ llvm.func @no_direct_use() -> i32 { // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) %0 = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr - // CHECK-NOT: = llvm.alloca %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> %3 = llvm.load %2 : !llvm.ptr -> i32 llvm.call @use(%1) : (!llvm.ptr) -> () @@ -122,3 +122,58 @@ llvm.func @no_direct_use() -> i32 { } llvm.func @use(!llvm.ptr) + +// ----- + +// CHECK-LABEL: llvm.func @direct_promotable_use_is_fine +llvm.func @direct_promotable_use_is_fine() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // This is a direct use of the slot but it can be removed because it implements PromotableOpInterface. + llvm.intr.lifetime.start 2, %1 : !llvm.ptr + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +llvm.func @use(!llvm.ptr) + +// ----- + +// CHECK-LABEL: llvm.func @direct_promotable_use_is_fine_on_accessor +llvm.func @direct_promotable_use_is_fine_on_accessor() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // This does not provide side-effect info but it can be removed because it implements PromotableOpInterface. + llvm.intr.lifetime.start 2, %2 : !llvm.ptr + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +llvm.func @use(!llvm.ptr) + +// ----- + + +// CHECK-LABEL: llvm.func @no_dynamic_indexing +llvm.func @no_dynamic_indexing() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca + %2 = llvm.getelementptr %1[0, %0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +llvm.func @use(!llvm.ptr) From 07cc660c9690402d636443ea2610253b177c0437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 2 May 2023 15:47:08 +0200 Subject: [PATCH 18/25] clarify gep-related code --- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 36 ++++++++++--------- mlir/test/Transforms/sroa-llvmir.mlir | 1 - 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 35de2f304152b..aeb6ea9dd4f6b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; @@ -280,26 +281,27 @@ DeletionKind LLVM::GEPOp::removeBlockingUses( return DeletionKind::Delete; } -static std::pair computeReachedGEPType(LLVM::GEPOp gep) { - bool isCheckingPointer = true; +static Type computeReachedGEPType(LLVM::GEPOp gep) { + if (gep.getIndices().empty()) + return {}; + + // Check the pointer indexing only targets the first element. + auto firstIndex = gep.getIndices()[0]; + IntegerAttr indexInt = firstIndex.dyn_cast(); + if (!indexInt || indexInt.getInt() != 0) + return {}; + Optional maybeSelectedType = gep.getElemType(); if (!maybeSelectedType) return {}; Type selectedType = maybeSelectedType.value(); - Attribute firstLevelIndex; - for (const auto &index : gep.getIndices()) { + + // Follow the indexed elements in the gep. + for (const auto &index : llvm::drop_begin(gep.getIndices())) { IntegerAttr indexInt = index.dyn_cast(); if (!indexInt) return {}; - if (isCheckingPointer) { - isCheckingPointer = false; - if (indexInt.getInt() != 0) - return {}; - continue; - } assert(!selectedType.isa()); - if (!firstLevelIndex) - firstLevelIndex = indexInt; auto destructible = selectedType.dyn_cast(); if (!destructible) return {}; @@ -308,7 +310,8 @@ static std::pair computeReachedGEPType(LLVM::GEPOp gep) { return {}; selectedType = field; } - return std::make_pair(selectedType, firstLevelIndex); + + return selectedType; } LogicalResult LLVM::GEPOp::ensureOnlyTypeSafeAccesses( @@ -317,7 +320,7 @@ LogicalResult LLVM::GEPOp::ensureOnlyTypeSafeAccesses( return success(); if (slot.elemType != getElemType()) return failure(); - auto [reachedType, _] = computeReachedGEPType(*this); + Type reachedType = computeReachedGEPType(*this); if (!reachedType) return failure(); mustBeSafelyUsed.emplace_back({getResult(), reachedType}); @@ -329,9 +332,10 @@ bool LLVM::GEPOp::canRewire(const DestructibleMemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { if (getBase() != slot.slot.ptr || slot.slot.elemType != getElemType()) return false; - auto [reachedType, firstLevelIndex] = computeReachedGEPType(*this); - if (!reachedType || !firstLevelIndex) + Type reachedType = computeReachedGEPType(*this); + if (!reachedType || getIndices().size() < 2) return false; + auto firstLevelIndex = cast(getIndices()[1]); assert(slot.info.elementsPtrs.contains(firstLevelIndex)); if (!slot.info.elementsPtrs.at(firstLevelIndex).isa()) return false; diff --git a/mlir/test/Transforms/sroa-llvmir.mlir b/mlir/test/Transforms/sroa-llvmir.mlir index 92684c27b13dc..c57af9fccd7c7 100644 --- a/mlir/test/Transforms/sroa-llvmir.mlir +++ b/mlir/test/Transforms/sroa-llvmir.mlir @@ -163,7 +163,6 @@ llvm.func @use(!llvm.ptr) // ----- - // CHECK-LABEL: llvm.func @no_dynamic_indexing llvm.func @no_dynamic_indexing() -> i32 { // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) From 207e278cc2faa3ef43611a91fa69c471204006e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 2 May 2023 16:12:59 +0200 Subject: [PATCH 19/25] complete documentation --- .../mlir/Interfaces/MemorySlotInterfaces.h | 20 +++-------------- .../mlir/Interfaces/MemorySlotInterfaces.td | 22 +++++++++---------- 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h index f1ef1b18d0329..61fe87998284e 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -23,33 +23,19 @@ struct MemorySlot { Type elemType; }; +/// Information about the destruction procedure of a destructible memory slot. struct DestructibleSlotInfo { /// Maps from an index within the memory slot the type of the pointer that /// will be generated to access the element directly. - DenseMap elementsPtrs; -}; - -struct SubElementSlotInfo { - /// Index of this memory slot in the parent memory slot. - Attribute subelementIndex; + DenseMap elementPtrs; }; +/// Memory slot attached with information about its destruction procedure. struct DestructibleMemorySlot { MemorySlot slot; DestructibleSlotInfo info; }; -struct SubElementMemorySlot { - MemorySlot slot; - SubElementSlotInfo info; -}; - -struct MaybeDestructibleSubElementMemorySlot { - MemorySlot slot; - SubElementSlotInfo subElementInfo; - Optional destructibleInfo; -}; - /// Returned by operation promotion logic requesting the deletion of an /// operation. enum class DeletionKind { diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index 25358c9bf4151..2b664c6a6526a 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -154,10 +154,8 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { let methods = [ InterfaceMethod<[{ - TODO - Checks that this operation can be promoted to no longer use the provided - blocking uses, in the context of promoting `slot`. + blocking uses, in order to allow optimization. If the removal procedure of the use will require that other uses get removed, that dependency should be added to the `newBlockingUses` @@ -167,12 +165,11 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses) >, InterfaceMethod<[{ - TODO - Transforms IR to ensure that the current operation does not use the - provided memory slot anymore. In contrast to `PromotableMemOpInterface`, - operations implementing this interface must not need access to the - reaching definition of the content of the slot. + provided blocking uses anymore. In contrast to + `PromotableMemOpInterface`, operations implementing this interface + must not need access to the reaching definition of the content of the + slot. During the transformation, *no operation should be deleted*. The operation can only schedule its own deletion by returning the @@ -201,7 +198,8 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { def DestructibleAllocationOpInterface : OpInterface<"DestructibleAllocationOpInterface"> { let description = [{ - TODO + Describes operations allocating memory slots of aggregates that can be + destructed into multiple smaller allocations. }]; let cppNamespace = "::mlir"; @@ -244,7 +242,7 @@ def DestructibleAllocationOpInterface def TypeSafeOpInterface : OpInterface<"TypeSafeOpInterface"> { let description = [{ - TODO + Describes operations using memory slots in a type-safe manner. }]; let cppNamespace = "::mlir"; @@ -272,7 +270,7 @@ def TypeSafeOpInterface : OpInterface<"TypeSafeOpInterface"> { def DestructibleAccessorOpInterface : OpInterface<"DestructibleAccessorOpInterface"> { let description = [{ - TODO + Describes operations that can access a sub-element of a destructible slot. }]; let cppNamespace = "::mlir"; @@ -309,7 +307,7 @@ def DestructibleAccessorOpInterface def DestructibleTypeInterface : TypeInterface<"DestructibleTypeInterface"> { let description = [{ - TODO + Describes a type that can be broken down into indexable sub-element types. }]; let cppNamespace = "::mlir"; From a28abf28adfe564dc213d7633485b7181a9c77cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 2 May 2023 16:40:17 +0200 Subject: [PATCH 20/25] address comments --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 3 +- .../mlir/Interfaces/MemorySlotInterfaces.h | 12 ++--- .../mlir/Interfaces/MemorySlotInterfaces.td | 6 +-- mlir/include/mlir/Transforms/SROA.h | 3 ++ mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 21 +++++---- mlir/lib/Transforms/SROA.cpp | 46 +++++++++---------- 6 files changed, 45 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index fe054d814d4d3..ce08a7997e4fc 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -202,8 +202,9 @@ class LLVMStructType Location loc) const; /// Destructs the struct into its indexed field types. - Optional> destruct(); + Optional> getDestructedLayout(); + /// Retyurns which type is stored at a given integer index within the struct. Type getTypeAtIndex(Attribute index); }; diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h index 61fe87998284e..704d6fc71deee 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -23,19 +23,13 @@ struct MemorySlot { Type elemType; }; -/// Information about the destruction procedure of a destructible memory slot. -struct DestructibleSlotInfo { - /// Maps from an index within the memory slot the type of the pointer that +/// Memory slot attached with information about its destruction procedure. +struct DestructibleMemorySlot : MemorySlot { + /// Maps an index within the memory slot to the type of the pointer that /// will be generated to access the element directly. DenseMap elementPtrs; }; -/// Memory slot attached with information about its destruction procedure. -struct DestructibleMemorySlot { - MemorySlot slot; - DestructibleSlotInfo info; -}; - /// Returned by operation promotion logic requesting the deletion of an /// operation. enum class DeletionKind { diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index 2b664c6a6526a..1d5ff1401058a 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -207,8 +207,8 @@ def DestructibleAllocationOpInterface InterfaceMethod<[{ Returns the list of slots for which destruction should be attempted, specifying in which way the slot should be destructed into subslots. The - subslots are indexed by attributes. The type of the pointers of each - subslots to be generated must be provided. The type of the memory slot + subslots are indexed by attributes. This computes the type of the + pointers of each subslots to be generated. The type of the memory slot must implement `DestructibleTypeInterface`. }], "::llvm::SmallVector<::mlir::DestructibleMemorySlot>", @@ -317,7 +317,7 @@ def DestructibleTypeInterface types of subelements. Returns nothing if the type cannot be destructed. }], "::llvm::Optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type>>", - "destruct", + "getDestructedLayout", (ins) >, InterfaceMethod<[{ diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h index a6eec3b0dbc09..955d48d47a4a2 100644 --- a/mlir/include/mlir/Transforms/SROA.h +++ b/mlir/include/mlir/Transforms/SROA.h @@ -26,6 +26,9 @@ struct MemorySlotDestructionInfo { Optional computeDestructionInfo(DestructibleMemorySlot &slot); +/// Performs the destruction of a destructible slot given associated destruction +/// information. The provided slot will be destructed in subslots by its +/// allocator. void destructSlot(DestructibleMemorySlot &slot, DestructibleAllocationOpInterface allocator, OpBuilder &builder, MemorySlotDestructionInfo &info); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index aeb6ea9dd4f6b..59993d2e41aad 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -69,7 +69,8 @@ SmallVector LLVM::AllocaOp::getDestructibleSlots() { if (!destructible) return {}; - Optional> destructedType = destructible.destruct(); + Optional> destructedType = + destructible.getDestructedLayout(); if (!destructedType) return {}; @@ -85,7 +86,7 @@ DenseMap LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, SmallPtrSetImpl &usedIndices, OpBuilder &builder) { - assert(slot.slot.ptr == getResult()); + assert(slot.ptr == getResult()); Type elemType = getElemType() ? *getElemType() : getResult().getType().getElementType(); @@ -93,7 +94,7 @@ LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, DenseMap slotMap; Optional> destructedType = - cast(elemType).destruct(); + cast(elemType).getDestructedLayout(); for (auto &[index, type] : destructedType.value()) { if (usedIndices.contains(index)) { auto subAlloca = builder.create( @@ -108,7 +109,7 @@ LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, void LLVM::AllocaOp::handleDestructionComplete( const DestructibleMemorySlot &slot) { - assert(slot.slot.ptr == getResult()); + assert(slot.ptr == getResult()); erase(); } @@ -330,14 +331,14 @@ LogicalResult LLVM::GEPOp::ensureOnlyTypeSafeAccesses( bool LLVM::GEPOp::canRewire(const DestructibleMemorySlot &slot, SmallPtrSetImpl &usedIndices, SmallVectorImpl &mustBeSafelyUsed) { - if (getBase() != slot.slot.ptr || slot.slot.elemType != getElemType()) + if (getBase() != slot.ptr || slot.elemType != getElemType()) return false; Type reachedType = computeReachedGEPType(*this); if (!reachedType || getIndices().size() < 2) return false; auto firstLevelIndex = cast(getIndices()[1]); - assert(slot.info.elementsPtrs.contains(firstLevelIndex)); - if (!slot.info.elementsPtrs.at(firstLevelIndex).isa()) + assert(slot.elementPtrs.contains(firstLevelIndex)); + if (!slot.elementPtrs.at(firstLevelIndex).isa()) return false; mustBeSafelyUsed.emplace_back({getResult(), reachedType}); usedIndices.insert(firstLevelIndex); @@ -378,7 +379,8 @@ DeletionKind LLVM::GEPOp::rewire(const DestructibleMemorySlot &slot, // Interfaces for destructible types //===----------------------------------------------------------------------===// -Optional> LLVM::LLVMStructType::destruct() { +Optional> +LLVM::LLVMStructType::getDestructedLayout() { int32_t index = 0; Type i32 = IntegerType::get(getContext(), 32); DenseMap destructured; @@ -400,7 +402,8 @@ Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) { return body[indexInt]; } -Optional> LLVM::LLVMArrayType::destruct() const { +Optional> +LLVM::LLVMArrayType::getDestructedLayout() const { if (getNumElements() > 16) return {}; int32_t numElements = getNumElements(); diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 3472222e24420..d6f8aadbde76f 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Builders.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { @@ -22,14 +23,20 @@ using namespace mlir; Optional mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { - assert(isa(slot.slot.elemType)); + assert(isa(slot.elemType)); MemorySlotDestructionInfo info; SmallVector usedSafelyWorklist; + auto scheduleAsBlockingUse = [&](OpOperand &use) { + SmallPtrSet &blockingUses = + info.userToBlockingUses.getOrInsertDefault(use.getOwner()); + blockingUses.insert(&use); + }; + // Initialize the analysis with the immediate users of the slot. - for (OpOperand &use : slot.slot.ptr.getUses()) { + for (OpOperand &use : slot.ptr.getUses()) { if (auto accessor = dyn_cast(use.getOwner())) { if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist)) { @@ -40,9 +47,7 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { // If it cannot be shown that the operation uses the slot safely, maybe it // can be promoted out of using the slot? - SmallPtrSet &blockingUses = - info.userToBlockingUses.getOrInsertDefault(use.getOwner()); - blockingUses.insert(&use); + scheduleAsBlockingUse(use); } SmallPtrSet dealtWith; @@ -61,14 +66,12 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { // If it cannot be shown that the operation uses the slot safely, maybe it // can be promoted out of using the slot? - SmallPtrSet &blockingUses = - info.userToBlockingUses.getOrInsertDefault(subslotUser); - blockingUses.insert(&subslotUse); + scheduleAsBlockingUse(subslotUse); } } SetVector forwardSlice; - mlir::getForwardSlice(slot.slot.ptr, &forwardSlice); + mlir::getForwardSlice(slot.ptr, &forwardSlice); for (Operation *user : forwardSlice) { // If the next operation has no blocking uses, everything is fine. if (!info.userToBlockingUses.contains(user)) @@ -90,8 +93,7 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { // Then, register any new blocking uses for coming operations. for (OpOperand *blockingUse : newBlockingUses) { - assert(llvm::find(user->getResults(), blockingUse->get()) != - user->result_end()); + assert(llvm::is_contained(user->getResults(), blockingUse->get())); SmallPtrSetImpl &newUserBlockingUseSet = info.userToBlockingUses.getOrInsertDefault(blockingUse->getOwner()); @@ -107,7 +109,7 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, OpBuilder &builder, MemorySlotDestructionInfo &info) { OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(slot.slot.ptr.getParentBlock()); + builder.setInsertionPointToStart(slot.ptr.getParentBlock()); DenseMap subslots = allocator.destruct(slot, info.usedIndices, builder); @@ -137,9 +139,8 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, for (Operation *toEraseOp : toErase) toEraseOp->erase(); - assert(slot.slot.ptr.use_empty() && - "at the end of destruction, the original slot " - "pointer should no longer be used"); + assert(slot.ptr.use_empty() && "at the end of destruction, the original slot " + "pointer should no longer be used"); allocator.handleDestructionComplete(slot); } @@ -168,15 +169,12 @@ struct SROA : public impl::SROABase { std::vector toDestruct; - for (Block &block : region) - for (Operation &op : block.getOperations()) - if (auto allocator = - llvm::dyn_cast(op)) - for (DestructibleMemorySlot slot : - allocator.getDestructibleSlots()) - if (auto info = computeDestructionInfo(slot)) - toDestruct.emplace_back( - {allocator, std::move(slot), std::move(info.value())}); + region.walk([&](DestructibleAllocationOpInterface allocator) { + for (DestructibleMemorySlot slot : allocator.getDestructibleSlots()) + if (auto info = computeDestructionInfo(slot)) + toDestruct.emplace_back( + {allocator, std::move(slot), std::move(info.value())}); + }); if (toDestruct.empty()) break; From 8e6c3ecd02eaf879855776a788ecdfe338d58642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Tue, 2 May 2023 17:19:00 +0200 Subject: [PATCH 21/25] address some comments --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 2 +- mlir/include/mlir/Transforms/Passes.td | 13 ++++++-- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 31 ++++++++++--------- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index ce08a7997e4fc..c06e621c81c81 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -204,7 +204,7 @@ class LLVMStructType /// Destructs the struct into its indexed field types. Optional> getDestructedLayout(); - /// Retyurns which type is stored at a given integer index within the struct. + /// Returns which type is stored at a given integer index within the struct. Type getTypeAtIndex(Attribute index); }; diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 00509f342e8b9..16f80cadc5b8c 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -215,9 +215,18 @@ def SCCP : Pass<"sccp"> { } def SROA : Pass<"sroa"> { - let summary = "TODO"; + let summary = "Scalar Replacement of Aggregates"; let description = [{ - TODO + Scalar Replacement of Aggregates. Replaces allocations of aggregates into + independant allocations of its elements. + + Allocators must implement `DestructibleAllocationOpInterface` to provide + the list of memory slots for which destruction should be attempted. + + This pass will only be applied if all accessors of the aggregate implement + the `DestructibleAccessorOpInterface`. If the accessors provide a view + into the struct, users of the view must ensure it is used in a type-safe + manner and within bounds by implementing `TypeSafeOpInterface`. }]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 59993d2e41aad..72ff815271cc6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" @@ -55,16 +56,12 @@ void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot, } SmallVector LLVM::AllocaOp::getDestructibleSlots() { - auto numElem = dyn_cast(getArraySize().getDefiningOp()); - if (!numElem) - return {}; - auto numElemIntAttr = numElem.getValue().dyn_cast(); - if (!numElemIntAttr || !numElemIntAttr.getType().isSignlessInteger() || - numElemIntAttr.getInt() != 1) + if (!mlir::matchPattern(getArraySize(), m_One())) return {}; - Type elemType = - getElemType() ? *getElemType() : getResult().getType().getElementType(); + Type elemType = getElemType().has_value() + ? *getElemType() + : getResult().getType().getElementType(); auto destructible = dyn_cast(elemType); if (!destructible) return {}; @@ -75,9 +72,8 @@ SmallVector LLVM::AllocaOp::getDestructibleSlots() { return {}; DenseMap allocaTypeMap; - for (Attribute index : llvm::make_first_range(destructedType.value())) { + for (Attribute index : llvm::make_first_range(destructedType.value())) allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())}); - } return {DestructibleMemorySlot{{getResult(), elemType}, {allocaTypeMap}}}; } @@ -282,7 +278,10 @@ DeletionKind LLVM::GEPOp::removeBlockingUses( return DeletionKind::Delete; } +/// TODO: Support non-opaque pointers. static Type computeReachedGEPType(LLVM::GEPOp gep) { + assert(gep.getBase().getType().cast().isOpaque()); + if (gep.getIndices().empty()) return {}; @@ -295,17 +294,22 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) { Optional maybeSelectedType = gep.getElemType(); if (!maybeSelectedType) return {}; - Type selectedType = maybeSelectedType.value(); + Type selectedType = *maybeSelectedType; // Follow the indexed elements in the gep. for (const auto &index : llvm::drop_begin(gep.getIndices())) { + // Ensure the index is static and obtain it. IntegerAttr indexInt = index.dyn_cast(); if (!indexInt) return {}; + + // Ensure the structure of the type being indexed can be reasoned about. assert(!selectedType.isa()); auto destructible = selectedType.dyn_cast(); if (!destructible) return {}; + + // Follow the type at the index the gep is accessing. Type field = destructible.getTypeAtIndex(indexInt); if (!field) return {}; @@ -360,9 +364,8 @@ DeletionKind LLVM::GEPOp::rewire(const DestructibleMemorySlot &slot, } // Rewire the indices by popping off the second index. - SmallVector newIndices; - newIndices.reserve(remainingIndices.size() + 1); - newIndices.push_back(0); + // Start with a single zero, then add the indices beyond the second. + SmallVector newIndices(1); newIndices.append(remainingIndices.begin(), remainingIndices.end()); setRawConstantIndices(newIndices); From 3843f6e6a1cbb1e2637f4f2698838983b37d0815 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Wed, 3 May 2023 12:13:09 +0200 Subject: [PATCH 22/25] address a round of comments --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 2 +- .../mlir/Interfaces/MemorySlotInterfaces.td | 4 +-- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 30 ++++++++++++------- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index c06e621c81c81..7fc62e04942e3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -202,7 +202,7 @@ class LLVMStructType Location loc) const; /// Destructs the struct into its indexed field types. - Optional> getDestructedLayout(); + Optional> getSubelementIndexMap(); /// Returns which type is stored at a given integer index within the struct. Type getTypeAtIndex(Attribute index); diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index 1d5ff1401058a..8723d91d5fa9f 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -317,14 +317,14 @@ def DestructibleTypeInterface types of subelements. Returns nothing if the type cannot be destructed. }], "::llvm::Optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type>>", - "getDestructedLayout", + "getSubelementIndexMap", (ins) >, InterfaceMethod<[{ Indicates which type is held at the provided index, returning a null Type if no type could be computed. While this can return information even when the type cannot be completely destructed, it must be coherent - with the values returned by `destruct` when they exist. + with the types returned by `getSubelementIndexMap` when they exist. }], "::mlir::Type", "getTypeAtIndex", diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 72ff815271cc6..39e5b64921400 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -25,13 +25,18 @@ using namespace mlir; // Interfaces for AllocaOp //===----------------------------------------------------------------------===// +static Type getAllocaElementType(LLVM::AllocaOp alloca) { + // This will become trivial once non-opaque pointers are gone. + return alloca.getElemType().has_value() + ? *alloca.getElemType() + : alloca.getResult().getType().getElementType(); +} + llvm::SmallVector LLVM::AllocaOp::getPromotableSlots() { if (!getOperation()->getBlock()->isEntryBlock()) return {}; - Type elemType = - getElemType() ? *getElemType() : getResult().getType().getElementType(); - return {MemorySlot{getResult(), elemType}}; + return {MemorySlot{getResult(), getAllocaElementType(*this)}}; } Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot, @@ -59,15 +64,13 @@ SmallVector LLVM::AllocaOp::getDestructibleSlots() { if (!mlir::matchPattern(getArraySize(), m_One())) return {}; - Type elemType = getElemType().has_value() - ? *getElemType() - : getResult().getType().getElementType(); + Type elemType = getAllocaElementType(*this); auto destructible = dyn_cast(elemType); if (!destructible) return {}; Optional> destructedType = - destructible.getDestructedLayout(); + destructible.getSubelementIndexMap(); if (!destructedType) return {}; @@ -90,7 +93,7 @@ LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, DenseMap slotMap; Optional> destructedType = - cast(elemType).getDestructedLayout(); + cast(elemType).getSubelementIndexMap(); for (auto &[index, type] : destructedType.value()) { if (usedIndices.contains(index)) { auto subAlloca = builder.create( @@ -291,6 +294,8 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) { if (!indexInt || indexInt.getInt() != 0) return {}; + // Set the initial type currently being used for indexing. This will be + // updated as the indices get walked over. Optional maybeSelectedType = gep.getElemType(); if (!maybeSelectedType) return {}; @@ -309,13 +314,16 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) { if (!destructible) return {}; - // Follow the type at the index the gep is accessing. + // Follow the type at the index the gep is accessing, making it the new type + // used for indexing. Type field = destructible.getTypeAtIndex(indexInt); if (!field) return {}; selectedType = field; } + // When there are no more indices, the type currently being used for indexing + // is the type of the value pointed at by the returned indexed pointer. return selectedType; } @@ -383,7 +391,7 @@ DeletionKind LLVM::GEPOp::rewire(const DestructibleMemorySlot &slot, //===----------------------------------------------------------------------===// Optional> -LLVM::LLVMStructType::getDestructedLayout() { +LLVM::LLVMStructType::getSubelementIndexMap() { int32_t index = 0; Type i32 = IntegerType::get(getContext(), 32); DenseMap destructured; @@ -406,7 +414,7 @@ Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) { } Optional> -LLVM::LLVMArrayType::getDestructedLayout() const { +LLVM::LLVMArrayType::getSubelementIndexMap() const { if (getNumElements() > 16) return {}; int32_t numElements = getNumElements(); From a843ed506109d3b4b7f5106a4d3e01c10f26a405 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Wed, 3 May 2023 14:24:44 +0200 Subject: [PATCH 23/25] address more comments --- .../mlir/Interfaces/MemorySlotInterfaces.td | 2 +- mlir/include/mlir/Transforms/Mem2Reg.h | 6 +++--- mlir/lib/Transforms/SROA.cpp | 14 ++++++-------- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index 8723d91d5fa9f..b2c34168d17c8 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -279,7 +279,7 @@ def DestructibleAccessorOpInterface For a given destructible memory slot, returns whether this operation can rewire its uses of the slot to use the slots generated after destruction. This may involve creating new operations, and usually - amounts to checking the pointer types match. + amounts to checking if the pointer types match. This method must also register the indices it will access within the `usedIndices` set. If the accessor generates new slots mapping to diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h index 1a93c6736d07d..39637d6ec7ecf 100644 --- a/mlir/include/mlir/Transforms/Mem2Reg.h +++ b/mlir/include/mlir/Transforms/Mem2Reg.h @@ -21,9 +21,9 @@ struct MemorySlotPromotionInfo { /// Blocks for which at least two definitions of the slot values clash. SmallPtrSet mergePoints; /// Contains, for each operation, which uses must be eliminated by promotion. - /// This is a DAG structure because an operation that must eliminate some of - /// its uses always comes from a request from an operation that must - /// eliminate some of its own uses. + /// This is a DAG structure because if an operation must eliminate some of + /// its uses, it is because the parents of the uses it must eliminate must + /// eliminate uses themselves. DenseMap> userToBlockingUses; }; diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index d6f8aadbde76f..39300cd9fadc7 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -50,13 +50,12 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { scheduleAsBlockingUse(use); } - SmallPtrSet dealtWith; + SmallPtrSet visited; while (!usedSafelyWorklist.empty()) { MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val(); for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) { - if (dealtWith.contains(&subslotUse)) + if (!visited.insert(&subslotUse).second) continue; - dealtWith.insert(&subslotUse); Operation *subslotUser = subslotUse.getOwner(); if (auto memOp = dyn_cast(subslotUser)) @@ -113,16 +112,15 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, DenseMap subslots = allocator.destruct(slot, info.usedIndices, builder); - llvm::SetVector usersToRewire; - for (auto &[user, _] : info.userToBlockingUses) + SetVector usersToRewire; + for (Operation *user : llvm::make_first_range(info.userToBlockingUses)) usersToRewire.insert(user); for (DestructibleAccessorOpInterface accessor : info.accessors) usersToRewire.insert(accessor); - SetVector sortedUsersToRewire = - mlir::topologicalSort(usersToRewire); + usersToRewire = mlir::topologicalSort(usersToRewire); llvm::SmallVector toErase; - for (Operation *toRewire : llvm::reverse(sortedUsersToRewire)) { + for (Operation *toRewire : llvm::reverse(usersToRewire)) { builder.setInsertionPointAfter(toRewire); if (auto accessor = dyn_cast(toRewire)) { if (accessor.rewire(slot, subslots) == DeletionKind::Delete) From 6011b2bfbdac7b867298ea1beb77d74f500e4bd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Wed, 3 May 2023 14:55:21 +0200 Subject: [PATCH 24/25] address more comments --- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 39e5b64921400..891f93aa159bc 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -392,13 +392,10 @@ DeletionKind LLVM::GEPOp::rewire(const DestructibleMemorySlot &slot, Optional> LLVM::LLVMStructType::getSubelementIndexMap() { - int32_t index = 0; Type i32 = IntegerType::get(getContext(), 32); DenseMap destructured; - for (Type elemType : getBody()) { + for (auto const &[index, elemType] : llvm::enumerate(getBody())) destructured.insert({IntegerAttr::get(i32, index), elemType}); - index++; - } return destructured; } @@ -408,20 +405,21 @@ Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) { return {}; int32_t indexInt = indexAttr.getInt(); ArrayRef body = getBody(); - if (indexInt < 0 || body.size() <= ((uint32_t)indexInt)) + if (indexInt < 0 || body.size() <= static_cast(indexInt)) return {}; return body[indexInt]; } Optional> LLVM::LLVMArrayType::getSubelementIndexMap() const { - if (getNumElements() > 16) + constexpr size_t maxArraySizeForDestruction = 16; + if (getNumElements() > maxArraySizeForDestruction) return {}; int32_t numElements = getNumElements(); Type i32 = IntegerType::get(getContext(), 32); DenseMap destructured; - for (int32_t index = 0; index < numElements; index++) + for (int32_t index = 0; index < numElements; ++index) destructured.insert({IntegerAttr::get(i32, index), getElementType()}); return destructured; } @@ -431,7 +429,7 @@ Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const { if (!indexAttr || !indexAttr.getType().isInteger(32)) return {}; int32_t indexInt = indexAttr.getInt(); - if (indexInt < 0 || getNumElements() <= ((uint32_t)indexInt)) + if (indexInt < 0 || getNumElements() <= static_cast(indexInt)) return {}; return getElementType(); } From 5abecdec7e8d751931120ea1bf08a8fe7361044a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Wed, 3 May 2023 16:18:14 +0200 Subject: [PATCH 25/25] rename destruct into destructure --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 4 +- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 2 +- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td | 2 +- .../mlir/Interfaces/MemorySlotInterfaces.h | 4 +- .../mlir/Interfaces/MemorySlotInterfaces.td | 52 ++++++++-------- mlir/include/mlir/Transforms/Passes.td | 6 +- mlir/include/mlir/Transforms/SROA.h | 28 ++++----- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 48 +++++++-------- mlir/lib/Transforms/SROA.cpp | 60 ++++++++++--------- 9 files changed, 104 insertions(+), 102 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index e78e73a41ef0d..65f70d5d8d8e5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -174,7 +174,7 @@ def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp< // Memory-related operations. def LLVM_AllocaOp : LLVM_Op<"alloca", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]>, + DeclareOpInterfaceMethods]>, LLVM_MemOpPatterns { let arguments = (ins AnyInteger:$arraySize, OptionalAttr:$alignment, @@ -235,7 +235,7 @@ def LLVM_AllocaOp : LLVM_Op<"alloca", def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, Variadic>:$dynamicIndices, DenseI32ArrayAttr:$rawConstantIndices, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index 7fc62e04942e3..80de92b12c1e5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -104,7 +104,7 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType); class LLVMStructType : public Type::TypeBase { public: /// Inherit base constructors. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td index f861ba4b0af88..e26d9d8acc79e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -26,7 +26,7 @@ class LLVMType traits = []> def LLVMArrayType : LLVMType<"LLVMArray", "array", [ DeclareTypeInterfaceMethods, - DeclareTypeInterfaceMethods]> { + DeclareTypeInterfaceMethods]> { let summary = "LLVM array type"; let description = [{ The `!llvm.array` type represents a fixed-size array of element types. diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h index 704d6fc71deee..3134af8866a54 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -23,8 +23,8 @@ struct MemorySlot { Type elemType; }; -/// Memory slot attached with information about its destruction procedure. -struct DestructibleMemorySlot : MemorySlot { +/// Memory slot attached with information about its destructuring procedure. +struct DestructurableMemorySlot : MemorySlot { /// Maps an index within the memory slot to the type of the pointer that /// will be generated to access the element directly. DenseMap elementPtrs; diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index b2c34168d17c8..9d00f71934359 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -195,24 +195,24 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { ]; } -def DestructibleAllocationOpInterface - : OpInterface<"DestructibleAllocationOpInterface"> { +def DestructurableAllocationOpInterface + : OpInterface<"DestructurableAllocationOpInterface"> { let description = [{ Describes operations allocating memory slots of aggregates that can be - destructed into multiple smaller allocations. + destructured into multiple smaller allocations. }]; let cppNamespace = "::mlir"; let methods = [ InterfaceMethod<[{ - Returns the list of slots for which destruction should be attempted, - specifying in which way the slot should be destructed into subslots. The - subslots are indexed by attributes. This computes the type of the + Returns the list of slots for which destructuring should be attempted, + specifying in which way the slot should be destructured into subslots. + The subslots are indexed by attributes. This computes the type of the pointers of each subslots to be generated. The type of the memory slot - must implement `DestructibleTypeInterface`. + must implement `DestructurableTypeInterface`. }], - "::llvm::SmallVector<::mlir::DestructibleMemorySlot>", - "getDestructibleSlots", + "::llvm::SmallVector<::mlir::DestructurableMemorySlot>", + "getDestructurableSlots", (ins) >, InterfaceMethod<[{ @@ -224,18 +224,18 @@ def DestructibleAllocationOpInterface pointer is defined. }], "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot>", - "destruct", - (ins "const ::mlir::DestructibleMemorySlot &":$slot, + "destructure", + (ins "const ::mlir::DestructurableMemorySlot &":$slot, "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices, "::mlir::OpBuilder &":$builder) >, InterfaceMethod<[{ - Hook triggered once the destruction of a slot is complete, meaning the + Hook triggered once the destructuring of a slot is complete, meaning the original slot is no longer being refered to and could be deleted. This will only be called for slots declared by this operation. }], - "void", "handleDestructionComplete", - (ins "const ::mlir::DestructibleMemorySlot &":$slot) + "void", "handleDestructuringComplete", + (ins "const ::mlir::DestructurableMemorySlot &":$slot) >, ]; } @@ -267,18 +267,18 @@ def TypeSafeOpInterface : OpInterface<"TypeSafeOpInterface"> { ]; } -def DestructibleAccessorOpInterface - : OpInterface<"DestructibleAccessorOpInterface"> { +def DestructurableAccessorOpInterface + : OpInterface<"DestructurableAccessorOpInterface"> { let description = [{ - Describes operations that can access a sub-element of a destructible slot. + Describes operations that can access a sub-element of a destructurable slot. }]; let cppNamespace = "::mlir"; let methods = [ InterfaceMethod<[{ - For a given destructible memory slot, returns whether this operation can + For a given destructurable memory slot, returns whether this operation can rewire its uses of the slot to use the slots generated after - destruction. This may involve creating new operations, and usually + destructuring. This may involve creating new operations, and usually amounts to checking if the pointer types match. This method must also register the indices it will access within the @@ -288,7 +288,7 @@ def DestructibleAccessorOpInterface }], "bool", "canRewire", - (ins "const ::mlir::DestructibleMemorySlot &":$slot, + (ins "const ::mlir::DestructurableMemorySlot &":$slot, "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices, "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed) >, @@ -298,14 +298,14 @@ def DestructibleAccessorOpInterface }], "::mlir::DeletionKind", "rewire", - (ins "const ::mlir::DestructibleMemorySlot &":$slot, + (ins "const ::mlir::DestructurableMemorySlot &":$slot, "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot> &":$subslots) > ]; } -def DestructibleTypeInterface - : TypeInterface<"DestructibleTypeInterface"> { +def DestructurableTypeInterface + : TypeInterface<"DestructurableTypeInterface"> { let description = [{ Describes a type that can be broken down into indexable sub-element types. }]; @@ -313,8 +313,8 @@ def DestructibleTypeInterface let methods = [ InterfaceMethod<[{ - Destructs the type into subelements into a map of index attributes to - types of subelements. Returns nothing if the type cannot be destructed. + Destructures the type into subelements into a map of index attributes to + types of subelements. Returns nothing if the type cannot be destructured. }], "::llvm::Optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type>>", "getSubelementIndexMap", @@ -323,7 +323,7 @@ def DestructibleTypeInterface InterfaceMethod<[{ Indicates which type is held at the provided index, returning a null Type if no type could be computed. While this can return information - even when the type cannot be completely destructed, it must be coherent + even when the type cannot be completely destructured, it must be coherent with the types returned by `getSubelementIndexMap` when they exist. }], "::mlir::Type", diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 16f80cadc5b8c..08e048430625c 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -220,11 +220,11 @@ def SROA : Pass<"sroa"> { Scalar Replacement of Aggregates. Replaces allocations of aggregates into independant allocations of its elements. - Allocators must implement `DestructibleAllocationOpInterface` to provide - the list of memory slots for which destruction should be attempted. + Allocators must implement `DestructurableAllocationOpInterface` to provide + the list of memory slots for which destructuring should be attempted. This pass will only be applied if all accessors of the aggregate implement - the `DestructibleAccessorOpInterface`. If the accessors provide a view + the `DestructurableAccessorOpInterface`. If the accessors provide a view into the struct, users of the view must ensure it is used in a type-safe manner and within bounds by implementing `TypeSafeOpInterface`. }]; diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h index 955d48d47a4a2..8c5369f4d5775 100644 --- a/mlir/include/mlir/Transforms/SROA.h +++ b/mlir/include/mlir/Transforms/SROA.h @@ -14,24 +14,24 @@ namespace mlir { -struct MemorySlotDestructionInfo { +struct MemorySlotDestructuringInfo { SmallPtrSet usedIndices; DenseMap> userToBlockingUses; - SmallVector accessors; + SmallVector accessors; }; -/// Computes information for slot destruction leading to promotion. This will -/// compute whether this slot can be destructed. Returns nothing if the slot -/// cannot be destructed. -Optional -computeDestructionInfo(DestructibleMemorySlot &slot); - -/// Performs the destruction of a destructible slot given associated destruction -/// information. The provided slot will be destructed in subslots by its -/// allocator. -void destructSlot(DestructibleMemorySlot &slot, - DestructibleAllocationOpInterface allocator, - OpBuilder &builder, MemorySlotDestructionInfo &info); +/// Computes information for slot destructuring. This will compute whether this +/// slot can be destructured and data to perform the destructuring. Returns +/// nothing if the slot cannot be destructured. +Optional +computeDestructuringInfo(DestructurableMemorySlot &slot); + +/// Performs the destructuring of a destructible slot given associated +/// destructuring information. The provided slot will be destructured in +/// subslots as specified by its allocator. +void destructureSlot(DestructurableMemorySlot &slot, + DestructurableAllocationOpInterface allocator, + OpBuilder &builder, MemorySlotDestructuringInfo &info); } // namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 891f93aa159bc..7e3bbe1d5caeb 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -60,31 +60,31 @@ void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot, erase(); } -SmallVector LLVM::AllocaOp::getDestructibleSlots() { +SmallVector LLVM::AllocaOp::getDestructurableSlots() { if (!mlir::matchPattern(getArraySize(), m_One())) return {}; Type elemType = getAllocaElementType(*this); - auto destructible = dyn_cast(elemType); - if (!destructible) + auto destructurable = dyn_cast(elemType); + if (!destructurable) return {}; - Optional> destructedType = - destructible.getSubelementIndexMap(); - if (!destructedType) + Optional> destructuredType = + destructurable.getSubelementIndexMap(); + if (!destructuredType) return {}; DenseMap allocaTypeMap; - for (Attribute index : llvm::make_first_range(destructedType.value())) + for (Attribute index : llvm::make_first_range(destructuredType.value())) allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())}); - return {DestructibleMemorySlot{{getResult(), elemType}, {allocaTypeMap}}}; + return {DestructurableMemorySlot{{getResult(), elemType}, {allocaTypeMap}}}; } DenseMap -LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, - SmallPtrSetImpl &usedIndices, - OpBuilder &builder) { +LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + OpBuilder &builder) { assert(slot.ptr == getResult()); Type elemType = getElemType() ? *getElemType() : getResult().getType().getElementType(); @@ -92,9 +92,9 @@ LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, builder.setInsertionPointAfter(*this); DenseMap slotMap; - Optional> destructedType = - cast(elemType).getSubelementIndexMap(); - for (auto &[index, type] : destructedType.value()) { + Optional> destructuredType = + cast(elemType).getSubelementIndexMap(); + for (auto &[index, type] : destructuredType.value()) { if (usedIndices.contains(index)) { auto subAlloca = builder.create( getLoc(), LLVM::LLVMPointerType::get(getContext()), type, @@ -106,8 +106,8 @@ LLVM::AllocaOp::destruct(const DestructibleMemorySlot &slot, return slotMap; } -void LLVM::AllocaOp::handleDestructionComplete( - const DestructibleMemorySlot &slot) { +void LLVM::AllocaOp::handleDestructuringComplete( + const DestructurableMemorySlot &slot) { assert(slot.ptr == getResult()); erase(); } @@ -310,13 +310,13 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) { // Ensure the structure of the type being indexed can be reasoned about. assert(!selectedType.isa()); - auto destructible = selectedType.dyn_cast(); - if (!destructible) + auto destructurable = selectedType.dyn_cast(); + if (!destructurable) return {}; // Follow the type at the index the gep is accessing, making it the new type // used for indexing. - Type field = destructible.getTypeAtIndex(indexInt); + Type field = destructurable.getTypeAtIndex(indexInt); if (!field) return {}; selectedType = field; @@ -340,7 +340,7 @@ LogicalResult LLVM::GEPOp::ensureOnlyTypeSafeAccesses( return success(); } -bool LLVM::GEPOp::canRewire(const DestructibleMemorySlot &slot, +bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot, SmallPtrSetImpl &usedIndices, SmallVectorImpl &mustBeSafelyUsed) { if (getBase() != slot.ptr || slot.elemType != getElemType()) @@ -357,7 +357,7 @@ bool LLVM::GEPOp::canRewire(const DestructibleMemorySlot &slot, return true; } -DeletionKind LLVM::GEPOp::rewire(const DestructibleMemorySlot &slot, +DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot, DenseMap &subslots) { IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast(); const MemorySlot &newSlot = subslots.at(firstLevelIndex); @@ -387,7 +387,7 @@ DeletionKind LLVM::GEPOp::rewire(const DestructibleMemorySlot &slot, } //===----------------------------------------------------------------------===// -// Interfaces for destructible types +// Interfaces for destructurable types //===----------------------------------------------------------------------===// Optional> @@ -412,8 +412,8 @@ Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) { Optional> LLVM::LLVMArrayType::getSubelementIndexMap() const { - constexpr size_t maxArraySizeForDestruction = 16; - if (getNumElements() > maxArraySizeForDestruction) + constexpr size_t maxArraySizeForDestructuring = 16; + if (getNumElements() > maxArraySizeForDestructuring) return {}; int32_t numElements = getNumElements(); diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 39300cd9fadc7..a30b8fb24223a 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -21,11 +21,11 @@ namespace mlir { using namespace mlir; -Optional -mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { - assert(isa(slot.elemType)); +Optional +mlir::computeDestructuringInfo(DestructurableMemorySlot &slot) { + assert(isa(slot.elemType)); - MemorySlotDestructionInfo info; + MemorySlotDestructuringInfo info; SmallVector usedSafelyWorklist; @@ -38,7 +38,7 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { // Initialize the analysis with the immediate users of the slot. for (OpOperand &use : slot.ptr.getUses()) { if (auto accessor = - dyn_cast(use.getOwner())) { + dyn_cast(use.getOwner())) { if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist)) { info.accessors.push_back(accessor); continue; @@ -80,13 +80,13 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { auto promotable = dyn_cast(user); // An operation that has blocking uses must be promoted. If it is not - // promotable, destruction must fail. + // promotable, destructuring must fail. if (!promotable) return {}; SmallVector newBlockingUses; // If the operation decides it cannot deal with removing the blocking uses, - // destruction must fail. + // destructuring must fail. if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses)) return {}; @@ -103,26 +103,27 @@ mlir::computeDestructionInfo(DestructibleMemorySlot &slot) { return info; } -void mlir::destructSlot(DestructibleMemorySlot &slot, - DestructibleAllocationOpInterface allocator, - OpBuilder &builder, MemorySlotDestructionInfo &info) { +void mlir::destructureSlot(DestructurableMemorySlot &slot, + DestructurableAllocationOpInterface allocator, + OpBuilder &builder, + MemorySlotDestructuringInfo &info) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(slot.ptr.getParentBlock()); DenseMap subslots = - allocator.destruct(slot, info.usedIndices, builder); + allocator.destructure(slot, info.usedIndices, builder); SetVector usersToRewire; for (Operation *user : llvm::make_first_range(info.userToBlockingUses)) usersToRewire.insert(user); - for (DestructibleAccessorOpInterface accessor : info.accessors) + for (DestructurableAccessorOpInterface accessor : info.accessors) usersToRewire.insert(accessor); usersToRewire = mlir::topologicalSort(usersToRewire); llvm::SmallVector toErase; for (Operation *toRewire : llvm::reverse(usersToRewire)) { builder.setInsertionPointAfter(toRewire); - if (auto accessor = dyn_cast(toRewire)) { + if (auto accessor = dyn_cast(toRewire)) { if (accessor.rewire(slot, subslots) == DeletionKind::Delete) toErase.push_back(accessor); continue; @@ -137,10 +138,10 @@ void mlir::destructSlot(DestructibleMemorySlot &slot, for (Operation *toEraseOp : toErase) toEraseOp->erase(); - assert(slot.ptr.use_empty() && "at the end of destruction, the original slot " + assert(slot.ptr.use_empty() && "after destructuring, the original slot " "pointer should no longer be used"); - allocator.handleDestructionComplete(slot); + allocator.handleDestructuringComplete(slot); } namespace { @@ -156,29 +157,30 @@ struct SROA : public impl::SROABase { OpBuilder builder(®ion.front(), region.front().begin()); - // Destructing a slot can allow for further destruction of other slots, - // destruction is tried until no destruction succeeds. + // Destructuring a slot can allow for further destructuring of other + // slots, so destructuring is tried until no destructuring succeeds. while (true) { - struct DestructionJob { - DestructibleAllocationOpInterface allocator; - DestructibleMemorySlot slot; - MemorySlotDestructionInfo info; + struct DestructuringJob { + DestructurableAllocationOpInterface allocator; + DestructurableMemorySlot slot; + MemorySlotDestructuringInfo info; }; - std::vector toDestruct; + std::vector toDestructure; - region.walk([&](DestructibleAllocationOpInterface allocator) { - for (DestructibleMemorySlot slot : allocator.getDestructibleSlots()) - if (auto info = computeDestructionInfo(slot)) - toDestruct.emplace_back( + region.walk([&](DestructurableAllocationOpInterface allocator) { + for (DestructurableMemorySlot slot : + allocator.getDestructurableSlots()) + if (auto info = computeDestructuringInfo(slot)) + toDestructure.emplace_back( {allocator, std::move(slot), std::move(info.value())}); }); - if (toDestruct.empty()) + if (toDestructure.empty()) break; - for (DestructionJob &job : toDestruct) - destructSlot(job.slot, job.allocator, builder, job.info); + for (DestructuringJob &job : toDestructure) + destructureSlot(job.slot, job.allocator, builder, job.info); changed = true; }