Skip to content

[VectorCombine] Enable transform 'scalarizeLoadExtract' for non constant indexes #65445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2023
Merged

[VectorCombine] Enable transform 'scalarizeLoadExtract' for non constant indexes #65445

merged 1 commit into from
Sep 26, 2023

Conversation

benshi001
Copy link
Member

@benshi001 benshi001 commented Sep 6, 2023

Enable the transform if a non constant index is guaranteed to be safe
via a UREM/AND.

This PR is stacked on #65443

@benshi001
Copy link
Member Author

This PR is stacked on #65443

@benshi001
Copy link
Member Author

ping ...

@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Changes Enable the transform if a non constant index is guaranteed to be safe via a UREM/AND.

This PR is stacked on #65443

Full diff: https://github.com/llvm/llvm-project/pull/65445.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+30-22)
  • (modified) llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll (+36-27)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 66e3bcaac0adb2e..830804ddd3b8024 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/VectorCombine.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
@@ -969,7 +970,11 @@ class ScalarizationResult {
 public:
   ScalarizationResult(const ScalarizationResult &Other) = default;
   ~ScalarizationResult() {
-    assert(!ToFreeze && "freeze() not called with ToFreeze being set");
+    // The object may be copied to another scope if it is in state
+    // StatusTy::SafeWithFreeze.
+    if (Status != StatusTy::SafeWithFreeze)
+      assert(!ToFreeze &&
+             "freeze() or discard() not called with ToFreeze being set");
   }
 
   static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
@@ -1134,19 +1139,20 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   if (!match(&I, m_Load(m_Value(Ptr))))
     return false;
 
-  auto *FixedVT = cast<FixedVectorType>(I.getType());
+  auto *VecTy = cast<VectorType>(I.getType());
   auto *LI = cast<LoadInst>(&I);
   const DataLayout &DL = I.getModule()->getDataLayout();
-  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
+  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy))
     return false;
 
   InstructionCost OriginalCost =
-      TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
                           LI->getPointerAddressSpace());
   InstructionCost ScalarizedCost = 0;
 
   Instruction *LastCheckedInst = LI;
   unsigned NumInstChecked = 0;
+  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
   // Check if all users of the load are extracts with no memory modifications
   // between the load and the extract. Compute the cost of both the original
   // code and the scalarized version.
@@ -1155,9 +1161,6 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     if (!UI || UI->getParent() != LI->getParent())
       return false;
 
-    if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT))
-      return false;
-
     // Check if any instruction between the load and the extract may modify
     // memory.
     if (LastCheckedInst->comesBefore(UI)) {
@@ -1172,22 +1175,23 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
       LastCheckedInst = UI;
     }
 
-    auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
-    if (!ScalarIdx.isSafe()) {
-      // TODO: Freeze index if it is safe to do so.
-      ScalarIdx.discard();
+    auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
+    if (ScalarIdx.isUnsafe()) {
       return false;
+    } else if (ScalarIdx.isSafeWithFreeze()) {
+      NeedFreeze.insert(std::make_pair(UI, ScalarIdx));
+      ScalarIdx.discard();
     }
 
     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     OriginalCost +=
-        TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
+        TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
                                Index ? Index->getZExtValue() : -1);
     ScalarizedCost +=
-        TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
+        TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
                             Align(1), LI->getPointerAddressSpace());
-    ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
+    ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
   }
 
   if (ScalarizedCost >= OriginalCost)
@@ -1196,16 +1200,21 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   // Replace extracts with narrow scalar loads.
   for (User *U : LI->users()) {
     auto *EI = cast<ExtractElementInst>(U);
-    Builder.SetInsertPoint(EI);
-
     Value *Idx = EI->getOperand(1);
+
+    // Insert 'freeze' for poison indexes.
+    DenseMap<ExtractElementInst *, ScalarizationResult>::iterator It;
+    if ((It = NeedFreeze.find(EI)) != NeedFreeze.end())
+      It->second.freeze(Builder, *cast<Instruction>(Idx));
+
+    Builder.SetInsertPoint(EI);
     Value *GEP =
-        Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
+        Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
-        FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
+        VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
 
     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
-        LI->getAlign(), FixedVT->getElementType(), Idx, DL);
+        LI->getAlign(), VecTy->getElementType(), Idx, DL);
     NewLoad->setAlignment(ScalarOpAlignment);
 
     replaceValue(*EI, *NewLoad);
@@ -1727,9 +1736,6 @@ bool VectorCombine::run() {
       case Instruction::ShuffleVector:
         MadeChange |= widenSubvectorLoad(I);
         break;
-      case Instruction::Load:
-        MadeChange |= scalarizeLoadExtract(I);
-        break;
       default:
         break;
       }
@@ -1743,6 +1749,8 @@ bool VectorCombine::run() {
     if (Opcode == Instruction::Store)
       MadeChange |= foldSingleElementStore(I);
 
+    if (isa<VectorType>(I.getType()) && Opcode == Instruction::Load)
+      MadeChange |= scalarizeLoadExtract(I);
 
     // If this is an early pipeline invocation of this pass, we are done.
     if (TryEarlyFoldsOnly)
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 7df4f49e095c96c..42b3f9afeb56ee8 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -15,8 +15,8 @@ define i32 @load_extract_idx_0(ptr %x) {
 
 define i32 @vscale_load_extract_idx_0(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_0(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 0
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -61,8 +61,8 @@ define i32 @load_extract_idx_2(ptr %x) {
 
 define i32 @vscale_load_extract_idx_2(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_2(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 2
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 8
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -142,9 +142,9 @@ define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %i
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
 ; CHECK-NEXT:    call void @maythrow()
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -309,9 +309,10 @@ declare void @llvm.assume(i1)
 define i32 @load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -324,9 +325,10 @@ entry:
 define i32 @vscale_load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -384,9 +386,10 @@ entry:
 define i32 @load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -399,9 +402,10 @@ entry:
 define i32 @vscale_load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -789,11 +793,14 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_only_first
 
 define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(ptr %x, i64 %idx.0, i64 %idx.1) {
 ; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(
-; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
-; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT:    [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT:    [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT:    [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
+; CHECK-NEXT:    [[IDX_1_FROZEN:%.*]] = freeze i64 [[IDX_1:%.*]]
+; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1_FROZEN]], 15
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
@@ -809,11 +816,13 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_
 
 define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(ptr %x, i64 %idx.0, i64 noundef %idx.1) {
 ; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(
-; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
+; CHECK-NEXT:    [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
 ; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT:    [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT:    [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT:    [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;

…ant indexes

Enable the transform if a non constant index is guaranteed to be safe
via a UREM/AND.
@benshi001 benshi001 requested a review from nikic September 25, 2023 12:40
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@benshi001 benshi001 merged commit ea0ee55 into llvm:main Sep 26, 2023
@benshi001 benshi001 deleted the veccom-load-extract-2 branch September 26, 2023 06:02
Guzhu-AMD pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Sep 28, 2023
Local branch amd-gfx 9e0aff5 Merged main:feb7b1914d51 into amd-gfx:e2e3938ffce3
Remote branch main ea0ee55 [VectorCombine] Enable transform scalarizeLoadExtract for non constant indexes (llvm#65445)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants