-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[LLVM][IR] Teach constant integer binop folds about vector ConstantInts. #115739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVM][IR] Teach constant integer binop folds about vector ConstantInts. #115739
Conversation
The existing logic mostly works with the main changes being: * Use getScalarSizeInBits instead of IntegerType::getBitWidth * Use ConstantInt::get(Type* instead of ConstantInt::get(LLVMContext
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-llvm-transforms Author: Paul Walker (paulwalker-arm) ChangesThe existing logic mostly works with the main changes being:
Full diff: https://github.com/llvm/llvm-project/pull/115739.diff 10 Files Affected:
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index cfe87937c372cd..2dbc6785c08b9d 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -231,26 +231,20 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
return nullptr;
case Instruction::ZExt:
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
- return ConstantInt::get(V->getContext(),
- CI->getValue().zext(BitWidth));
+ uint32_t BitWidth = DestTy->getScalarSizeInBits();
+ return ConstantInt::get(DestTy, CI->getValue().zext(BitWidth));
}
return nullptr;
case Instruction::SExt:
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
- return ConstantInt::get(V->getContext(),
- CI->getValue().sext(BitWidth));
+ uint32_t BitWidth = DestTy->getScalarSizeInBits();
+ return ConstantInt::get(DestTy, CI->getValue().sext(BitWidth));
}
return nullptr;
case Instruction::Trunc: {
- if (V->getType()->isVectorTy())
- return nullptr;
-
- uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth();
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- return ConstantInt::get(V->getContext(),
- CI->getValue().trunc(DestBitWidth));
+ uint32_t BitWidth = DestTy->getScalarSizeInBits();
+ return ConstantInt::get(DestTy, CI->getValue().trunc(BitWidth));
}
return nullptr;
@@ -807,44 +801,44 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
default:
break;
case Instruction::Add:
- return ConstantInt::get(CI1->getContext(), C1V + C2V);
+ return ConstantInt::get(C1->getType(), C1V + C2V);
case Instruction::Sub:
- return ConstantInt::get(CI1->getContext(), C1V - C2V);
+ return ConstantInt::get(C1->getType(), C1V - C2V);
case Instruction::Mul:
- return ConstantInt::get(CI1->getContext(), C1V * C2V);
+ return ConstantInt::get(C1->getType(), C1V * C2V);
case Instruction::UDiv:
assert(!CI2->isZero() && "Div by zero handled above");
- return ConstantInt::get(CI1->getContext(), C1V.udiv(C2V));
+ return ConstantInt::get(CI1->getType(), C1V.udiv(C2V));
case Instruction::SDiv:
assert(!CI2->isZero() && "Div by zero handled above");
if (C2V.isAllOnes() && C1V.isMinSignedValue())
return PoisonValue::get(CI1->getType()); // MIN_INT / -1 -> poison
- return ConstantInt::get(CI1->getContext(), C1V.sdiv(C2V));
+ return ConstantInt::get(CI1->getType(), C1V.sdiv(C2V));
case Instruction::URem:
assert(!CI2->isZero() && "Div by zero handled above");
- return ConstantInt::get(CI1->getContext(), C1V.urem(C2V));
+ return ConstantInt::get(C1->getType(), C1V.urem(C2V));
case Instruction::SRem:
assert(!CI2->isZero() && "Div by zero handled above");
if (C2V.isAllOnes() && C1V.isMinSignedValue())
- return PoisonValue::get(CI1->getType()); // MIN_INT % -1 -> poison
- return ConstantInt::get(CI1->getContext(), C1V.srem(C2V));
+ return PoisonValue::get(C1->getType()); // MIN_INT % -1 -> poison
+ return ConstantInt::get(C1->getType(), C1V.srem(C2V));
case Instruction::And:
- return ConstantInt::get(CI1->getContext(), C1V & C2V);
+ return ConstantInt::get(C1->getType(), C1V & C2V);
case Instruction::Or:
- return ConstantInt::get(CI1->getContext(), C1V | C2V);
+ return ConstantInt::get(C1->getType(), C1V | C2V);
case Instruction::Xor:
- return ConstantInt::get(CI1->getContext(), C1V ^ C2V);
+ return ConstantInt::get(C1->getType(), C1V ^ C2V);
case Instruction::Shl:
if (C2V.ult(C1V.getBitWidth()))
- return ConstantInt::get(CI1->getContext(), C1V.shl(C2V));
+ return ConstantInt::get(C1->getType(), C1V.shl(C2V));
return PoisonValue::get(C1->getType()); // too big shift is poison
case Instruction::LShr:
if (C2V.ult(C1V.getBitWidth()))
- return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V));
+ return ConstantInt::get(C1->getType(), C1V.lshr(C2V));
return PoisonValue::get(C1->getType()); // too big shift is poison
case Instruction::AShr:
if (C2V.ult(C1V.getBitWidth()))
- return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V));
+ return ConstantInt::get(C1->getType(), C1V.ashr(C2V));
return PoisonValue::get(C1->getType()); // too big shift is poison
}
}
@@ -877,7 +871,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
return ConstantFP::get(C1->getContext(), C3V);
}
}
- } else if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
+ }
+
+ if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
// Fast path for splatted constants.
if (Constant *C2Splat = C2->getSplatValue()) {
if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 7ae397871bdea2..3d6c4ad780dc24 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -441,6 +441,13 @@ Constant *Constant::getAggregateElement(unsigned Elt) const {
? CAZ->getElementValue(Elt)
: nullptr;
+ if (const auto *CI = dyn_cast<ConstantInt>(this))
+ return Elt < cast<VectorType>(getType())
+ ->getElementCount()
+ .getKnownMinValue()
+ ? ConstantInt::get(getContext(), CI->getValue())
+ : nullptr;
+
// FIXME: getNumElements() will fail for non-fixed vector types.
if (isa<ScalableVectorType>(getType()))
return nullptr;
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 4b1159cf07e710..4825e588aa0856 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
declare void @use(i8)
declare void @use_i1(i1)
diff --git a/llvm/test/Transforms/InstCombine/div.ll b/llvm/test/Transforms/InstCombine/div.ll
index 33a8e12dfa1a68..6344966d6cac3b 100644
--- a/llvm/test/Transforms/InstCombine/div.ll
+++ b/llvm/test/Transforms/InstCombine/div.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
declare void @use(i32)
diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index e38ab1b9622b2c..e3108fc54c4f4c 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
declare i32 @llvm.abs.i32(i32, i1)
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 4a886afd78a5f0..95f89e4ce11cd5 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s --check-prefixes=CHECK,CONSTVEC
+; RUN: opt < %s -passes=instcombine -S -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=CHECK,CONSTSPLAT
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128-n32:64"
declare void @use(i32)
@@ -399,10 +400,15 @@ define i32 @test30(i32 %A) {
}
define <2 x i32> @test30vec(<2 x i32> %A) {
-; CHECK-LABEL: @test30vec(
-; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
-; CHECK-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962)
-; CHECK-NEXT: ret <2 x i32> [[E]]
+; CONSTVEC-LABEL: @test30vec(
+; CONSTVEC-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
+; CONSTVEC-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962)
+; CONSTVEC-NEXT: ret <2 x i32> [[E]]
+;
+; CONSTSPLAT-LABEL: @test30vec(
+; CONSTSPLAT-NEXT: [[D:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
+; CONSTSPLAT-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[D]], splat (i32 32962)
+; CONSTSPLAT-NEXT: ret <2 x i32> [[E]]
;
%B = or <2 x i32> %A, <i32 32962, i32 32962>
%C = and <2 x i32> %A, <i32 -65536, i32 -65536>
diff --git a/llvm/test/Transforms/InstCombine/rotate.ll b/llvm/test/Transforms/InstCombine/rotate.ll
index ea7c471594da0a..bae50736de0c33 100644
--- a/llvm/test/Transforms/InstCombine/rotate.ll
+++ b/llvm/test/Transforms/InstCombine/rotate.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128"
diff --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll
index d2ee97f39123b0..d72a1849c7dfd6 100644
--- a/llvm/test/Transforms/InstCombine/shift.ll
+++ b/llvm/test/Transforms/InstCombine/shift.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
declare void @use(i64)
declare void @use_i32(i32)
diff --git a/llvm/test/Transforms/InstCombine/xor-ashr.ll b/llvm/test/Transforms/InstCombine/xor-ashr.ll
index 0c0554adcf1230..f5ccdeef2f382b 100644
--- a/llvm/test/Transforms/InstCombine/xor-ashr.ll
+++ b/llvm/test/Transforms/InstCombine/xor-ashr.ll
@@ -1,5 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
+
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"
declare void @use16(i16)
diff --git a/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll b/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll
index 3f1672d66abf0d..b475b8199541d5 100644
--- a/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll
+++ b/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll
@@ -81,8 +81,7 @@ define <1 x i1> @test10() {
; CONSTVEC-NEXT: ret <1 x i1> [[RET]]
;
; CONSTSPLAT-LABEL: @test10(
-; CONSTSPLAT-NEXT: [[RET:%.*]] = icmp eq <1 x i64> splat (i64 -1), zeroinitializer
-; CONSTSPLAT-NEXT: ret <1 x i1> [[RET]]
+; CONSTSPLAT-NEXT: ret <1 x i1> zeroinitializer
;
%ret = icmp eq <1 x i64> <i64 bitcast (<1 x double> <double 0xFFFFFFFFFFFFFFFF> to i64)>, zeroinitializer
ret <1 x i1> %ret
|
; CONSTSPLAT-LABEL: @test30vec( | ||
; CONSTSPLAT-NEXT: [[D:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312) | ||
; CONSTSPLAT-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[D]], splat (i32 32962) | ||
; CONSTSPLAT-NEXT: ret <2 x i32> [[E]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the difference here only in variable naming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. My guess is that we're now going down a path where the original name is preserved. I was going to ignore it but then remembered how it bugs me when I run the update scripts and end up with a bunch of changes unrelated to my PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nowadays the update scripts are supposed to keep the check names stable even if the generated names change :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The existing logic mostly works with the main changes being: