Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[Scalarizer] No need to gather a scattered extracted element
ExtractElement does not produce a vector out of a vector, so there's no need to
call a gather once done.

Fix #54469

Credits to npopov@redhat.com for the original approach.

Differential Revision: https://reviews.llvm.org/D126012
  • Loading branch information
serge-sans-paille committed Jun 21, 2022
1 parent 271cc58 commit aaf1630
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 3 deletions.
20 changes: 17 additions & 3 deletions llvm/lib/Transforms/Scalar/Scalarizer.cpp
Expand Up @@ -229,6 +229,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
private:
Scatterer scatter(Instruction *Point, Value *V, Type *PtrElemTy = nullptr);
void gather(Instruction *Op, const ValueVector &CV);
void replaceUses(Instruction *Op, Value *CV);
bool canTransferMetadata(unsigned Kind);
void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
Optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
Expand All @@ -242,6 +243,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {

ScatterMap Scattered;
GatherList Gathered;
bool Scalarized;

SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;

Expand Down Expand Up @@ -361,6 +363,8 @@ FunctionPass *llvm::createScalarizerPass() {
bool ScalarizerVisitor::visit(Function &F) {
assert(Gathered.empty() && Scattered.empty());

Scalarized = false;

// To ensure we replace gathered components correctly we need to do an ordered
// traversal of the basic blocks in the function.
ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
Expand Down Expand Up @@ -436,6 +440,15 @@ void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) {
Gathered.push_back(GatherList::value_type(Op, &SV));
}

// Replace Op with CV and collect Op has a potentially dead instruction.
void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) {
if (CV != Op) {
Op->replaceAllUsesWith(CV);
PotentiallyDeadInstrs.emplace_back(Op);
Scalarized = true;
}
}

// Return true if it is safe to transfer the given metadata tag from
// vector to scalar instructions.
bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
Expand Down Expand Up @@ -828,7 +841,7 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {

if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
Value *Res = Op0[CI->getValue().getZExtValue()];
gather(&EEI, {Res});
replaceUses(&EEI, Res);
return true;
}

Expand All @@ -844,7 +857,7 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
Res = Builder.CreateSelect(ShouldExtract, Elt, Res,
EEI.getName() + ".upto" + Twine(I));
}
gather(&EEI, {Res});
replaceUses(&EEI, Res);
return true;
}

