Skip to content

Commit

Permalink
[AggressiveInstCombine] Add udiv and urem instrs to TruncInstComb…
Browse files Browse the repository at this point in the history
…ine DAG

Add `udiv` and `urem` instructions to the DAG post-dominated by `trunc`,
allowing TruncInstCombine to reduce bitwidth of expressions containing these
instructions. It is sufficient to require that all truncated bits of both
operands are zeros: https://alive2.llvm.org/ce/z/yiithn
(`urem` case is identical).

Differential Revision: https://reviews.llvm.org/D109515
  • Loading branch information
anton-afanasyev committed Sep 10, 2021
1 parent ea7b2c1 commit 54d8ebb
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 33 deletions.
22 changes: 20 additions & 2 deletions llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
Expand Up @@ -65,6 +65,8 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
case Instruction::Shl:
case Instruction::LShr:
case Instruction::AShr:
case Instruction::UDiv:
case Instruction::URem:
Ops.push_back(I->getOperand(0));
Ops.push_back(I->getOperand(1));
break;
Expand Down Expand Up @@ -134,6 +136,8 @@ bool TruncInstCombine::buildTruncExpressionDag() {
case Instruction::Shl:
case Instruction::LShr:
case Instruction::AShr:
case Instruction::UDiv:
case Instruction::URem:
case Instruction::Select: {
SmallVector<Value *, 2> Operands;
getRelevantOperands(I, Operands);
Expand All @@ -143,7 +147,7 @@ bool TruncInstCombine::buildTruncExpressionDag() {
default:
// TODO: Can handle more cases here:
// 1. shufflevector, extractelement, insertelement
// 2. udiv, urem
// 2. sdiv, srem
// 3. phi node(and loop handling)
// ...
return false;
Expand Down Expand Up @@ -306,6 +310,18 @@ Type *TruncInstCombine::getBestTruncatedType() {
return nullptr;
Itr.second.MinBitWidth = MinBitWidth;
}
if (I->getOpcode() == Instruction::UDiv ||
I->getOpcode() == Instruction::URem) {
unsigned MinBitWidth = 0;
for (const auto &Op : I->operands()) {
KnownBits Known = computeKnownBits(Op);
MinBitWidth =
std::max(Known.getMaxValue().getActiveBits(), MinBitWidth);
if (MinBitWidth >= OrigBitWidth)
return nullptr;
}
Itr.second.MinBitWidth = MinBitWidth;
}
}

// Calculate minimum allowed bit-width allowed for shrinking the currently
Expand Down Expand Up @@ -397,7 +413,9 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
case Instruction::Xor:
case Instruction::Shl:
case Instruction::LShr:
case Instruction::AShr: {
case Instruction::AShr:
case Instruction::UDiv:
case Instruction::URem: {
Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);
Expand Down
55 changes: 24 additions & 31 deletions llvm/test/Transforms/AggressiveInstCombine/trunc_udivrem.ll
Expand Up @@ -3,10 +3,9 @@

define i16 @udiv_one_arg(i8 %x) {
; CHECK-LABEL: @udiv_one_arg(
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
; CHECK-NEXT: [[DIV:%.*]] = udiv i32 [[ZEXT]], 42
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[DIV]] to i16
; CHECK-NEXT: ret i16 [[TRUNC]]
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16
; CHECK-NEXT: [[DIV:%.*]] = udiv i16 [[ZEXT]], 42
; CHECK-NEXT: ret i16 [[DIV]]
;
%zext = zext i8 %x to i32
%div = udiv i32 %zext, 42
Expand All @@ -16,11 +15,8 @@ define i16 @udiv_one_arg(i8 %x) {

define i16 @udiv_two_args(i16 %x, i16 %y) {
; CHECK-LABEL: @udiv_two_args(
; CHECK-NEXT: [[ZEXTX:%.*]] = zext i16 [[X:%.*]] to i32
; CHECK-NEXT: [[ZEXTY:%.*]] = zext i16 [[Y:%.*]] to i32
; CHECK-NEXT: [[I0:%.*]] = udiv i32 [[ZEXTX]], [[ZEXTY]]
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[I0]] to i16
; CHECK-NEXT: ret i16 [[R]]
; CHECK-NEXT: [[I0:%.*]] = udiv i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i16 [[I0]]
;
%zextx = zext i16 %x to i32
%zexty = zext i16 %y to i32
Expand All @@ -29,6 +25,7 @@ define i16 @udiv_two_args(i16 %x, i16 %y) {
ret i16 %r
}

; Negative test
define i16 @udiv_big_const(i8 %x) {
; CHECK-LABEL: @udiv_big_const(
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
Expand All @@ -44,17 +41,17 @@ define i16 @udiv_big_const(i8 %x) {

define <2 x i16> @udiv_vector(<2 x i8> %x) {
; CHECK-LABEL: @udiv_vector(
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
; CHECK-NEXT: [[S:%.*]] = udiv <2 x i32> [[Z]], <i32 4, i32 10>
; CHECK-NEXT: [[T:%.*]] = trunc <2 x i32> [[S]] to <2 x i16>
; CHECK-NEXT: ret <2 x i16> [[T]]
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16>
; CHECK-NEXT: [[S:%.*]] = udiv <2 x i16> [[Z]], <i16 4, i16 10>
; CHECK-NEXT: ret <2 x i16> [[S]]
;
%z = zext <2 x i8> %x to <2 x i32>
%s = udiv <2 x i32> %z, <i32 4, i32 10>
%t = trunc <2 x i32> %s to <2 x i16>
ret <2 x i16> %t
}

; Negative test: can only fold to <2 x i16>, requiring new vector type
define <2 x i8> @udiv_vector_need_new_vector_type(<2 x i8> %x) {
; CHECK-LABEL: @udiv_vector_need_new_vector_type(
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
Expand All @@ -68,6 +65,7 @@ define <2 x i8> @udiv_vector_need_new_vector_type(<2 x i8> %x) {
ret <2 x i8> %t
}

; Negative test
define <2 x i16> @udiv_vector_big_const(<2 x i8> %x) {
; CHECK-LABEL: @udiv_vector_big_const(
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
Expand All @@ -83,11 +81,8 @@ define <2 x i16> @udiv_vector_big_const(<2 x i8> %x) {

define i16 @udiv_exact(i16 %x, i16 %y) {
; CHECK-LABEL: @udiv_exact(
; CHECK-NEXT: [[ZEXTX:%.*]] = zext i16 [[X:%.*]] to i32
; CHECK-NEXT: [[ZEXTY:%.*]] = zext i16 [[Y:%.*]] to i32
; CHECK-NEXT: [[I0:%.*]] = udiv exact i32 [[ZEXTX]], [[ZEXTY]]
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[I0]] to i16
; CHECK-NEXT: ret i16 [[R]]
; CHECK-NEXT: [[I0:%.*]] = udiv exact i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i16 [[I0]]
;
%zextx = zext i16 %x to i32
%zexty = zext i16 %y to i32
Expand All @@ -99,10 +94,9 @@ define i16 @udiv_exact(i16 %x, i16 %y) {

define i16 @urem_one_arg(i8 %x) {
; CHECK-LABEL: @urem_one_arg(
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
; CHECK-NEXT: [[DIV:%.*]] = urem i32 [[ZEXT]], 42
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[DIV]] to i16
; CHECK-NEXT: ret i16 [[TRUNC]]
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16
; CHECK-NEXT: [[DIV:%.*]] = urem i16 [[ZEXT]], 42
; CHECK-NEXT: ret i16 [[DIV]]
;
%zext = zext i8 %x to i32
%div = urem i32 %zext, 42
Expand All @@ -112,11 +106,8 @@ define i16 @urem_one_arg(i8 %x) {

define i16 @urem_two_args(i16 %x, i16 %y) {
; CHECK-LABEL: @urem_two_args(
; CHECK-NEXT: [[ZEXTX:%.*]] = zext i16 [[X:%.*]] to i32
; CHECK-NEXT: [[ZEXTY:%.*]] = zext i16 [[Y:%.*]] to i32
; CHECK-NEXT: [[I0:%.*]] = urem i32 [[ZEXTX]], [[ZEXTY]]
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[I0]] to i16
; CHECK-NEXT: ret i16 [[R]]
; CHECK-NEXT: [[I0:%.*]] = urem i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i16 [[I0]]
;
%zextx = zext i16 %x to i32
%zexty = zext i16 %y to i32
Expand All @@ -125,6 +116,7 @@ define i16 @urem_two_args(i16 %x, i16 %y) {
ret i16 %r
}

; Negative test
define i16 @urem_big_const(i8 %x) {
; CHECK-LABEL: @urem_big_const(
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
Expand All @@ -140,17 +132,17 @@ define i16 @urem_big_const(i8 %x) {

define <2 x i16> @urem_vector(<2 x i8> %x) {
; CHECK-LABEL: @urem_vector(
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
; CHECK-NEXT: [[S:%.*]] = urem <2 x i32> [[Z]], <i32 4, i32 10>
; CHECK-NEXT: [[T:%.*]] = trunc <2 x i32> [[S]] to <2 x i16>
; CHECK-NEXT: ret <2 x i16> [[T]]
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16>
; CHECK-NEXT: [[S:%.*]] = urem <2 x i16> [[Z]], <i16 4, i16 10>
; CHECK-NEXT: ret <2 x i16> [[S]]
;
%z = zext <2 x i8> %x to <2 x i32>
%s = urem <2 x i32> %z, <i32 4, i32 10>
%t = trunc <2 x i32> %s to <2 x i16>
ret <2 x i16> %t
}

; Negative test: can only fold to <2 x i16>, requiring new vector type
define <2 x i8> @urem_vector_need_new_vector_type(<2 x i8> %x) {
; CHECK-LABEL: @urem_vector_need_new_vector_type(
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
Expand All @@ -164,6 +156,7 @@ define <2 x i8> @urem_vector_need_new_vector_type(<2 x i8> %x) {
ret <2 x i8> %t
}

; Negative test
define <2 x i16> @urem_vector_big_const(<2 x i8> %x) {
; CHECK-LABEL: @urem_vector_big_const(
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
Expand Down

0 comments on commit 54d8ebb

Please sign in to comment.