Skip to content

Commit

Permalink
[Matrix] Adjust lifetime.ends during multiply fusion. (#84914)
Browse files Browse the repository at this point in the history
At the moment, loads introduced by multiply fusion may be placed after
an objects lifetime has been terminated by lifetime.end. This introduces
reads to dead objects.

To avoid this, first collect all lifetime.end calls in the function.
During fusion, we deal with any lifetime.end calls that may alias any of
the loads.

Such lifetime.end calls are either moved when possible (both the
lifetime.end and the store are in the same block) or deleted.

PR: #84914
  • Loading branch information
fhahn committed Mar 16, 2024
1 parent 0847c90 commit e77378c
Show file tree
Hide file tree
Showing 2 changed files with 527 additions and 18 deletions.
61 changes: 58 additions & 3 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
Expand Down Expand Up @@ -990,12 +991,15 @@ class LowerMatrixIntrinsics {
bool Changed = false;
SmallVector<CallInst *, 16> MaybeFusableInsts;
SmallVector<Instruction *, 16> MatrixInsts;
SmallVector<IntrinsicInst *, 16> LifetimeEnds;

// First, collect all instructions with shape information and candidates for
// fusion (currently only matrix multiplies).
ReversePostOrderTraversal<Function *> RPOT(&Func);
for (auto *BB : RPOT)
for (Instruction &I : *BB) {
if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
if (ShapeMap.find(&I) == ShapeMap.end())
continue;
if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
Expand All @@ -1010,7 +1014,7 @@ class LowerMatrixIntrinsics {

// Third, try to fuse candidates.
for (CallInst *CI : MaybeFusableInsts)
LowerMatrixMultiplyFused(CI, FusedInsts);
LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);

Changed = !FusedInsts.empty();

Expand Down Expand Up @@ -1856,8 +1860,10 @@ class LowerMatrixIntrinsics {
///
/// Call finalizeLowering on lowered instructions. Instructions that are
/// completely eliminated by fusion are added to \p FusedInsts.
void LowerMatrixMultiplyFused(CallInst *MatMul,
SmallPtrSetImpl<Instruction *> &FusedInsts) {
void
LowerMatrixMultiplyFused(CallInst *MatMul,
SmallPtrSetImpl<Instruction *> &FusedInsts,
SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
if (!FuseMatrix || !DT)
return;

Expand Down Expand Up @@ -1946,6 +1952,55 @@ class LowerMatrixIntrinsics {
for (Instruction *I : ToHoist)
I->moveBefore(MatMul);

// Deal with lifetime.end calls that might be between Load0/Load1 and the
// store. To avoid introducing loads to dead objects (i.e. after the
// lifetime has been termined by @llvm.lifetime.end), either sink them
// after the store if in the same block, or remove the lifetime.end marker
// otherwise. This might pessimize further optimizations, by extending the
// lifetime of the object until the function returns, but should be
// conservatively correct.
MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
BasicBlock *StoreParent = Store->getParent();
bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
LoadOp1->getParent() == StoreParent;
for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
IntrinsicInst *End = LifetimeEnds[Idx];
auto Inc = make_scope_exit([&Idx]() { Idx++; });
// If the lifetime.end is guaranteed to be before the loads or after the
// store, it won't interfere with fusion.
if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
continue;
if (DT->dominates(Store, End))
continue;
// If all fusable ops are in the same block and the lifetime.end is in a
// different block, it won't interfere with fusion.
if (FusableOpsInSameBlock && End->getParent() != StoreParent)
continue;

// If the loads don't alias the lifetime.end, it won't interfere with
// fusion.
MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr);
if (!EndLoc.Ptr)
continue;
if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
continue;

// If both lifetime.end and the store are in the same block, extend the
// lifetime until after the store, so the new lifetime covers the loads
// we introduce later.
if (End->getParent() == StoreParent) {
End->moveAfter(Store);
continue;
}

// Otherwise remove the conflicting lifetime.end marker.
ToRemove.push_back(End);
std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
LifetimeEnds.pop_back();
Inc.release();
}

emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
return;
}
Expand Down

0 comments on commit e77378c

Please sign in to comment.