Skip to content

Commit

Permalink
[ADT] Deprecate is_splat and replace all uses with all_equal
Browse files Browse the repository at this point in the history
See the discussion thread for more details:
https://discourse.llvm.org/t/adt-is-splat-and-empty-ranges/64692

Reviewed By: dblaikie

Differential Revision: https://reviews.llvm.org/D132335
  • Loading branch information
kuhar committed Aug 23, 2022
1 parent 796124f commit 6fa87ec
Show file tree
Hide file tree
Showing 17 changed files with 36 additions and 50 deletions.
12 changes: 10 additions & 2 deletions llvm/include/llvm/ADT/STLExtras.h
Expand Up @@ -1795,12 +1795,20 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
}

/// Returns true if Range consists of the same value repeated multiple times.
template <typename R> bool is_splat(R &&Range) {
template <typename R>
LLVM_DEPRECATED(
"Use 'all_equal(Range)' or '!empty(Range) && all_equal(Range)' instead.",
"all_equal")
bool is_splat(R &&Range) {
return !llvm::empty(Range) && all_equal(Range);
}

/// Returns true if Values consists of the same value repeated multiple times.
template <typename T> bool is_splat(std::initializer_list<T> Values) {
template <typename T>
LLVM_DEPRECATED(
"Use 'all_equal(Values)' or '!empty(Values) && all_equal(Values)' instead.",
"all_equal")
bool is_splat(std::initializer_list<T> Values) {
return is_splat<std::initializer_list<T>>(std::move(Values));
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/InstructionSimplify.cpp
Expand Up @@ -4997,7 +4997,7 @@ static Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1,
// value type is same as the input vectors' type.
if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0))
if (Q.isUndefValue(Op1) && RetTy == InVecTy &&
is_splat(OpShuf->getShuffleMask()))
all_equal(OpShuf->getShuffleMask()))
return Op0;

