Skip to content

Commit

Permalink
Recommit "[InstCombine] Fold and-reduce idiom"
Browse files Browse the repository at this point in the history
Checks of original vector types made more thorough.

Differential Revision: https://reviews.llvm.org/D118317
  • Loading branch information
xortator committed Jan 29, 2022
1 parent c38c134 commit 3b194ca
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 26 deletions.
51 changes: 51 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Expand Up @@ -5882,6 +5882,54 @@ static Instruction *foldICmpInvariantGroup(ICmpInst &I) {
return nullptr;
}

/// This function folds patterns produced by lowering of reduce idioms, such as
/// llvm.vector.reduce.and which are lowered into instruction chains. This code
/// attempts to generate fewer number of scalar comparisons instead of vector
/// comparisons when possible.
static Instruction *foldReductionIdiom(ICmpInst &I,
InstCombiner::BuilderTy &Builder,
const DataLayout &DL) {
if (I.getType()->isVectorTy())
return nullptr;
ICmpInst::Predicate OuterPred, InnerPred;
Value *LHS, *RHS;

// Match lowering of @llvm.vector.reduce.and. Turn
/// %vec_ne = icmp ne <8 x i8> %lhs, %rhs
/// %scalar_ne = bitcast <8 x i1> %vec_ne to i8
/// %all_eq = icmp eq i8 %scalar_ne, 0
///
/// into
///
/// %lhs.scalar = bitcast <8 x i8> %lhs to i64
/// %rhs.scalar = bitcast <8 x i8> %rhs to i64
/// %all_eq = icmp eq i64 %lhs.scalar, %rhs.scalar
if (!match(&I, m_ICmp(OuterPred,
m_OneUse(m_BitCast(m_OneUse(
m_ICmp(InnerPred, m_Value(LHS), m_Value(RHS))))),
m_Zero())))
return nullptr;
auto *LHSTy = dyn_cast<FixedVectorType>(LHS->getType());
if (!LHSTy || !LHSTy->getElementType()->isIntegerTy())
return nullptr;
unsigned NumBits =
LHSTy->getNumElements() * LHSTy->getElementType()->getIntegerBitWidth();
// TODO: Relax this to "not wider than max legal integer type"?
if (!DL.isLegalInteger(NumBits))
return nullptr;

// TODO: Generalize to isEquality and support other patterns.
if (OuterPred == ICmpInst::ICMP_EQ && InnerPred == ICmpInst::ICMP_NE) {
auto *ScalarTy = Builder.getIntNTy(NumBits);
LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar");
RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar");
return ICmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, LHS, RHS,
I.getName());
}

return nullptr;
}

Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
bool Changed = false;
const SimplifyQuery Q = SQ.getWithInstruction(&I);
Expand Down Expand Up @@ -6124,6 +6172,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpInvariantGroup(I))
return Res;

if (Instruction *Res = foldReductionIdiom(I, Builder, DL))
return Res;

return Changed ? &I : nullptr;
}

Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Transforms/InstCombine/icmp-vec.ll
Expand Up @@ -404,9 +404,9 @@ define <vscale x 2 x i1> @icmp_logical_or_scalablevec(<vscale x 2 x i64> %x, <vs

define i1 @eq_cast_eq-1(<2 x i4> %x, <2 x i4> %y) {
; CHECK-LABEL: @eq_cast_eq-1(
; CHECK-NEXT: [[IC:%.*]] = icmp ne <2 x i4> [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[IC]] to i2
; CHECK-NEXT: [[R:%.*]] = icmp eq i2 [[TMP1]], 0
; CHECK-NEXT: [[X_SCALAR:%.*]] = bitcast <2 x i4> [[X:%.*]] to i8
; CHECK-NEXT: [[Y_SCALAR:%.*]] = bitcast <2 x i4> [[Y:%.*]] to i8
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[X_SCALAR]], [[Y_SCALAR]]
; CHECK-NEXT: ret i1 [[R]]
;
%ic = icmp eq <2 x i4> %x, %y
Expand Down
28 changes: 12 additions & 16 deletions llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll
Expand Up @@ -100,14 +100,12 @@ define i64 @reduce_and_zext_external_use(<8 x i1> %x) {
define i1 @reduce_and_pointer_cast(i8* %arg, i8* %arg1) {
; CHECK-LABEL: @reduce_and_pointer_cast(
; CHECK-NEXT: bb:
; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>*
; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>*
; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8
; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]]
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[TMP0]], 0
; CHECK-NEXT: ret i1 [[TMP1]]
; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64*
; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64*
; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[LHS1]], [[RHS2]]
; CHECK-NEXT: ret i1 [[TMP2]]
;
bb:
%ptr1 = bitcast i8* %arg1 to <8 x i8>*
Expand Down Expand Up @@ -144,14 +142,12 @@ bb:
define i1 @reduce_and_pointer_cast_ne(i8* %arg, i8* %arg1) {
; CHECK-LABEL: @reduce_and_pointer_cast_ne(
; CHECK-NEXT: bb:
; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>*
; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>*
; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8
; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]]
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
; CHECK-NEXT: [[TMP1:%.*]] = icmp ne i8 [[TMP0]], 0
; CHECK-NEXT: ret i1 [[TMP1]]
; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64*
; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64*
; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[LHS1]], [[RHS2]]
; CHECK-NEXT: ret i1 [[TMP2]]
;
bb:
%ptr1 = bitcast i8* %arg1 to <8 x i8>*
Expand Down
12 changes: 5 additions & 7 deletions llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll
Expand Up @@ -100,13 +100,11 @@ define i64 @reduce_or_zext_external_use(<8 x i1> %x) {
define i1 @reduce_or_pointer_cast(i8* %arg, i8* %arg1) {
; CHECK-LABEL: @reduce_or_pointer_cast(
; CHECK-NEXT: bb:
; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>*
; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>*
; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8
; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]]
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP0]], 0
; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64*
; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64*
; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i64 [[LHS1]], [[RHS2]]
; CHECK-NEXT: ret i1 [[DOTNOT]]
;
bb:
Expand Down

0 comments on commit 3b194ca

Please sign in to comment.