Expand Down Expand Up @@ -959,7 +972,7 @@ bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
bool ScalarizerVisitor::finish() {
// The presence of data in Gathered or Scattered indicates changes
// made to the Function.
if (Gathered.empty() && Scattered.empty())
if (Gathered.empty() && Scattered.empty() && !Scalarized)
return false;
for (const auto &GMI : Gathered) {
Instruction *Op = GMI.first;
Expand Down Expand Up @@ -990,6 +1003,7 @@ bool ScalarizerVisitor::finish() {
}
Gathered.clear();
Scattered.clear();
Scalarized = false;

RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);

Expand Down
1 change: 1 addition & 0 deletions llvm/test/Transforms/Scalarizer/global-bug.ll
@@ -1,3 +1,4 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes='function(scalarizer)' -S | FileCheck %s

@a = dso_local global i16 0, align 1
Expand Down
99 changes: 99 additions & 0 deletions llvm/test/Transforms/Scalarizer/vector-of-pointer-to-vector.ll
@@ -0,0 +1,99 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt %s -passes='function(scalarizer,dce)' -scalarize-load-store -S | FileCheck %s
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"

define <1 x i32> @f1(<1 x <1 x i32>*> %src, i32 %index) {
; CHECK-LABEL: @f1(
; CHECK-NEXT: [[INDEX_IS_0:%.*]] = icmp eq i32 [[INDEX:%.*]], 0
; CHECK-NEXT: [[SRC_I0:%.*]] = extractelement <1 x <1 x i32>*> [[SRC:%.*]], i32 0
; CHECK-NEXT: [[DOTUPTO0:%.*]] = select i1 [[INDEX_IS_0]], <1 x i32>* [[SRC_I0]], <1 x i32>* undef
; CHECK-NEXT: [[DOTUPTO0_I0:%.*]] = bitcast <1 x i32>* [[DOTUPTO0]] to i32*
; CHECK-NEXT: [[DOTI0:%.*]] = load i32, i32* [[DOTUPTO0_I0]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <1 x i32> poison, i32 [[DOTI0]], i32 0
; CHECK-NEXT: ret <1 x i32> [[TMP1]]
;
%1 = extractelement <1 x <1 x i32>*> %src, i32 %index
%2 = load <1 x i32>, <1 x i32>* %1, align 4
ret <1 x i32> %2
}

define <1 x i32> @f1b(<1 x <1 x i32>*> %src) {
; CHECK-LABEL: @f1b(
; CHECK-NEXT: [[SRC_I0:%.*]] = extractelement <1 x <1 x i32>*> [[SRC:%.*]], i32 0
; CHECK-NEXT: [[SRC_I0_I0:%.*]] = bitcast <1 x i32>* [[SRC_I0]] to i32*
; CHECK-NEXT: [[DOTI0:%.*]] = load i32, i32* [[SRC_I0_I0]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <1 x i32> poison, i32 [[DOTI0]], i32 0
; CHECK-NEXT: ret <1 x i32> [[TMP1]]
;
%1 = extractelement <1 x <1 x i32>*> %src, i32 0
%2 = load <1 x i32>, <1 x i32>* %1, align 4
ret <1 x i32> %2
}

define <2 x i32> @f2(<1 x <2 x i32>*> %src, i32 %index) {
; CHECK-LABEL: @f2(
; CHECK-NEXT: [[INDEX_IS_0:%.*]] = icmp eq i32 [[INDEX:%.*]], 0
; CHECK-NEXT: [[SRC_I0:%.*]] = extractelement <1 x <2 x i32>*> [[SRC:%.*]], i32 0
; CHECK-NEXT: [[DOTUPTO0:%.*]] = select i1 [[INDEX_IS_0]], <2 x i32>* [[SRC_I0]], <2 x i32>* undef
; CHECK-NEXT: [[DOTUPTO0_I0:%.*]] = bitcast <2 x i32>* [[DOTUPTO0]] to i32*
; CHECK-NEXT: [[DOTUPTO0_I1:%.*]] = getelementptr i32, i32* [[DOTUPTO0_I0]], i32 1
; CHECK-NEXT: [[DOTI0:%.*]] = load i32, i32* [[DOTUPTO0_I0]], align 4
; CHECK-NEXT: [[DOTI1:%.*]] = load i32, i32* [[DOTUPTO0_I1]], align 4
; CHECK-NEXT: [[DOTUPTO01:%.*]] = insertelement <2 x i32> poison, i32 [[DOTI0]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> [[DOTUPTO01]], i32 [[DOTI1]], i32 1
; CHECK-NEXT: ret <2 x i32> [[TMP1]]
;
%1 = extractelement <1 x <2 x i32>*> %src, i32 %index
%2 = load <2 x i32>, <2 x i32>* %1, align 4
ret <2 x i32> %2
}

define <2 x i32> @f2b(<1 x <2 x i32>*> %src) {
; CHECK-LABEL: @f2b(
; CHECK-NEXT: [[SRC_I0:%.*]] = extractelement <1 x <2 x i32>*> [[SRC:%.*]], i32 0
; CHECK-NEXT: [[SRC_I0_I0:%.*]] = bitcast <2 x i32>* [[SRC_I0]] to i32*
; CHECK-NEXT: [[SRC_I0_I1:%.*]] = getelementptr i32, i32* [[SRC_I0_I0]], i32 1
; CHECK-NEXT: [[DOTI0:%.*]] = load i32, i32* [[SRC_I0_I0]], align 4
; CHECK-NEXT: [[DOTI1:%.*]] = load i32, i32* [[SRC_I0_I1]], align 4
; CHECK-NEXT: [[DOTUPTO0:%.*]] = insertelement <2 x i32> poison, i32 [[DOTI0]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> [[DOTUPTO0]], i32 [[DOTI1]], i32 1
; CHECK-NEXT: ret <2 x i32> [[TMP1]]
;
%1 = extractelement <1 x <2 x i32>*> %src, i32 0
%2 = load <2 x i32>, <2 x i32>* %1, align 4
ret <2 x i32> %2
}

define void @f3(<1 x <2 x i32>*> %src, i32 %index, <2 x i32> %val) {
; CHECK-LABEL: @f3(
; CHECK-NEXT: [[VAL_I0:%.*]] = extractelement <2 x i32> [[VAL:%.*]], i32 0
; CHECK-NEXT: [[VAL_I1:%.*]] = extractelement <2 x i32> [[VAL]], i32 1
; CHECK-NEXT: [[INDEX_IS_0:%.*]] = icmp eq i32 [[INDEX:%.*]], 0
; CHECK-NEXT: [[SRC_I0:%.*]] = extractelement <1 x <2 x i32>*> [[SRC:%.*]], i32 0
; CHECK-NEXT: [[DOTUPTO0:%.*]] = select i1 [[INDEX_IS_0]], <2 x i32>* [[SRC_I0]], <2 x i32>* undef
; CHECK-NEXT: [[DOTUPTO0_I0:%.*]] = bitcast <2 x i32>* [[DOTUPTO0]] to i32*
; CHECK-NEXT: [[DOTUPTO0_I1:%.*]] = getelementptr i32, i32* [[DOTUPTO0_I0]], i32 1
; CHECK-NEXT: store i32 [[VAL_I0]], i32* [[DOTUPTO0_I0]], align 4
; CHECK-NEXT: store i32 [[VAL_I1]], i32* [[DOTUPTO0_I1]], align 4
; CHECK-NEXT: ret void
;
%1 = extractelement <1 x <2 x i32>*> %src, i32 %index
store <2 x i32> %val, <2 x i32>* %1, align 4
ret void
}

define void @f3b(<1 x <2 x i32>*> %src, <2 x i32> %val) {
; CHECK-LABEL: @f3b(
; CHECK-NEXT: [[VAL_I0:%.*]] = extractelement <2 x i32> [[VAL:%.*]], i32 0
; CHECK-NEXT: [[VAL_I1:%.*]] = extractelement <2 x i32> [[VAL]], i32 1
; CHECK-NEXT: [[SRC_I0:%.*]] = extractelement <1 x <2 x i32>*> [[SRC:%.*]], i32 0
; CHECK-NEXT: [[SRC_I0_I0:%.*]] = bitcast <2 x i32>* [[SRC_I0]] to i32*
; CHECK-NEXT: [[SRC_I0_I1:%.*]] = getelementptr i32, i32* [[SRC_I0_I0]], i32 1
; CHECK-NEXT: store i32 [[VAL_I0]], i32* [[SRC_I0_I0]], align 4
; CHECK-NEXT: store i32 [[VAL_I1]], i32* [[SRC_I0_I1]], align 4
; CHECK-NEXT: ret void
;
%1 = extractelement <1 x <2 x i32>*> %src, i32 0
store <2 x i32> %val, <2 x i32>* %1, align 4
ret void
}

0 comments on commit aaf1630

Please sign in to comment.