// All remaining transformation depend on the value of the mask, which is
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Analysis/VectorUtils.cpp
Expand Up @@ -398,7 +398,7 @@ bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) {
if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
// FIXME: We can safely allow undefs here. If Index was specified, we will
// check that the mask elt is defined at the required index.
if (!is_splat(Shuf->getShuffleMask()))
if (!all_equal(Shuf->getShuffleMask()))
return false;

// Match any index.
Expand Down Expand Up @@ -478,7 +478,7 @@ bool llvm::widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
if (SliceFront < 0) {
// Negative values (undef or other "sentinel" values) must be equal across
// the entire slice.
if (!is_splat(MaskSlice))
if (!all_equal(MaskSlice))
return false;
ScaledMask.push_back(SliceFront);
} else {
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -23713,7 +23713,7 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
// demanded elements analysis. It is further limited to not change a splat
// of an inserted scalar because that may be optimized better by
// load-folding or other target-specific behaviors.
if (isConstOrConstSplat(RHS) && Shuf0 && is_splat(Shuf0->getMask()) &&
if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) &&
Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
// binop (splat X), (splat C) --> splat (binop X, C)
Expand All @@ -23722,7 +23722,7 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
Shuf0->getMask());
}
if (isConstOrConstSplat(LHS) && Shuf1 && is_splat(Shuf1->getMask()) &&
if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) &&
Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
// binop (splat C), (splat X) --> splat (binop C, X)
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Expand Up @@ -3287,7 +3287,7 @@ void SelectionDAGBuilder::visitSelect(const User &I) {
Flags.copyFMF(*FPOp);

// Min/max matching is only viable if all output VTs are the same.
if (is_splat(ValueVTs)) {
if (all_equal(ValueVTs)) {
EVT VT = ValueVTs[0];
LLVMContext &Ctx = *DAG.getContext();
auto &TLI = DAG.getTargetLoweringInfo();
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/IR/Instructions.cpp
Expand Up @@ -2061,7 +2061,7 @@ bool ShuffleVectorInst::isValidOperands(const Value *V1, const Value *V2,
return false;

if (isa<ScalableVectorType>(V1->getType()))
if ((Mask[0] != 0 && Mask[0] != UndefMaskElem) || !is_splat(Mask))
if ((Mask[0] != 0 && Mask[0] != UndefMaskElem) || !all_equal(Mask))
return false;

return true;
Expand Down Expand Up @@ -2152,7 +2152,7 @@ Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(ArrayRef<int> Mask,
Type *ResultTy) {
Type *Int32Ty = Type::getInt32Ty(ResultTy->getContext());
if (isa<ScalableVectorType>(ResultTy)) {
assert(is_splat(Mask) && "Unexpected shuffle");
assert(all_equal(Mask) && "Unexpected shuffle");
Type *VecTy = VectorType::get(Int32Ty, Mask.size(), true);
if (Mask[0] == 0)
return Constant::getNullValue(VecTy);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -12890,7 +12890,7 @@ static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {

static bool isSplatShuffle(Value *V) {
if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V))
return is_splat(Shuf->getShuffleMask());
return all_equal(Shuf->getShuffleMask());
return false;
}

Expand Down Expand Up @@ -20827,7 +20827,7 @@ bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters(
// All non aggregate members of the type must have the same type
SmallVector<EVT> ValueVTs;
ComputeValueVTs(*this, DL, Ty, ValueVTs);
return is_splat(ValueVTs);
return all_equal(ValueVTs);
}

bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &,
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Expand Up @@ -726,7 +726,7 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc,
InstCombiner::BuilderTy &Builder) {
auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0));
if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) &&
is_splat(Shuf->getShuffleMask()) &&
all_equal(Shuf->getShuffleMask()) &&
Shuf->getType() == Shuf->getOperand(0)->getType()) {
// trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask
// trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Expand Up @@ -3141,7 +3141,7 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) {
ArrayRef<int> Mask;
if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) {
// Check whether every element of Mask is the same constant
if (is_splat(Mask)) {
if (all_equal(Mask)) {
auto *VecTy = cast<VectorType>(SrcType);
auto *EltTy = cast<IntegerType>(VecTy->getElementType());
if (C->isSplat(EltTy->getBitWidth())) {
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Scalar/NewGVN.cpp
Expand Up @@ -3166,7 +3166,7 @@ bool NewGVN::singleReachablePHIPath(
make_filter_range(MP->operands(), ReachableOperandPred);
SmallVector<const Value *, 32> OperandList;
llvm::copy(FilteredPhiArgs, std::back_inserter(OperandList));
bool Okay = is_splat(OperandList);
bool Okay = all_equal(OperandList);
if (Okay)
return singleReachablePHIPath(Visited, cast<MemoryAccess>(OperandList[0]),
Second);
Expand Down Expand Up @@ -3261,7 +3261,7 @@ void NewGVN::verifyMemoryCongruency() const {
const MemoryDef *MD = cast<MemoryDef>(U);
return ValueToClass.lookup(MD->getMemoryInst());
});
assert(is_splat(PhiOpClasses) &&
assert(all_equal(PhiOpClasses) &&
"All MemoryPhi arguments should be in the same class");
}
}
Expand Down
22 changes: 0 additions & 22 deletions llvm/unittests/ADT/STLExtrasTest.cpp
Expand Up @@ -611,28 +611,6 @@ TEST(STLExtrasTest, AllEqualInitializerList) {
EXPECT_TRUE(all_equal({1, 1, 1}));
}

TEST(STLExtrasTest, IsSplat) {
std::vector<int> V;
EXPECT_FALSE(is_splat(V));

V.push_back(1);
EXPECT_TRUE(is_splat(V));

V.push_back(1);
V.push_back(1);
EXPECT_TRUE(is_splat(V));

V.push_back(2);
EXPECT_FALSE(is_splat(V));
}

TEST(STLExtrasTest, IsSplatInitializerList) {
EXPECT_TRUE(is_splat({1}));
EXPECT_TRUE(is_splat({1, 1}));
EXPECT_FALSE(is_splat({1, 2}));
EXPECT_TRUE(is_splat({1, 1, 1}));
}

TEST(STLExtrasTest, to_address) {
int *V1 = new int;
EXPECT_EQ(V1, to_address(V1));
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/OpBase.td
Expand Up @@ -2432,7 +2432,7 @@ class TCOpIsBroadcastableToRes<int opId, int resId> : And<[
// 1) all operands involved are of shaped type and
// 2) the indices are not out of range.
class TCopVTEtAreSameAt<list<int> indices> : CPred<
"::llvm::is_splat(::llvm::map_range("
"::llvm::all_equal(::llvm::map_range("
"::mlir::ArrayRef<unsigned>({" # !interleave(indices, ", ") # "}), "
"[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); "
"}))">;
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
Expand Up @@ -77,8 +77,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
auto getOperandElementType = [](OpOperand *operand) {
return operand->get().getType().cast<ShapedType>().getElementType();
};
if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
getOperandElementType)))
if (!llvm::all_equal(llvm::map_range(genericOp.getInputAndOutputOperands(),
getOperandElementType)))
return failure();

// We can only handle the case where we have int/float elements.
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Expand Up @@ -2871,9 +2871,9 @@ LogicalResult spirv::IAddCarryOp::verify() {
if (resultType.getNumElements() != 2)
return emitOpError("expected result struct type containing two members");

if (!llvm::is_splat({operand1().getType(), operand2().getType(),
resultType.getElementType(0),
resultType.getElementType(1)}))
if (!llvm::all_equal({operand1().getType(), operand2().getType(),
resultType.getElementType(0),
resultType.getElementType(1)}))
return emitOpError(
"expected all operand types and struct member types are the same");

Expand Down Expand Up @@ -2920,9 +2920,9 @@ LogicalResult spirv::ISubBorrowOp::verify() {
if (resultType.getNumElements() != 2)
return emitOpError("expected result struct type containing two members");

if (!llvm::is_splat({operand1().getType(), operand2().getType(),
resultType.getElementType(0),
resultType.getElementType(1)}))
if (!llvm::all_equal({operand1().getType(), operand2().getType(),
resultType.getElementType(0),
resultType.getElementType(1)}))
return emitOpError(
"expected all operand types and struct member types are the same");

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Expand Up @@ -1269,7 +1269,7 @@ struct ReorderElementwiseOpsOnTranspose final
// This is an elementwise op, so all transposed operands should have the
// same type. We need to additionally check that all transposes uses the
// same map.
if (!llvm::is_splat(transposeMaps))
if (!llvm::all_equal(transposeMaps))
return rewriter.notifyMatchFailure(op, "different transpose map");

SmallVector<Value, 4> srcValues;
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/mlir-tblgen/predicate.td
Expand Up @@ -102,7 +102,7 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
}

// CHECK-LABEL: OpJAdaptor::verify
// CHECK: ::llvm::is_splat(::llvm::map_range(
// CHECK: ::llvm::all_equal(::llvm::map_range(
// CHECK-SAME: ::mlir::ArrayRef<unsigned>({0, 2, 3}),
// CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))
// CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type"
Expand Down
2 changes: 1 addition & 1 deletion mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Expand Up @@ -41,7 +41,7 @@ static void collectAllDefs(StringRef selectedDialect,
if (selectedDialect.empty()) {
// If a dialect was not specified, ensure that all found defs belong to the
// same dialect.
if (!llvm::is_splat(llvm::map_range(
if (!llvm::all_equal(llvm::map_range(
defs, [](const auto &def) { return def.getDialect(); }))) {
llvm::PrintFatalError("defs belonging to more than one dialect. Must "
"select one via '--(attr|type)defs-dialect'");
Expand Down

0 comments on commit 6fa87ec

Please sign in to comment.