Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Matrix] Adjust lifetime.ends during multiply fusion. #84914

Merged
merged 5 commits into from
Mar 16, 2024

Conversation

fhahn
Copy link
Contributor

@fhahn fhahn commented Mar 12, 2024

At the moment, loads introduced by multiply fusion may be placed after an objects lifetime has been terminated by lifetime.end. This introduces reads to dead objects.

To avoid this, first collect all lifetime.end calls in the function. During fusion, we deal with any lifetime.end calls that may alias any of the loads.

Such lifetime.end calls are either moved when possible (both the lifetime.end and the store are in the same block) or deleted.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 12, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

At the moment, loads introduced by multiply fusion may be placed after an objects lifetime has been terminated by lifetime.end. This introduces reads to dead objects.

To avoid this, first collect all lifetime.end calls in the function. During fusion, we deal with any lifetime.end calls that may alias any of the loads.

Such lifetime.end calls are either moved when possible (both the lifetime.end and the store are in the same block) or deleted.


Full diff: https://github.com/llvm/llvm-project/pull/84914.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+37-3)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll (+6-15)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 67c011b747acfd..bf8752bb32bacd 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -990,12 +990,15 @@ class LowerMatrixIntrinsics {
     bool Changed = false;
     SmallVector<CallInst *, 16> MaybeFusableInsts;
     SmallVector<Instruction *, 16> MatrixInsts;
