Skip to content

Commit

Permalink
[InstCombine] canonicalize splat shuffle after cmp
Browse files Browse the repository at this point in the history
cmp (splat V1, M), SplatC --> splat (cmp V1, SplatC'), M

As discussed in PR44588:
https://bugs.llvm.org/show_bug.cgi?id=44588
...we try harder to push shuffles after binops than after compares.

This patch handles the special (but presumably most common case) of
splat shuffles. If both operands are splats, then we can do the
comparison on the non-splat inputs followed by splat of the compare.
That should take care of the regression noted in D73411.

There's another potential fold requested in PR37463 to scalarize the
compare, but that's another patch (and it's not clear if we can do
that without the ability to undo it later):
https://bugs.llvm.org/show_bug.cgi?id=37463

Differential Revision: https://reviews.llvm.org/D73575
  • Loading branch information
rotateright committed Jan 29, 2020
1 parent 2939fc1 commit 87f6314
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 14 deletions.
21 changes: 21 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Expand Up @@ -5369,6 +5369,27 @@ static Instruction *foldVectorCmp(CmpInst &Cmp,
return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M);
}

// Try to canonicalize compare with splatted operand and splat constant.
// TODO: We could generalize this for more than splats. See/use the code in
// InstCombiner::foldVectorBinop().
Constant *C;
if (!LHS->hasOneUse() || !match(RHS, m_Constant(C)))
return nullptr;

// Length-changing splats are ok, so adjust the constants as needed:
// cmp (shuffle V1, M), C --> shuffle (cmp V1, C'), M
Constant *ScalarC = C->getSplatValue(/* AllowUndefs */ true);
Constant *ScalarM = M->getSplatValue(/* AllowUndefs */ true);
if (ScalarC && ScalarM) {
// We allow undefs in matching, but this transform removes those for safety.
// Demanded elements analysis should be able to recover some/all of that.
C = ConstantVector::getSplat(V1Ty->getVectorNumElements(), ScalarC);
M = ConstantVector::getSplat(M->getType()->getVectorNumElements(), ScalarM);
Value *NewCmp = IsFP ? Builder.CreateFCmp(Pred, V1, C)
: Builder.CreateICmp(Pred, V1, C);
return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M);
}

