diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index a4d86d751c2f5..9bbedb292fbb5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2159,20 +2159,30 @@ static Instruction *matchOrConcat(Instruction &Or, LowerSrc->getType()->getScalarSizeInBits() != HalfWidth) return nullptr; - // Find matching bswap instructions. - // TODO: Add more patterns (bitreverse?) + auto ConcatIntrinsicCalls = [&](Intrinsic::ID id, Value *Lo, Value *Hi) { + Value *NewLower = Builder.CreateZExt(Lo, Ty); + Value *NewUpper = Builder.CreateZExt(Hi, Ty); + NewUpper = Builder.CreateShl(NewUpper, HalfWidth); + Value *BinOp = Builder.CreateOr(NewLower, NewUpper); + Function *F = Intrinsic::getDeclaration(Or.getModule(), id, Ty); + return Builder.CreateCall(F, BinOp); + }; + + // BSWAP: Push the concat down, swapping the lower/upper sources. + // concat(bswap(x),bswap(y)) -> bswap(concat(x,y)) Value *LowerBSwap, *UpperBSwap; - if (!match(LowerSrc, m_BSwap(m_Value(LowerBSwap))) || - !match(UpperSrc, m_BSwap(m_Value(UpperBSwap)))) - return nullptr; + if (match(LowerSrc, m_BSwap(m_Value(LowerBSwap))) && + match(UpperSrc, m_BSwap(m_Value(UpperBSwap)))) + return ConcatIntrinsicCalls(Intrinsic::bswap, UpperBSwap, LowerBSwap); - // Push the concat down, swapping the lower/upper sources. - Value *NewLower = Builder.CreateZExt(UpperBSwap, Ty); - Value *NewUpper = Builder.CreateZExt(LowerBSwap, Ty); - NewUpper = Builder.CreateShl(NewUpper, HalfWidth); - Value *BinOp = Builder.CreateOr(NewLower, NewUpper); - Function *F = Intrinsic::getDeclaration(Or.getModule(), Intrinsic::bswap, Ty); - return Builder.CreateCall(F, BinOp); + // BITREVERSE: Push the concat down, swapping the lower/upper sources. + // concat(bitreverse(x),bitreverse(y)) -> bitreverse(concat(x,y)) + Value *LowerBRev, *UpperBRev; + if (match(LowerSrc, m_BitReverse(m_Value(LowerBRev))) && + match(UpperSrc, m_BitReverse(m_Value(UpperBRev)))) + return ConcatIntrinsicCalls(Intrinsic::bitreverse, UpperBRev, LowerBRev); + + return nullptr; } /// If all elements of two constant vectors are 0/-1 and inverses, return true. diff --git a/llvm/test/Transforms/InstCombine/or-concat.ll b/llvm/test/Transforms/InstCombine/or-concat.ll index 77cdaa9a37ddc..4148e4900f7e3 100644 --- a/llvm/test/Transforms/InstCombine/or-concat.ll +++ b/llvm/test/Transforms/InstCombine/or-concat.ll @@ -72,16 +72,8 @@ declare i32 @llvm.bswap.i32(i32) define i64 @concat_bitreverse32_unary_split(i64 %a0) { ; CHECK-LABEL: @concat_bitreverse32_unary_split( -; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A0:%.*]], 32 -; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32 -; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[A0]] to i32 -; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.bitreverse.i32(i32 [[TMP2]]) -; CHECK-NEXT: [[TMP5:%.*]] = tail call i32 @llvm.bitreverse.i32(i32 [[TMP3]]) -; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP4]] to i64 -; CHECK-NEXT: [[TMP7:%.*]] = zext i32 [[TMP5]] to i64 -; CHECK-NEXT: [[TMP8:%.*]] = shl nuw i64 [[TMP7]], 32 -; CHECK-NEXT: [[TMP9:%.*]] = or i64 [[TMP8]], [[TMP6]] -; CHECK-NEXT: ret i64 [[TMP9]] +; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.bitreverse.i64(i64 [[A0:%.*]]) +; CHECK-NEXT: ret i64 [[TMP1]] ; %1 = lshr i64 %a0, 32 %2 = trunc i64 %1 to i32 @@ -98,15 +90,10 @@ define i64 @concat_bitreverse32_unary_split(i64 %a0) { define i64 @concat_bitreverse32_unary_flip(i64 %a0) { ; CHECK-LABEL: @concat_bitreverse32_unary_flip( ; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A0:%.*]], 32 -; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32 -; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[A0]] to i32 -; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.bitreverse.i32(i32 [[TMP2]]) -; CHECK-NEXT: [[TMP5:%.*]] = tail call i32 @llvm.bitreverse.i32(i32 [[TMP3]]) -; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP4]] to i64 -; CHECK-NEXT: [[TMP7:%.*]] = zext i32 [[TMP5]] to i64 -; CHECK-NEXT: [[TMP8:%.*]] = shl nuw i64 [[TMP6]], 32 -; CHECK-NEXT: [[TMP9:%.*]] = or i64 [[TMP8]], [[TMP7]] -; CHECK-NEXT: ret i64 [[TMP9]] +; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[A0]], 32 +; CHECK-NEXT: [[TMP3:%.*]] = or i64 [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.bitreverse.i64(i64 [[TMP3]]) +; CHECK-NEXT: ret i64 [[TMP4]] ; %1 = lshr i64 %a0, 32 %2 = trunc i64 %1 to i32 @@ -122,13 +109,12 @@ define i64 @concat_bitreverse32_unary_flip(i64 %a0) { define i64 @concat_bitreverse32_binary(i32 %a0, i32 %a1) { ; CHECK-LABEL: @concat_bitreverse32_binary( -; CHECK-NEXT: [[TMP1:%.*]] = tail call i32 @llvm.bitreverse.i32(i32 [[A0:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = tail call i32 @llvm.bitreverse.i32(i32 [[A1:%.*]]) -; CHECK-NEXT: [[TMP3:%.*]] = zext i32 [[TMP1]] to i64 -; CHECK-NEXT: [[TMP4:%.*]] = zext i32 [[TMP2]] to i64 -; CHECK-NEXT: [[TMP5:%.*]] = shl nuw i64 [[TMP4]], 32 -; CHECK-NEXT: [[TMP6:%.*]] = or i64 [[TMP5]], [[TMP3]] -; CHECK-NEXT: ret i64 [[TMP6]] +; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[A1:%.*]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[A0:%.*]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = shl nuw i64 [[TMP2]], 32 +; CHECK-NEXT: [[TMP4:%.*]] = or i64 [[TMP3]], [[TMP1]] +; CHECK-NEXT: [[TMP5:%.*]] = call i64 @llvm.bitreverse.i64(i64 [[TMP4]]) +; CHECK-NEXT: ret i64 [[TMP5]] ; %1 = tail call i32 @llvm.bitreverse.i32(i32 %a0) %2 = tail call i32 @llvm.bitreverse.i32(i32 %a1)