diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index d7971e8e3caea..6e46898634070 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3740,6 +3740,82 @@ static Instruction *foldIntegerPackFromVector(Instruction &I, return CastInst::Create(Instruction::BitCast, MaskedVec, I.getType()); } +/// Match \p V as "lshr -> mask -> zext -> shl". +/// +/// \p Int is the underlying integer being extracted from. +/// \p Mask is a bitmask identifying which bits of the integer are being +/// extracted. \p Offset identifies which bit of the result \p V corresponds to +/// the least significant bit of \p Int +static bool matchZExtedSubInteger(Value *V, Value *&Int, APInt &Mask, + uint64_t &Offset, bool &IsShlNUW, + bool &IsShlNSW) { + Value *ShlOp0; + uint64_t ShlAmt = 0; + if (!match(V, m_OneUse(m_Shl(m_Value(ShlOp0), m_ConstantInt(ShlAmt))))) + return false; + + IsShlNUW = cast(V)->hasNoUnsignedWrap(); + IsShlNSW = cast(V)->hasNoSignedWrap(); + + Value *ZExtOp0; + if (!match(ShlOp0, m_OneUse(m_ZExt(m_Value(ZExtOp0))))) + return false; + + Value *MaskedOp0; + const APInt *ShiftedMaskConst = nullptr; + if (!match(ZExtOp0, m_CombineOr(m_OneUse(m_And(m_Value(MaskedOp0), + m_APInt(ShiftedMaskConst))), + m_Value(MaskedOp0)))) + return false; + + uint64_t LShrAmt = 0; + if (!match(MaskedOp0, + m_CombineOr(m_OneUse(m_LShr(m_Value(Int), m_ConstantInt(LShrAmt))), + m_Value(Int)))) + return false; + + if (LShrAmt > ShlAmt) + return false; + Offset = ShlAmt - LShrAmt; + + Mask = ShiftedMaskConst ? ShiftedMaskConst->shl(LShrAmt) + : APInt::getBitsSetFrom( + Int->getType()->getScalarSizeInBits(), LShrAmt); + + return true; +} + +/// Try to fold the join of two scalar integers whose bits are unpacked and +/// zexted from the same source integer. +static Value *foldIntegerRepackThroughZExt(Value *Lhs, Value *Rhs, + InstCombiner::BuilderTy &Builder) { + + Value *LhsInt, *RhsInt; + APInt LhsMask, RhsMask; + uint64_t LhsOffset, RhsOffset; + bool IsLhsShlNUW, IsLhsShlNSW, IsRhsShlNUW, IsRhsShlNSW; + if (!matchZExtedSubInteger(Lhs, LhsInt, LhsMask, LhsOffset, IsLhsShlNUW, + IsLhsShlNSW)) + return nullptr; + if (!matchZExtedSubInteger(Rhs, RhsInt, RhsMask, RhsOffset, IsRhsShlNUW, + IsRhsShlNSW)) + return nullptr; + if (LhsInt != RhsInt || LhsOffset != RhsOffset) + return nullptr; + + APInt Mask = LhsMask | RhsMask; + + Type *DestTy = Lhs->getType(); + Value *Res = Builder.CreateShl( + Builder.CreateZExt( + Builder.CreateAnd(LhsInt, Mask, LhsInt->getName() + ".mask"), DestTy, + LhsInt->getName() + ".zext"), + ConstantInt::get(DestTy, LhsOffset), "", IsLhsShlNUW && IsRhsShlNUW, + IsLhsShlNSW && IsRhsShlNSW); + Res->takeName(Lhs); + return Res; +} + // A decomposition of ((X & Mask) * Factor). The NUW / NSW bools // track these properities for preservation. Note that we can decompose // equivalent select form of this expression (e.g. (!(X & Mask) ? 0 : Mask * @@ -3841,6 +3917,8 @@ static Value *foldBitmaskMul(Value *Op0, Value *Op1, Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS) { if (Value *Res = foldBitmaskMul(LHS, RHS, Builder)) return Res; + if (Value *Res = foldIntegerRepackThroughZExt(LHS, RHS, Builder)) + return Res; return nullptr; } @@ -3973,7 +4051,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { /*NSW=*/true, /*NUW=*/true)) return R; - if (Value *Res = foldBitmaskMul(I.getOperand(0), I.getOperand(1), Builder)) + if (Value *Res = foldDisjointOr(I.getOperand(0), I.getOperand(1))) return replaceInstUsesWith(I, Res); if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1))) diff --git a/llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll b/llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll new file mode 100644 index 0000000000000..c90f08b7322ac --- /dev/null +++ b/llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll @@ -0,0 +1,242 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=instcombine %s -S | FileCheck %s + +declare void @use.i32(i32) +declare void @use.i64(i64) + +define i64 @full_shl(i32 %x) { +; CHECK-LABEL: define i64 @full_shl( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[X_ZEXT:%.*]] = zext i32 [[X]] to i64 +; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[X_ZEXT]], 24 +; CHECK-NEXT: ret i64 [[LO_SHL]] +; + %lo = and i32 %x, u0xffff + %lo.zext = zext nneg i32 %lo to i64 + %lo.shl = shl nuw nsw i64 %lo.zext, 24 + + %hi = lshr i32 %x, 16 + %hi.zext = zext nneg i32 %hi to i64 + %hi.shl = shl nuw nsw i64 %hi.zext, 40 + + %res = or disjoint i64 %lo.shl, %hi.shl + ret i64 %res +} + +define <2 x i64> @full_shl_vec(<2 x i32> %v) { +; CHECK-LABEL: define <2 x i64> @full_shl_vec( +; CHECK-SAME: <2 x i32> [[V:%.*]]) { +; CHECK-NEXT: [[V_ZEXT:%.*]] = zext <2 x i32> [[V]] to <2 x i64> +; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw <2 x i64> [[V_ZEXT]], splat (i64 24) +; CHECK-NEXT: ret <2 x i64> [[LO_SHL]] +; + %lo = and <2 x i32> %v, splat(i32 u0xffff) + %lo.zext = zext nneg <2 x i32> %lo to <2 x i64> + %lo.shl = shl nuw nsw <2 x i64> %lo.zext, splat(i64 24) + + %hi = lshr <2 x i32> %v, splat(i32 16) + %hi.zext = zext nneg <2 x i32> %hi to <2 x i64> + %hi.shl = shl nuw nsw <2 x i64> %hi.zext, splat(i64 40) + + %res = or disjoint <2 x i64> %lo.shl, %hi.shl + ret <2 x i64> %res +} + +; u0xaabbccdd = -1430532899 +define i64 @partial_shl(i32 %x) { +; CHECK-LABEL: define i64 @partial_shl( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[X_MASK:%.*]] = and i32 [[X]], -1430532899 +; CHECK-NEXT: [[X_ZEXT:%.*]] = zext i32 [[X_MASK]] to i64 +; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[X_ZEXT]], 24 +; CHECK-NEXT: ret i64 [[LO_SHL]] +; + %lo = and i32 %x, u0xccdd + %lo.zext = zext nneg i32 %lo to i64 + %lo.shl = shl nuw nsw i64 %lo.zext, 24 + + %hi = lshr i32 %x, 16 + %hi.mask = and i32 %hi, u0xaabb + %hi.zext = zext nneg i32 %hi.mask to i64 + %hi.shl = shl nuw nsw i64 %hi.zext, 40 + + %res = or disjoint i64 %lo.shl, %hi.shl + ret i64 %res +} + +define i64 @shl_multi_use_shl(i32 %x) { +; CHECK-LABEL: define i64 @shl_multi_use_shl( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[X]], 24 +; CHECK-NEXT: [[LO_SHL:%.*]] = zext i32 [[TMP1]] to i64 +; CHECK-NEXT: call void @use.i64(i64 [[LO_SHL]]) +; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16 +; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64 +; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40 +; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[HI_SHL]], [[LO_SHL]] +; CHECK-NEXT: ret i64 [[RES]] +; + %lo = and i32 %x, u0x00ff + %lo.zext = zext nneg i32 %lo to i64 + %lo.shl = shl nuw nsw i64 %lo.zext, 24 + call void @use.i64(i64 %lo.shl) + + %hi = lshr i32 %x, 16 + %hi.zext = zext nneg i32 %hi to i64 + %hi.shl = shl nuw nsw i64 %hi.zext, 40 + + %res = or disjoint i64 %lo.shl, %hi.shl + ret i64 %res +} + +define i64 @shl_multi_use_zext(i32 %x) { +; CHECK-LABEL: define i64 @shl_multi_use_zext( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[LO:%.*]] = and i32 [[X]], 255 +; CHECK-NEXT: [[LO_ZEXT:%.*]] = zext nneg i32 [[LO]] to i64 +; CHECK-NEXT: call void @use.i64(i64 [[LO_ZEXT]]) +; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[LO_ZEXT]], 24 +; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16 +; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64 +; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40 +; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[LO_SHL]], [[HI_SHL]] +; CHECK-NEXT: ret i64 [[RES]] +; + %lo = and i32 %x, u0x00ff + %lo.zext = zext nneg i32 %lo to i64 + call void @use.i64(i64 %lo.zext) + %lo.shl = shl nuw nsw i64 %lo.zext, 24 + + %hi = lshr i32 %x, 16 + %hi.zext = zext nneg i32 %hi to i64 + %hi.shl = shl nuw nsw i64 %hi.zext, 40 + + %res = or disjoint i64 %lo.shl, %hi.shl + ret i64 %res +} + +define i64 @shl_multi_use_lshr(i32 %x) { +; CHECK-LABEL: define i64 @shl_multi_use_lshr( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[X]], 24 +; CHECK-NEXT: [[LO_SHL:%.*]] = zext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16 +; CHECK-NEXT: call void @use.i32(i32 [[HI]]) +; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64 +; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40 +; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[HI_SHL]], [[LO_SHL]] +; CHECK-NEXT: ret i64 [[RES]] +; + %lo = and i32 %x, u0x00ff + %lo.zext = zext nneg i32 %lo to i64 + %lo.shl = shl nuw nsw i64 %lo.zext, 24 + + %hi = lshr i32 %x, 16 + call void @use.i32(i32 %hi) + %hi.zext = zext nneg i32 %hi to i64 + %hi.shl = shl nuw nsw i64 %hi.zext, 40 + + %res = or disjoint i64 %lo.shl, %hi.shl + ret i64 %res +} + +define i64 @shl_non_disjoint(i32 %x) { +; CHECK-LABEL: define i64 @shl_non_disjoint( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[LO:%.*]] = and i32 [[X]], 16711680 +; CHECK-NEXT: [[LO_ZEXT:%.*]] = zext nneg i32 [[LO]] to i64 +; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[LO_ZEXT]], 24 +; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16 +; CHECK-NEXT: call void @use.i32(i32 [[HI]]) +; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64 +; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40 +; CHECK-NEXT: [[RES:%.*]] = or i64 [[LO_SHL]], [[HI_SHL]] +; CHECK-NEXT: ret i64 [[RES]] +; + %lo = and i32 %x, u0x00ff0000 + %lo.zext = zext nneg i32 %lo to i64 + %lo.shl = shl nuw nsw i64 %lo.zext, 24 + + %hi = lshr i32 %x, 16 + call void @use.i32(i32 %hi) + %hi.zext = zext nneg i32 %hi to i64 + %hi.shl = shl nuw nsw i64 %hi.zext, 40 + + %res = or i64 %lo.shl, %hi.shl + ret i64 %res +} + +define i64 @combine(i32 %lower, i32 %upper) { +; CHECK-LABEL: define i64 @combine( +; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) { +; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64 +; CHECK-NEXT: [[UPPER_ZEXT:%.*]] = zext i32 [[UPPER]] to i64 +; CHECK-NEXT: [[S_0:%.*]] = shl nuw i64 [[UPPER_ZEXT]], 32 +; CHECK-NEXT: [[O_3:%.*]] = or disjoint i64 [[S_0]], [[BASE]] +; CHECK-NEXT: ret i64 [[O_3]] +; + %base = zext i32 %lower to i64 + + %u.0 = and i32 %upper, u0xff + %z.0 = zext i32 %u.0 to i64 + %s.0 = shl i64 %z.0, 32 + %o.0 = or i64 %base, %s.0 + + %r.1 = lshr i32 %upper, 8 + %u.1 = and i32 %r.1, u0xff + %z.1 = zext i32 %u.1 to i64 + %s.1 = shl i64 %z.1, 40 + %o.1 = or i64 %o.0, %s.1 + + %r.2 = lshr i32 %upper, 16 + %u.2 = and i32 %r.2, u0xff + %z.2 = zext i32 %u.2 to i64 + %s.2 = shl i64 %z.2, 48 + %o.2 = or i64 %o.1, %s.2 + + %r.3 = lshr i32 %upper, 24 + %u.3 = and i32 %r.3, u0xff + %z.3 = zext i32 %u.3 to i64 + %s.3 = shl i64 %z.3, 56 + %o.3 = or i64 %o.2, %s.3 + + ret i64 %o.3 +} + +define i64 @combine_2(i32 %lower, i32 %upper) { +; CHECK-LABEL: define i64 @combine_2( +; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) { +; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64 +; CHECK-NEXT: [[S_03:%.*]] = zext i32 [[UPPER]] to i64 +; CHECK-NEXT: [[O:%.*]] = shl nuw i64 [[S_03]], 32 +; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[O]], [[BASE]] +; CHECK-NEXT: ret i64 [[RES]] +; + %base = zext i32 %lower to i64 + + %u.0 = and i32 %upper, u0xff + %z.0 = zext i32 %u.0 to i64 + %s.0 = shl i64 %z.0, 32 + + %r.1 = lshr i32 %upper, 8 + %u.1 = and i32 %r.1, u0xff + %z.1 = zext i32 %u.1 to i64 + %s.1 = shl i64 %z.1, 40 + %o.1 = or i64 %s.0, %s.1 + + %r.2 = lshr i32 %upper, 16 + %u.2 = and i32 %r.2, u0xff + %z.2 = zext i32 %u.2 to i64 + %s.2 = shl i64 %z.2, 48 + + %r.3 = lshr i32 %upper, 24 + %u.3 = and i32 %r.3, u0xff + %z.3 = zext i32 %u.3 to i64 + %s.3 = shl i64 %z.3, 56 + %o.3 = or i64 %s.2, %s.3 + + %o = or i64 %o.1, %o.3 + %res = or i64 %o, %base + + ret i64 %res +}