Skip to content

Commit

Permalink
[Analysis] add query to get splat value from array of ints
Browse files Browse the repository at this point in the history
I was debug stepping through an x86 shuffle lowering and
noticed we were doing an N^2 search for splat index. I
didn't find the equivalent functionality anywhere else in
LLVM, so here's a helper that takes an array of int and
returns a splatted index while ignoring undefs (any
negative value).

This might also be used inside existing
ShuffleVectorInst/ShuffleVectorSDNode functions and/or
help with D72467.

Differential Revision: https://reviews.llvm.org/D74064
  • Loading branch information
rotateright committed Feb 5, 2020
1 parent 043e478 commit 686a038
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/VectorUtils.h
Expand Up @@ -301,6 +301,11 @@ Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
/// from the vector.
Value *findScalarElement(Value *V, unsigned EltNo);

/// If all non-negative \p Mask elements are the same value, return that value.
/// If all elements are negative (undefined) or \p Mask contains different
/// non-negative values, return -1.
int getSplatIndex(ArrayRef<int> Mask);

/// Get splat value if the input is a splat vector or return nullptr.
/// The value may be extracted from a splat constants vector or from
/// a sequence of instructions that broadcast a single value into a vector.
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/Analysis/VectorUtils.cpp
Expand Up @@ -307,6 +307,24 @@ Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
return nullptr;
}

int llvm::getSplatIndex(ArrayRef<int> Mask) {
int SplatIndex = -1;
for (int M : Mask) {
// Ignore invalid (undefined) mask elements.
if (M < 0)
continue;

// There can be only 1 non-negative mask element value if this is a splat.
if (SplatIndex != -1 && SplatIndex != M)
return -1;

// Initialize the splat index to the 1st non-negative mask element.
SplatIndex = M;
}
assert((SplatIndex == -1 || SplatIndex >= 0) && "Negative index?");
return SplatIndex;
}

/// Get splat value if the input is a splat vector or return nullptr.
/// This function is not fully general. It checks only 2 cases:
/// the input value is (1) a splat constant vector or (2) a sequence
Expand Down
11 changes: 11 additions & 0 deletions llvm/unittests/Analysis/VectorUtilsTest.cpp
Expand Up @@ -98,6 +98,17 @@ TEST_F(BasicTest, isSplat) {
EXPECT_FALSE(isSplatValue(SplatWithUndefC));
}

TEST_F(BasicTest, getSplatIndex) {
EXPECT_EQ(getSplatIndex({0,0,0}), 0);
EXPECT_EQ(getSplatIndex({1,0,0}), -1); // no splat
EXPECT_EQ(getSplatIndex({0,1,1}), -1); // no splat
EXPECT_EQ(getSplatIndex({42,42,42}), 42); // array size is independent of splat index
EXPECT_EQ(getSplatIndex({42,42,-1}), 42); // ignore negative
EXPECT_EQ(getSplatIndex({-1,42,-1}), 42); // ignore negatives
EXPECT_EQ(getSplatIndex({-4,42,-42}), 42); // ignore all negatives
EXPECT_EQ(getSplatIndex({-4,-1,-42}), -1); // all negative values map to -1
}

TEST_F(VectorUtilsTest, isSplatValue_00) {
parseAssembly(
"define <2 x i8> @test(<2 x i8> %x) {\n"
Expand Down

0 comments on commit 686a038

Please sign in to comment.