diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 3dd8414e950ae3..ec505381cc868f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -322,20 +322,15 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return BinaryOperator::Create(Instruction::And, NewShift, NewMask); } -/// If we have a shift-by-constant of a bin op (bitwise logic op or add/sub w/ -/// shl) that itself has a shift-by-constant operand with identical opcode, we -/// may be able to convert that into 2 independent shifts followed by the logic -/// op. This eliminates a use of an intermediate value (reduces dependency -/// chain). -static Instruction *foldShiftOfShiftedBinOp(BinaryOperator &I, +/// If we have a shift-by-constant of a bitwise logic op that itself has a +/// shift-by-constant operand with identical opcode, we may be able to convert +/// that into 2 independent shifts followed by the logic op. This eliminates a +/// a use of an intermediate value (reduces dependency chain). +static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { assert(I.isShift() && "Expected a shift as input"); - auto *BinInst = dyn_cast(I.getOperand(0)); - if (!BinInst || - (!BinInst->isBitwiseLogicOp() && - BinInst->getOpcode() != Instruction::Add && - BinInst->getOpcode() != Instruction::Sub) || - !BinInst->hasOneUse()) + auto *LogicInst = dyn_cast(I.getOperand(0)); + if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse()) return nullptr; Constant *C0, *C1; @@ -343,12 +338,6 @@ static Instruction *foldShiftOfShiftedBinOp(BinaryOperator &I, return nullptr; Instruction::BinaryOps ShiftOpcode = I.getOpcode(); - // Transform for add/sub only works with shl. - if ((BinInst->getOpcode() == Instruction::Add || - BinInst->getOpcode() == Instruction::Sub) && - ShiftOpcode != Instruction::Shl) - return nullptr; - Type *Ty = I.getType(); // Find a matching one-use shift by constant. The fold is not valid if the sum @@ -363,25 +352,19 @@ static Instruction *foldShiftOfShiftedBinOp(BinaryOperator &I, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); }; - // Logic ops and Add are commutative, so check each operand for a match. Sub - // is not so we cannot reoder if we match operand(1) and need to keep the - // operands in their original positions. - bool FirstShiftIsOp1 = false; - if (matchFirstShift(BinInst->getOperand(0))) - Y = BinInst->getOperand(1); - else if (matchFirstShift(BinInst->getOperand(1))) { - Y = BinInst->getOperand(0); - FirstShiftIsOp1 = BinInst->getOpcode() == Instruction::Sub; - } else + // Logic ops are commutative, so check each operand for a match. + if (matchFirstShift(LogicInst->getOperand(0))) + Y = LogicInst->getOperand(1); + else if (matchFirstShift(LogicInst->getOperand(1))) + Y = LogicInst->getOperand(0); + else return nullptr; - // shift (binop (shift X, C0), Y), C1 -> binop (shift X, C0+C1), (shift Y, C1) + // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1); Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1); - Value *Op1 = FirstShiftIsOp1 ? NewShift2 : NewShift1; - Value *Op2 = FirstShiftIsOp1 ? NewShift1 : NewShift2; - return BinaryOperator::Create(BinInst->getOpcode(), Op1, Op2); + return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); } Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { @@ -480,7 +463,7 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { return replaceOperand(I, 1, Rem); } - if (Instruction *Logic = foldShiftOfShiftedBinOp(I, Builder)) + if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) return Logic; return nullptr; diff --git a/llvm/test/Transforms/InstCombine/shift-logic.ll b/llvm/test/Transforms/InstCombine/shift-logic.ll index e0fbe819666683..ceafffca1a654e 100644 --- a/llvm/test/Transforms/InstCombine/shift-logic.ll +++ b/llvm/test/Transforms/InstCombine/shift-logic.ll @@ -335,9 +335,9 @@ define i64 @lshr_mul_negative_nsw(i64 %0) { define i8 @shl_add(i8 %x, i8 %y) { ; CHECK-LABEL: @shl_add( -; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 5 -; CHECK-NEXT: [[TMP2:%.*]] = shl i8 [[Y:%.*]], 2 -; CHECK-NEXT: [[SH1:%.*]] = add i8 [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[SH0:%.*]] = shl i8 [[X:%.*]], 3 +; CHECK-NEXT: [[R:%.*]] = add i8 [[SH0]], [[Y:%.*]] +; CHECK-NEXT: [[SH1:%.*]] = shl i8 [[R]], 2 ; CHECK-NEXT: ret i8 [[SH1]] ; %sh0 = shl i8 %x, 3 @@ -348,9 +348,9 @@ define i8 @shl_add(i8 %x, i8 %y) { define <2 x i8> @shl_add_nonuniform(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @shl_add_nonuniform( -; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[TMP2:%.*]] = shl <2 x i8> [[Y:%.*]], -; CHECK-NEXT: [[SH1:%.*]] = add <2 x i8> [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[SH0:%.*]] = shl <2 x i8> [[X:%.*]], +; CHECK-NEXT: [[R:%.*]] = add <2 x i8> [[SH0]], [[Y:%.*]] +; CHECK-NEXT: [[SH1:%.*]] = shl <2 x i8> [[R]], ; CHECK-NEXT: ret <2 x i8> [[SH1]] ; %sh0 = shl <2 x i8> %x, @@ -363,9 +363,9 @@ define <2 x i8> @shl_add_nonuniform(<2 x i8> %x, <2 x i8> %y) { define <2 x i64> @shl_add_undef(<2 x i64> %x, <2 x i64> %py) { ; CHECK-LABEL: @shl_add_undef( ; CHECK-NEXT: [[Y:%.*]] = srem <2 x i64> [[PY:%.*]], -; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i64> [[X:%.*]], -; CHECK-NEXT: [[TMP2:%.*]] = shl <2 x i64> [[Y]], -; CHECK-NEXT: [[SH1:%.*]] = add <2 x i64> [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[SH0:%.*]] = shl <2 x i64> [[X:%.*]], +; CHECK-NEXT: [[R:%.*]] = add <2 x i64> [[Y]], [[SH0]] +; CHECK-NEXT: [[SH1:%.*]] = shl <2 x i64> [[R]], ; CHECK-NEXT: ret <2 x i64> [[SH1]] ; %y = srem <2 x i64> %py, ; thwart complexity-based canonicalization @@ -419,9 +419,9 @@ define <2 x i64> @lshr_add_undef(<2 x i64> %x, <2 x i64> %py) { define i8 @shl_sub(i8 %x, i8 %y) { ; CHECK-LABEL: @shl_sub( -; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 5 -; CHECK-NEXT: [[TMP2:%.*]] = shl i8 [[Y:%.*]], 2 -; CHECK-NEXT: [[SH1:%.*]] = sub i8 [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[SH0:%.*]] = shl i8 [[X:%.*]], 3 +; CHECK-NEXT: [[R:%.*]] = sub i8 [[SH0]], [[Y:%.*]] +; CHECK-NEXT: [[SH1:%.*]] = shl i8 [[R]], 2 ; CHECK-NEXT: ret i8 [[SH1]] ; %sh0 = shl i8 %x, 3 @@ -433,9 +433,9 @@ define i8 @shl_sub(i8 %x, i8 %y) { ; Make sure we don't commute operands for sub define i8 @shl_sub_no_commute(i8 %x, i8 %y) { ; CHECK-LABEL: @shl_sub_no_commute( -; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[Y:%.*]], 5 -; CHECK-NEXT: [[TMP2:%.*]] = shl i8 [[X:%.*]], 2 -; CHECK-NEXT: [[SH1:%.*]] = sub i8 [[TMP2]], [[TMP1]] +; CHECK-NEXT: [[SH0:%.*]] = shl i8 [[Y:%.*]], 3 +; CHECK-NEXT: [[R:%.*]] = sub i8 [[X:%.*]], [[SH0]] +; CHECK-NEXT: [[SH1:%.*]] = shl i8 [[R]], 2 ; CHECK-NEXT: ret i8 [[SH1]] ; %sh0 = shl i8 %y, 3 @@ -446,9 +446,9 @@ define i8 @shl_sub_no_commute(i8 %x, i8 %y) { define <2 x i8> @shl_sub_nonuniform(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @shl_sub_nonuniform( -; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[TMP2:%.*]] = shl <2 x i8> [[Y:%.*]], -; CHECK-NEXT: [[SH1:%.*]] = sub <2 x i8> [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[SH0:%.*]] = shl <2 x i8> [[X:%.*]], +; CHECK-NEXT: [[R:%.*]] = sub <2 x i8> [[SH0]], [[Y:%.*]] +; CHECK-NEXT: [[SH1:%.*]] = shl <2 x i8> [[R]], ; CHECK-NEXT: ret <2 x i8> [[SH1]] ; %sh0 = shl <2 x i8> %x, @@ -461,9 +461,9 @@ define <2 x i8> @shl_sub_nonuniform(<2 x i8> %x, <2 x i8> %y) { define <2 x i64> @shl_sub_undef(<2 x i64> %x, <2 x i64> %py) { ; CHECK-LABEL: @shl_sub_undef( ; CHECK-NEXT: [[Y:%.*]] = srem <2 x i64> [[PY:%.*]], -; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i64> [[X:%.*]], -; CHECK-NEXT: [[TMP2:%.*]] = shl <2 x i64> [[Y]], -; CHECK-NEXT: [[SH1:%.*]] = sub <2 x i64> [[TMP2]], [[TMP1]] +; CHECK-NEXT: [[SH0:%.*]] = shl <2 x i64> [[X:%.*]], +; CHECK-NEXT: [[R:%.*]] = sub <2 x i64> [[Y]], [[SH0]] +; CHECK-NEXT: [[SH1:%.*]] = shl <2 x i64> [[R]], ; CHECK-NEXT: ret <2 x i64> [[SH1]] ; %y = srem <2 x i64> %py, ; thwart complexity-based canonicalization