diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 6742ac6802078b..2ba326d791c26b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2036,12 +2036,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::vector_reduce_xor: { if (IID == Intrinsic::vector_reduce_xor) { - // Convert vector_reduce_xor(zext()) to - // (ZExtOrTrunc(ctpop(bitcast to iN) & 1)). - // Convert vector_reduce_xor(sext()) to - // -(ZExtOrTrunc(ctpop(bitcast to iN) & 1)). - // Convert vector_reduce_xor() to - // ZExtOrTrunc(ctpop(bitcast to iN) & 1). + // Exclusive disjunction reduction over the vector with + // (potentially-extended) i1 element type is actually a + // (potentially-extended) parity check: + // vector_reduce_xor(?ext()) + // --> + // ?ext(trunc(vector_reduce_and() to i1)) Value *Arg = II->getArgOperand(0); Value *Vect; if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { @@ -2050,12 +2050,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *V = Builder.CreateBitCast( Vect, Builder.getIntNTy(FTy->getNumElements())); Value *Res = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V); - Res = Builder.CreateAnd(Res, ConstantInt::get(Res->getType(), 1)); - if (Res->getType() != II->getType()) - Res = Builder.CreateZExtOrTrunc(Res, II->getType()); - if (Arg != Vect && - cast(Arg)->getOpcode() == Instruction::SExt) - Res = Builder.CreateNeg(Res); + Res = Builder.CreateTrunc(Res, + IntegerType::get(Res->getContext(), 1)); + if (Arg != Vect) + Res = Builder.CreateCast(cast(Arg)->getOpcode(), Res, + II->getType()); return replaceInstUsesWith(CI, Res); } } diff --git a/llvm/test/Transforms/InstCombine/reduction-xor-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-xor-sext-zext-i1.ll index e4022355f5dacd..bbaf32945e1bb3 100644 --- a/llvm/test/Transforms/InstCombine/reduction-xor-sext-zext-i1.ll +++ b/llvm/test/Transforms/InstCombine/reduction-xor-sext-zext-i1.ll @@ -18,9 +18,9 @@ define i32 @reduce_xor_sext(<4 x i1> %x) { ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4 ; CHECK-NEXT: [[TMP2:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP1]]), !range [[RNG1:![0-9]+]] ; CHECK-NEXT: [[TMP3:%.*]] = and i4 [[TMP2]], 1 -; CHECK-NEXT: [[TMP4:%.*]] = zext i4 [[TMP3]] to i32 -; CHECK-NEXT: [[TMP5:%.*]] = sub nsw i32 0, [[TMP4]] -; CHECK-NEXT: ret i32 [[TMP5]] +; CHECK-NEXT: [[SEXT:%.*]] = sub nsw i4 0, [[TMP3]] +; CHECK-NEXT: [[TMP4:%.*]] = sext i4 [[SEXT]] to i32 +; CHECK-NEXT: ret i32 [[TMP4]] ; %sext = sext <4 x i1> %x to <4 x i32> %res = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> %sext) @@ -45,8 +45,8 @@ define i16 @reduce_xor_sext_same(<16 x i1> %x) { ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16 ; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[TMP1]]), !range [[RNG2:![0-9]+]] ; CHECK-NEXT: [[TMP3:%.*]] = and i16 [[TMP2]], 1 -; CHECK-NEXT: [[TMP4:%.*]] = sub nsw i16 0, [[TMP3]] -; CHECK-NEXT: ret i16 [[TMP4]] +; CHECK-NEXT: [[SEXT:%.*]] = sub nsw i16 0, [[TMP3]] +; CHECK-NEXT: ret i16 [[SEXT]] ; %sext = sext <16 x i1> %x to <16 x i16> %res = call i16 @llvm.vector.reduce.xor.v16i16(<16 x i16> %sext)