return nullptr;
}

Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/InstCombine/gep-inbounds-null.ll
Expand Up @@ -91,8 +91,8 @@ define <2 x i1> @test_vector_index(i8* %base, <2 x i64> %idx) {
; CHECK-LABEL: @test_vector_index(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <2 x i8*> undef, i8* [[BASE:%.*]], i32 0
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <2 x i8*> [[DOTSPLATINSERT]], <2 x i8*> undef, <2 x i32> zeroinitializer
; CHECK-NEXT: [[CND:%.*]] = icmp eq <2 x i8*> [[DOTSPLAT]], zeroinitializer
; CHECK-NEXT: [[TMP0:%.*]] = icmp eq <2 x i8*> [[DOTSPLATINSERT]], zeroinitializer
; CHECK-NEXT: [[CND:%.*]] = shufflevector <2 x i1> [[TMP0]], <2 x i1> undef, <2 x i32> zeroinitializer
; CHECK-NEXT: ret <2 x i1> [[CND]]
;
entry:
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/Transforms/InstCombine/getelementptr.ll
Expand Up @@ -217,8 +217,8 @@ define <2 x i1> @test13_vector2(i64 %X, <2 x %S*> %P) nounwind {
; CHECK-LABEL: @test13_vector2(
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <2 x i64> undef, i64 [[X:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i64> [[DOTSPLATINSERT]], <i64 2, i64 undef>
; CHECK-NEXT: [[A_IDX:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <2 x i32> zeroinitializer
; CHECK-NEXT: [[C:%.*]] = icmp eq <2 x i64> [[A_IDX]], <i64 -4, i64 -4>
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <2 x i64> [[TMP1]], <i64 -4, i64 -4>
; CHECK-NEXT: [[C:%.*]] = shufflevector <2 x i1> [[TMP2]], <2 x i1> undef, <2 x i32> zeroinitializer
; CHECK-NEXT: ret <2 x i1> [[C]]
;
%A = getelementptr inbounds %S, <2 x %S*> %P, <2 x i64> zeroinitializer, <2 x i32> <i32 1, i32 1>, i64 %X
Expand All @@ -232,8 +232,8 @@ define <2 x i1> @test13_vector3(i64 %X, <2 x %S*> %P) nounwind {
; CHECK-LABEL: @test13_vector3(
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <2 x i64> undef, i64 [[X:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i64> [[DOTSPLATINSERT]], <i64 2, i64 undef>
; CHECK-NEXT: [[A_IDX:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <2 x i32> zeroinitializer
; CHECK-NEXT: [[C:%.*]] = icmp eq <2 x i64> [[A_IDX]], <i64 4, i64 4>
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <2 x i64> [[TMP1]], <i64 4, i64 4>
; CHECK-NEXT: [[C:%.*]] = shufflevector <2 x i1> [[TMP2]], <2 x i1> undef, <2 x i32> zeroinitializer
; CHECK-NEXT: ret <2 x i1> [[C]]
;
%A = getelementptr inbounds %S, <2 x %S*> %P, <2 x i64> zeroinitializer, <2 x i32> <i32 1, i32 1>, i64 %X
Expand Down
22 changes: 14 additions & 8 deletions llvm/test/Transforms/InstCombine/icmp-vec.ll
Expand Up @@ -291,8 +291,8 @@ define <2 x i1> @same_shuffle_inputs_icmp_extra_use3(<4 x i8> %x, <4 x i8> %y) {

define <4 x i1> @splat_icmp(<4 x i8> %x) {
; CHECK-LABEL: @splat_icmp(
; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <4 x i8> [[X:%.*]], <4 x i8> undef, <4 x i32> <i32 3, i32 3, i32 3, i32 3>
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <4 x i8> [[SPLATX]], <i8 42, i8 42, i8 42, i8 42>
; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt <4 x i8> [[X:%.*]], <i8 42, i8 42, i8 42, i8 42>
; CHECK-NEXT: [[CMP:%.*]] = shufflevector <4 x i1> [[TMP1]], <4 x i1> undef, <4 x i32> <i32 3, i32 3, i32 3, i32 3>
; CHECK-NEXT: ret <4 x i1> [[CMP]]
;
%splatx = shufflevector <4 x i8> %x, <4 x i8> undef, <4 x i32> <i32 3, i32 3, i32 3, i32 3>
Expand All @@ -302,8 +302,8 @@ define <4 x i1> @splat_icmp(<4 x i8> %x) {

define <4 x i1> @splat_icmp_undef(<4 x i8> %x) {
; CHECK-LABEL: @splat_icmp_undef(
; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <4 x i8> [[X:%.*]], <4 x i8> undef, <4 x i32> <i32 2, i32 undef, i32 undef, i32 2>
; CHECK-NEXT: [[CMP:%.*]] = icmp ult <4 x i8> [[SPLATX]], <i8 undef, i8 42, i8 undef, i8 42>
; CHECK-NEXT: [[TMP1:%.*]] = icmp ult <4 x i8> [[X:%.*]], <i8 42, i8 42, i8 42, i8 42>
; CHECK-NEXT: [[CMP:%.*]] = shufflevector <4 x i1> [[TMP1]], <4 x i1> undef, <4 x i32> <i32 2, i32 2, i32 2, i32 2>
; CHECK-NEXT: ret <4 x i1> [[CMP]]
;
%splatx = shufflevector <4 x i8> %x, <4 x i8> undef, <4 x i32> <i32 2, i32 undef, i32 undef, i32 2>
Expand All @@ -313,8 +313,8 @@ define <4 x i1> @splat_icmp_undef(<4 x i8> %x) {

define <4 x i1> @splat_icmp_larger_size(<2 x i8> %x) {
; CHECK-LABEL: @splat_icmp_larger_size(
; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <2 x i8> [[X:%.*]], <2 x i8> undef, <4 x i32> <i32 1, i32 undef, i32 1, i32 undef>
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <4 x i8> [[SPLATX]], <i8 42, i8 42, i8 undef, i8 42>
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq <2 x i8> [[X:%.*]], <i8 42, i8 42>
; CHECK-NEXT: [[CMP:%.*]] = shufflevector <2 x i1> [[TMP1]], <2 x i1> undef, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT: ret <4 x i1> [[CMP]]
;
%splatx = shufflevector <2 x i8> %x, <2 x i8> undef, <4 x i32> <i32 1, i32 undef, i32 1, i32 undef>
Expand All @@ -324,15 +324,17 @@ define <4 x i1> @splat_icmp_larger_size(<2 x i8> %x) {

define <4 x i1> @splat_fcmp_smaller_size(<5 x float> %x) {
; CHECK-LABEL: @splat_fcmp_smaller_size(
; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <5 x float> [[X:%.*]], <5 x float> undef, <4 x i32> <i32 1, i32 undef, i32 1, i32 undef>
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq <4 x float> [[SPLATX]], <float 4.200000e+01, float 4.200000e+01, float undef, float 4.200000e+01>
; CHECK-NEXT: [[TMP1:%.*]] = fcmp oeq <5 x float> [[X:%.*]], <float 4.200000e+01, float 4.200000e+01, float 4.200000e+01, float 4.200000e+01, float 4.200000e+01>
; CHECK-NEXT: [[CMP:%.*]] = shufflevector <5 x i1> [[TMP1]], <5 x i1> undef, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT: ret <4 x i1> [[CMP]]
;
%splatx = shufflevector <5 x float> %x, <5 x float> undef, <4 x i32> <i32 1, i32 undef, i32 1, i32 undef>
%cmp = fcmp oeq <4 x float> %splatx, <float 42.0, float 42.0, float undef, float 42.0>
ret <4 x i1> %cmp
}

; Negative test

define <4 x i1> @splat_icmp_extra_use(<4 x i8> %x) {
; CHECK-LABEL: @splat_icmp_extra_use(
; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <4 x i8> [[X:%.*]], <4 x i8> undef, <4 x i32> <i32 3, i32 3, i32 3, i32 3>
Expand All @@ -346,6 +348,8 @@ define <4 x i1> @splat_icmp_extra_use(<4 x i8> %x) {
ret <4 x i1> %cmp
}

; Negative test

define <4 x i1> @not_splat_icmp(<4 x i8> %x) {
; CHECK-LABEL: @not_splat_icmp(
; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <4 x i8> [[X:%.*]], <4 x i8> undef, <4 x i32> <i32 3, i32 2, i32 3, i32 3>
Expand All @@ -357,6 +361,8 @@ define <4 x i1> @not_splat_icmp(<4 x i8> %x) {
ret <4 x i1> %cmp
}

; Negative test

define <4 x i1> @not_splat_icmp2(<4 x i8> %x) {
; CHECK-LABEL: @not_splat_icmp2(
; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <4 x i8> [[X:%.*]], <4 x i8> undef, <4 x i32> <i32 2, i32 2, i32 2, i32 2>
Expand Down

0 comments on commit 87f6314

Please sign in to comment.