+    SmallSetVector<IntrinsicInst *, 16> LifetimeEnds;
 
     // First, collect all instructions with shape information and candidates for
     // fusion (currently only matrix multiplies).
     ReversePostOrderTraversal<Function *> RPOT(&Func);
     for (auto *BB : RPOT)
       for (Instruction &I : *BB) {
+        if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
+          LifetimeEnds.insert(cast<IntrinsicInst>(&I));
         if (ShapeMap.find(&I) == ShapeMap.end())
           continue;
         if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
@@ -1010,7 +1013,7 @@ class LowerMatrixIntrinsics {
 
     // Third, try to fuse candidates.
     for (CallInst *CI : MaybeFusableInsts)
-      LowerMatrixMultiplyFused(CI, FusedInsts);
+      LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
 
     Changed = !FusedInsts.empty();
 
@@ -1856,8 +1859,10 @@ class LowerMatrixIntrinsics {
   ///
   /// Call finalizeLowering on lowered instructions.  Instructions that are
   /// completely eliminated by fusion are added to \p FusedInsts.
-  void LowerMatrixMultiplyFused(CallInst *MatMul,
-                                SmallPtrSetImpl<Instruction *> &FusedInsts) {
+  void
+  LowerMatrixMultiplyFused(CallInst *MatMul,
+                           SmallPtrSetImpl<Instruction *> &FusedInsts,
+                           SmallSetVector<IntrinsicInst *, 16> &LifetimeEnds) {
     if (!FuseMatrix || !DT)
       return;
 
@@ -1946,6 +1951,35 @@ class LowerMatrixIntrinsics {
       for (Instruction *I : ToHoist)
         I->moveBefore(MatMul);
 
+      // Deal with lifetime.end calls that might be between Load0/Load1 and the
+      // store. To avoid introducing loads to dead objects (i.e. after thei
+      // lifetime has been termined by @llvm.lifetime.end), either sink them
+      // after the store if in the same block, or remove the lifetime.end marker
+      // otherwise. This might pessimize further optimizations, by extending the
+      // lifetime of the object until the function returns, but should be
+      // conservatively correct.
+      MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
+      MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
+      for (IntrinsicInst *End : make_early_inc_range(LifetimeEnds)) {
+        if (DT->dominates(Store, End))
+          continue;
+        MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr);
+        if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLock))
+          continue;
+
+        // If both lifetime.end and the store are in the same block, extend the
+        // lifetime until after the store, so the new lifetime covers the loads
+        // we introduce later.
+        if (Store->getParent() == End->getParent()) {
+          End->moveAfter(Store);
+          continue;
+        }
+
+        // Otherwise remove the conflicting lifetime.end marker.
+        ToRemove.push_back(End);
+        LifetimeEnds.remove(End);
+      }
+
       emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
       return;
     }
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll
index ef8665b7969097..9c2b75f5d5756a 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll
@@ -6,15 +6,11 @@ target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
 ; Tests to make sure no loads are introduced after a lifetime.end by multiply
 ; fusion.
 
-; FIXME: Currently the tests are mis-compiled, with loads being introduced after
-;       llvm.lifetime.end calls.
-
 define void @lifetime_for_first_arg_before_multiply(ptr noalias %B, ptr noalias %C) {
 ; CHECK-LABEL: @lifetime_for_first_arg_before_multiply(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.*]] = alloca <4 x double>, align 32
 ; CHECK-NEXT:    call void @init(ptr [[A]])
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
@@ -77,6 +73,7 @@ define void @lifetime_for_first_arg_before_multiply(ptr noalias %B, ptr noalias
 ; CHECK-NEXT:    store <2 x double> [[TMP13]], ptr [[TMP26]], align 8
 ; CHECK-NEXT:    [[VEC_GEP28:%.*]] = getelementptr double, ptr [[TMP26]], i64 2
 ; CHECK-NEXT:    store <2 x double> [[TMP25]], ptr [[VEC_GEP28]], align 8
+; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -95,7 +92,6 @@ define void @lifetime_for_second_arg_before_multiply(ptr noalias %A, ptr noalias
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[B:%.*]] = alloca <4 x double>, align 32
 ; CHECK-NEXT:    call void @init(ptr [[B]])
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[B]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
@@ -158,6 +154,7 @@ define void @lifetime_for_second_arg_before_multiply(ptr noalias %A, ptr noalias
 ; CHECK-NEXT:    store <2 x double> [[TMP13]], ptr [[TMP26]], align 8
 ; CHECK-NEXT:    [[VEC_GEP28:%.*]] = getelementptr double, ptr [[TMP26]], i64 2
 ; CHECK-NEXT:    store <2 x double> [[TMP25]], ptr [[VEC_GEP28]], align 8
+; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[B]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -177,7 +174,6 @@ define void @lifetime_for_first_arg_before_multiply_load_from_offset(ptr noalias
 ; CHECK-NEXT:    [[A:%.*]] = alloca <8 x double>, align 64
 ; CHECK-NEXT:    call void @init(ptr [[A]])
 ; CHECK-NEXT:    [[GEP_8:%.*]] = getelementptr i8, ptr [[A]], i64 8
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[GEP_8]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
@@ -240,6 +236,7 @@ define void @lifetime_for_first_arg_before_multiply_load_from_offset(ptr noalias
 ; CHECK-NEXT:    store <2 x double> [[TMP13]], ptr [[TMP26]], align 8
 ; CHECK-NEXT:    [[VEC_GEP28:%.*]] = getelementptr double, ptr [[TMP26]], i64 2
 ; CHECK-NEXT:    store <2 x double> [[TMP25]], ptr [[VEC_GEP28]], align 8
+; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -261,7 +258,6 @@ define void @lifetime_for_first_arg_before_multiply_lifetime_does_not_dominate(p
 ; CHECK-NEXT:    call void @init(ptr [[A]])
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
@@ -352,7 +348,6 @@ define void @lifetime_for_second_arg_before_multiply_lifetime_does_not_dominate(
 ; CHECK-NEXT:    call void @init(ptr [[B]])
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[B]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
@@ -441,10 +436,9 @@ define void @lifetime_for_ptr_first_arg_before_multiply(ptr noalias %A, ptr noal
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A:%.*]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
@@ -528,15 +522,13 @@ define void @lifetime_for_both_ptr_args_before_multiply(ptr noalias %A, ptr noal
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[B:%.*]])
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A:%.*]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr double, ptr [[B]], i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr double, ptr [[B:%.*]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, ptr [[TMP1]], align 8
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, ptr [[TMP1]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x double>, ptr [[VEC_GEP3]], align 8
@@ -618,7 +610,6 @@ define void @lifetime_for_ptr_select_before_multiply(ptr noalias %A, ptr noalias
 ; CHECK-NEXT:    [[P:%.*]] = select i1 [[C_0:%.*]], ptr [[A:%.*]], ptr [[B:%.*]]
 ; CHECK-NEXT:    br i1 [[C_1:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[P]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[P]], i64 0

At the moment, loads introduced by multiply fusion may be placed after
an objects lifetime has been terminated by lifetime.end. This introduces
reads to dead objects.

To avoid this, first collect all lifetime.end calls in the function.
During fusion, we deal with any lifetime.end calls that may alias any of
the loads.

Such lifetime.end calls are either moved when possible (both the
lifetime.end and the store are in the same block) or deleted.
Copy link
Collaborator

@francisvm francisvm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a cheaper way to do this either. LGTM, thanks!

@Gerolf-Apple
Copy link
Contributor

Is the potential lifetime extension something you expect rarely? We should discuss this more on sample IR. I'm worried that we don't have good grasp on the trade-off. Another possible short term solution is to suppress fusion in this scenario (but again, I have no sense how frequent the liveends would interfere with fusion).

Co-authored-by: Francis Visoiu Mistrih <890283+francisvm@users.noreply.github.com>
@fhahn
Copy link
Contributor Author

fhahn commented Mar 12, 2024

Is the potential lifetime extension something you expect rarely? We should discuss this more on sample IR. I'm worried that we don't have good grasp on the trade-off. Another possible short term solution is to suppress fusion in this scenario (but again, I have no sense how frequent the liveends would interfere with fusion).

How often the lifetimes cause issues manifest depends on the use cases, but when used together with C++ templates, temporary objects could be quite common. The impact of longer lifetimes of a few objects could in some cases increase things like stack size, but in practice the benefit from fusion (when we deem profitable)should by far outweigh any gains from slightly shorter lifetimes. Without fusion, the generated code also likely has higher register pressure, as we need to load all data up-front.

Instead of dropping the lifetimes in some cases we could also use something like the first block that post-dominates both the block of the lifetime.end and the store, but requiring post-dominator tree would further increase compile-time

@fhahn fhahn merged commit e77378c into llvm:main Mar 16, 2024
4 checks passed
@fhahn fhahn deleted the matrix-lifetime branch March 16, 2024 19:41
@fhahn
Copy link
Contributor Author

fhahn commented Mar 16, 2024

I evaluated the patch on the codebase where this caused an mis-compile where fusion triggered around 1200 times. In the original version, about 600 lifetime.end markers were removed.

I adjusted the patch to also ignore lifetime.end, if both the loads and store are in the same BB and the lifetime.end is in a different one. In that case, the marker can't be between loads & store. With committed version of the patch, only 3 lifetime.end markers are removed, and 6 are moved.

fhahn added a commit to fhahn/llvm-project that referenced this pull request Mar 18, 2024
At the moment, loads introduced by multiply fusion may be placed after
an objects lifetime has been terminated by lifetime.end. This introduces
reads to dead objects.

To avoid this, first collect all lifetime.end calls in the function.
During fusion, we deal with any lifetime.end calls that may alias any of
the loads.

Such lifetime.end calls are either moved when possible (both the
lifetime.end and the store are in the same block) or deleted.

PR: llvm#84914

(cherry-picked from e77378c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants