diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 47f8c143b754c1..a4b14ad5eda1c9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1356,6 +1356,27 @@ Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO, return nullptr; } +/// Look for the pattern that conditionally negates a value via math operations: +/// cond.splat = sext i1 cond +/// sub = add cond.splat, x +/// xor = xor sub, cond.splat +/// and rewrite it to do the same, but via logical operations: +/// value.neg = sub 0, value +/// cond = select i1 neg, value.neg, value +Instruction *InstCombinerImpl::canonicalizeConditionalNegationViaMathToSelect( + BinaryOperator &I) { + assert(I.getOpcode() == BinaryOperator::Xor && "Only for xor!"); + Value *Cond, *X; + // As per complexity ordering, `xor` is not commutative here. + if (!match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value())) || + !match(I.getOperand(1), m_SExt(m_Value(Cond))) || + !Cond->getType()->isIntOrIntVectorTy(1) || + !match(I.getOperand(0), m_c_Add(m_SExt(m_Deferred(Cond)), m_Value(X)))) + return nullptr; + return SelectInst::Create(Cond, Builder.CreateNeg(X, X->getName() + ".neg"), + X); +} + /// This a limited reassociation for a special case (see above) where we are /// checking if two values are either both NAN (unordered) or not-NAN (ordered). /// This could be handled more generally in '-reassociation', but it seems like @@ -4237,5 +4258,8 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) return Folded; + if (Instruction *Folded = canonicalizeConditionalNegationViaMathToSelect(I)) + return Folded; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index f700cdb84d573e..bfbc31e10a80a5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -365,6 +365,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Instruction *foldLogicOfIsFPClass(BinaryOperator &Operator, Value *LHS, Value *RHS); + Instruction * + canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i); + Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI, bool IsAnd, bool IsLogical = false); diff --git a/llvm/test/Transforms/InstCombine/conditional-negation.ll b/llvm/test/Transforms/InstCombine/conditional-negation.ll index b1b704011ec825..18b2ac4c52ed7b 100644 --- a/llvm/test/Transforms/InstCombine/conditional-negation.ll +++ b/llvm/test/Transforms/InstCombine/conditional-negation.ll @@ -4,9 +4,8 @@ ; Basic pattern define i8 @t0(i8 %x, i1 %cond) { ; CHECK-LABEL: @t0( -; CHECK-NEXT: [[COND_SPLAT:%.*]] = sext i1 [[COND:%.*]] to i8 -; CHECK-NEXT: [[SUB:%.*]] = add i8 [[COND_SPLAT]], [[X:%.*]] -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[SUB]], [[COND_SPLAT]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub i8 0, [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = select i1 [[COND:%.*]], i8 [[X_NEG]], i8 [[X]] ; CHECK-NEXT: ret i8 [[XOR]] ; %cond.splat = sext i1 %cond to i8 @@ -16,9 +15,8 @@ define i8 @t0(i8 %x, i1 %cond) { } define <2 x i8> @t0_vec(<2 x i8> %x, <2 x i1> %cond) { ; CHECK-LABEL: @t0_vec( -; CHECK-NEXT: [[COND_SPLAT:%.*]] = sext <2 x i1> [[COND:%.*]] to <2 x i8> -; CHECK-NEXT: [[SUB:%.*]] = add <2 x i8> [[COND_SPLAT]], [[X:%.*]] -; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[SUB]], [[COND_SPLAT]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub <2 x i8> zeroinitializer, [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = select <2 x i1> [[COND:%.*]], <2 x i8> [[X_NEG]], <2 x i8> [[X]] ; CHECK-NEXT: ret <2 x i8> [[XOR]] ; %cond.splat = sext <2 x i1> %cond to <2 x i8> @@ -30,10 +28,8 @@ define <2 x i8> @t0_vec(<2 x i8> %x, <2 x i1> %cond) { ; Two different extensions are fine define i8 @t1(i8 %x, i1 %cond) { ; CHECK-LABEL: @t1( -; CHECK-NEXT: [[COND_SPLAT0:%.*]] = sext i1 [[COND:%.*]] to i8 -; CHECK-NEXT: [[COND_SPLAT1:%.*]] = sext i1 [[COND]] to i8 -; CHECK-NEXT: [[SUB:%.*]] = add i8 [[COND_SPLAT0]], [[X:%.*]] -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[SUB]], [[COND_SPLAT1]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub i8 0, [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = select i1 [[COND:%.*]], i8 [[X_NEG]], i8 [[X]] ; CHECK-NEXT: ret i8 [[XOR]] ; %cond.splat0 = sext i1 %cond to i8 @@ -89,10 +85,9 @@ define <2 x i8> @t3_vec(<2 x i8> %x, <2 x i2> %cond) { ; xor is not commutative here because of complexity ordering define i8 @xor.commuted(i1 %cond) { ; CHECK-LABEL: @xor.commuted( -; CHECK-NEXT: [[COND_SPLAT:%.*]] = sext i1 [[COND:%.*]] to i8 ; CHECK-NEXT: [[X:%.*]] = call i8 @gen.i8() -; CHECK-NEXT: [[SUB:%.*]] = add i8 [[X]], [[COND_SPLAT]] -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[SUB]], [[COND_SPLAT]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub i8 0, [[X]] +; CHECK-NEXT: [[XOR:%.*]] = select i1 [[COND:%.*]], i8 [[X_NEG]], i8 [[X]] ; CHECK-NEXT: ret i8 [[XOR]] ; %cond.splat = sext i1 %cond to i8 @@ -107,8 +102,8 @@ define i8 @extrause01_v1(i8 %x, i1 %cond) { ; CHECK-LABEL: @extrause01_v1( ; CHECK-NEXT: [[COND_SPLAT:%.*]] = sext i1 [[COND:%.*]] to i8 ; CHECK-NEXT: call void @use.i8(i8 [[COND_SPLAT]]) -; CHECK-NEXT: [[SUB:%.*]] = add i8 [[COND_SPLAT]], [[X:%.*]] -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[SUB]], [[COND_SPLAT]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub i8 0, [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = select i1 [[COND]], i8 [[X_NEG]], i8 [[X]] ; CHECK-NEXT: ret i8 [[XOR]] ; %cond.splat = sext i1 %cond to i8 @@ -153,9 +148,8 @@ define i8 @extrause001_v2(i8 %x, i1 %cond) { ; CHECK-LABEL: @extrause001_v2( ; CHECK-NEXT: [[COND_SPLAT0:%.*]] = sext i1 [[COND:%.*]] to i8 ; CHECK-NEXT: call void @use.i8(i8 [[COND_SPLAT0]]) -; CHECK-NEXT: [[COND_SPLAT1:%.*]] = sext i1 [[COND]] to i8 -; CHECK-NEXT: [[SUB:%.*]] = add i8 [[COND_SPLAT0]], [[X:%.*]] -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[SUB]], [[COND_SPLAT1]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub i8 0, [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = select i1 [[COND]], i8 [[X_NEG]], i8 [[X]] ; CHECK-NEXT: ret i8 [[XOR]] ; %cond.splat0 = sext i1 %cond to i8 @@ -167,11 +161,10 @@ define i8 @extrause001_v2(i8 %x, i1 %cond) { } define i8 @extrause010_v2(i8 %x, i1 %cond) { ; CHECK-LABEL: @extrause010_v2( -; CHECK-NEXT: [[COND_SPLAT0:%.*]] = sext i1 [[COND:%.*]] to i8 -; CHECK-NEXT: [[COND_SPLAT1:%.*]] = sext i1 [[COND]] to i8 +; CHECK-NEXT: [[COND_SPLAT1:%.*]] = sext i1 [[COND:%.*]] to i8 ; CHECK-NEXT: call void @use.i8(i8 [[COND_SPLAT1]]) -; CHECK-NEXT: [[SUB:%.*]] = add i8 [[COND_SPLAT0]], [[X:%.*]] -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[SUB]], [[COND_SPLAT1]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub i8 0, [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = select i1 [[COND]], i8 [[X_NEG]], i8 [[X]] ; CHECK-NEXT: ret i8 [[XOR]] ; %cond.splat0 = sext i1 %cond to i8 @@ -187,8 +180,8 @@ define i8 @extrause011_v2(i8 %x, i1 %cond) { ; CHECK-NEXT: call void @use.i8(i8 [[COND_SPLAT0]]) ; CHECK-NEXT: [[COND_SPLAT1:%.*]] = sext i1 [[COND]] to i8 ; CHECK-NEXT: call void @use.i8(i8 [[COND_SPLAT1]]) -; CHECK-NEXT: [[SUB:%.*]] = add i8 [[COND_SPLAT0]], [[X:%.*]] -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[SUB]], [[COND_SPLAT1]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub i8 0, [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = select i1 [[COND]], i8 [[X_NEG]], i8 [[X]] ; CHECK-NEXT: ret i8 [[XOR]] ; %cond.splat0 = sext i1 %cond to i8 @@ -202,10 +195,10 @@ define i8 @extrause011_v2(i8 %x, i1 %cond) { define i8 @extrause100_v2(i8 %x, i1 %cond) { ; CHECK-LABEL: @extrause100_v2( ; CHECK-NEXT: [[COND_SPLAT0:%.*]] = sext i1 [[COND:%.*]] to i8 -; CHECK-NEXT: [[COND_SPLAT1:%.*]] = sext i1 [[COND]] to i8 ; CHECK-NEXT: [[SUB:%.*]] = add i8 [[COND_SPLAT0]], [[X:%.*]] ; CHECK-NEXT: call void @use.i8(i8 [[SUB]]) -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[SUB]], [[COND_SPLAT1]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub i8 0, [[X]] +; CHECK-NEXT: [[XOR:%.*]] = select i1 [[COND]], i8 [[X_NEG]], i8 [[X]] ; CHECK-NEXT: ret i8 [[XOR]] ; %cond.splat0 = sext i1 %cond to i8 @@ -219,10 +212,10 @@ define i8 @extrause101_v2(i8 %x, i1 %cond) { ; CHECK-LABEL: @extrause101_v2( ; CHECK-NEXT: [[COND_SPLAT0:%.*]] = sext i1 [[COND:%.*]] to i8 ; CHECK-NEXT: call void @use.i8(i8 [[COND_SPLAT0]]) -; CHECK-NEXT: [[COND_SPLAT1:%.*]] = sext i1 [[COND]] to i8 ; CHECK-NEXT: [[SUB:%.*]] = add i8 [[COND_SPLAT0]], [[X:%.*]] ; CHECK-NEXT: call void @use.i8(i8 [[SUB]]) -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[SUB]], [[COND_SPLAT1]] +; CHECK-NEXT: [[X_NEG:%.*]] = sub i8 0, [[X]] +; CHECK-NEXT: [[XOR:%.*]] = select i1 [[COND]], i8 [[X_NEG]], i8 [[X]] ; CHECK-NEXT: ret i8 [[XOR]] ; %cond.splat0 = sext i1 %cond to i8