Skip to content

Commit

Permalink
[LLVM][IR] Replace ConstantInt's specialisation of getType() with get…
Browse files Browse the repository at this point in the history
…IntegerType(). (#75217)

The specialisation will not be valid when ConstantInt gains native
support for vector types.

This is largely a mechanical change but with extra attention paid to constant
folding, InstCombineVectorOps.cpp, LoopFlatten.cpp and Verifier.cpp to
remove the need to call `getIntegerType()`.

Co-authored-by: Nikita Popov <github@npopov.com>
  • Loading branch information
paulwalker-arm and nikic committed Dec 18, 2023
1 parent 2f81788 commit dea16eb
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 37 deletions.
7 changes: 4 additions & 3 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3214,7 +3214,7 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
Value *AlignmentValue = EmitScalarExpr(E->getArg(1));
ConstantInt *AlignmentCI = cast<ConstantInt>(AlignmentValue);
if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
llvm::Value::MaximumAlignment);

emitAlignmentAssumption(PtrValue, Ptr,
Expand Down Expand Up @@ -17034,7 +17034,7 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
Value *Op1 = EmitScalarExpr(E->getArg(1));
ConstantInt *AlignmentCI = cast<ConstantInt>(Op0);
if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
llvm::Value::MaximumAlignment);

emitAlignmentAssumption(Op1, E->getArg(1),
Expand Down Expand Up @@ -17272,7 +17272,8 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
Op0, llvm::FixedVectorType::get(ConvertType(E->getType()), 2));

if (getTarget().isLittleEndian())
Index = ConstantInt::get(Index->getType(), 1 - Index->getZExtValue());
Index =
ConstantInt::get(Index->getIntegerType(), 1 - Index->getZExtValue());

return Builder.CreateExtractElement(Unpacked, Index);
}
Expand Down
7 changes: 3 additions & 4 deletions llvm/include/llvm/IR/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,9 @@ class ConstantInt final : public ConstantData {
/// Determine if this constant's value is same as an unsigned char.
bool equalsInt(uint64_t V) const { return Val == V; }

/// getType - Specialize the getType() method to always return an IntegerType,
/// which reduces the amount of casting needed in parts of the compiler.
///
inline IntegerType *getType() const {
/// Variant of the getType() method to always return an IntegerType, which
/// reduces the amount of casting needed in parts of the compiler.
inline IntegerType *getIntegerType() const {
return cast<IntegerType>(Value::getType());
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6079,7 +6079,7 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset,
Type *Int32Ty = Type::getInt32Ty(Ptr->getContext());

auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset);
if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64)
if (!OffsetConstInt || OffsetConstInt->getBitWidth() > 64)
return nullptr;

APInt OffsetInt = OffsetConstInt->getValue().sextOrTrunc(
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/IR/ConstantFold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
}

if (GVAlign > 1) {
unsigned DstWidth = CI2->getType()->getBitWidth();
unsigned DstWidth = CI2->getBitWidth();
unsigned SrcWidth = std::min(DstWidth, Log2(GVAlign));
APInt BitsNotSet(APInt::getLowBitsSet(DstWidth, SrcWidth));

Expand Down
11 changes: 6 additions & 5 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2296,10 +2296,9 @@ void Verifier::verifyFunctionMetadata(
Check(isa<ConstantAsMetadata>(MD->getOperand(0)),
"expected a constant operand for !kcfi_type", MD);
Constant *C = cast<ConstantAsMetadata>(MD->getOperand(0))->getValue();
Check(isa<ConstantInt>(C),
Check(isa<ConstantInt>(C) && isa<IntegerType>(C->getType()),
"expected a constant integer operand for !kcfi_type", MD);
IntegerType *Type = cast<ConstantInt>(C)->getType();
Check(Type->getBitWidth() == 32,
Check(cast<ConstantInt>(C)->getBitWidth() == 32,
"expected a 32-bit integer constant operand for !kcfi_type", MD);
}
}
Expand Down Expand Up @@ -5690,8 +5689,10 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
"vector of ints");

auto *Op3 = cast<ConstantInt>(Call.getArgOperand(2));
Check(Op3->getType()->getBitWidth() <= 32,
"third argument of [us][mul|div]_fix[_sat] must fit within 32 bits");
Check(Op3->getType()->isIntegerTy(),
"third operand of [us][mul|div]_fix[_sat] must be an int type");
Check(Op3->getBitWidth() <= 32,
"third operand of [us][mul|div]_fix[_sat] must fit within 32 bits");

if (ID == Intrinsic::smul_fix || ID == Intrinsic::smul_fix_sat ||
ID == Intrinsic::sdiv_fix || ID == Intrinsic::sdiv_fix_sat) {
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ void PolynomialMultiplyRecognize::promoteTo(Instruction *In,
// Promote immediates.
for (unsigned i = 0, n = In->getNumOperands(); i != n; ++i) {
if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i)))
if (CI->getType()->getBitWidth() < DestBW)
if (CI->getBitWidth() < DestBW)
In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue()));
}
}
Expand Down Expand Up @@ -1577,7 +1577,7 @@ Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At,

static bool hasZeroSignBit(const Value *V) {
if (const auto *CI = dyn_cast<const ConstantInt>(V))
return (CI->getType()->getSignBit() & CI->getSExtValue()) == 0;
return CI->getValue().isNonNegative();
const Instruction *I = dyn_cast<const Instruction>(V);
if (!I)
return false;
Expand Down Expand Up @@ -1688,7 +1688,7 @@ void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) {
if (I->getOpcode() != Instruction::Or)
return nullptr;
ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1));
if (!Msb || Msb->getZExtValue() != Msb->getType()->getSignBit())
if (!Msb || !Msb->getValue().isSignMask())
return nullptr;
if (!hasZeroSignBit(I->getOperand(0)))
return nullptr;
Expand Down
15 changes: 8 additions & 7 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3763,7 +3763,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
ConstantInt *ExecModeC =
KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
ConstantInt *AssumedExecModeC = ConstantInt::get(
ExecModeC->getType(),
ExecModeC->getIntegerType(),
ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
SPMDCompatibilityTracker.indicateOptimisticFixpoint();
Expand Down Expand Up @@ -3792,7 +3792,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
ConstantInt *MayUseNestedParallelismC =
KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
MayUseNestedParallelismC->getType(), NestedParallelism);
MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
setMayUseNestedParallelismOfKernelEnvironment(
AssumedMayUseNestedParallelismC);

Expand All @@ -3801,7 +3801,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
KernelEnvC);
ConstantInt *AssumedUseGenericStateMachineC =
ConstantInt::get(UseGenericStateMachineC->getType(), false);
ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
setUseGenericStateMachineOfKernelEnvironment(
AssumedUseGenericStateMachineC);
}
Expand Down Expand Up @@ -4280,8 +4280,9 @@ struct AAKernelInfoFunction : AAKernelInfo {
// kernel is executed in.
assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
"Initially non-SPMD kernel has SPMD exec mode!");
setExecModeOfKernelEnvironment(ConstantInt::get(
ExecModeC->getType(), ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
setExecModeOfKernelEnvironment(
ConstantInt::get(ExecModeC->getIntegerType(),
ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));

++NumOpenMPTargetRegionKernelsSPMD;

Expand Down Expand Up @@ -4332,7 +4333,7 @@ struct AAKernelInfoFunction : AAKernelInfo {

// If not SPMD mode, indicate we use a custom state machine now.
setUseGenericStateMachineOfKernelEnvironment(
ConstantInt::get(UseStateMachineC->getType(), false));
ConstantInt::get(UseStateMachineC->getIntegerType(), false));

// If we don't actually need a state machine we are done here. This can
// happen if there simply are no parallel regions. In the resulting kernel
Expand Down Expand Up @@ -4658,7 +4659,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
AA.KernelEnvC);
ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
MayUseNestedParallelismC->getType(), AA.NestedParallelism);
MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
AA.setMayUseNestedParallelismOfKernelEnvironment(
NewMayUseNestedParallelismC);
}
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) {
/// arbitrarily pick 64 bit as our canonical type. The actual bitwidth doesn't
/// matter, we just want a consistent type to simplify CSE.
static ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) {
const unsigned IndexBW = IndexC->getType()->getBitWidth();
const unsigned IndexBW = IndexC->getBitWidth();
if (IndexBW == 64 || IndexC->getValue().getActiveBits() > 64)
return nullptr;
return ConstantInt::get(IndexC->getContext(),
Expand Down Expand Up @@ -2640,7 +2640,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
assert(NewInsIndex != -1 && "Did not fold shuffle with unused operand?");

// Index is updated to the potentially translated insertion lane.
IndexC = ConstantInt::get(IndexC->getType(), NewInsIndex);
IndexC = ConstantInt::get(IndexC->getIntegerType(), NewInsIndex);
return true;
};

Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,7 @@ void ConstantHoistingPass::findBaseConstants(GlobalVariable *BaseGV) {
llvm::stable_sort(ConstCandVec, [](const ConstantCandidate &LHS,
const ConstantCandidate &RHS) {
if (LHS.ConstInt->getType() != RHS.ConstInt->getType())
return LHS.ConstInt->getType()->getBitWidth() <
RHS.ConstInt->getType()->getBitWidth();
return LHS.ConstInt->getBitWidth() < RHS.ConstInt->getBitWidth();
return LHS.ConstInt->getValue().ult(RHS.ConstInt->getValue());
});

Expand Down Expand Up @@ -890,7 +889,7 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) {
Type *Ty = ConstInfo.BaseExpr->getType();
Base = new BitCastInst(ConstInfo.BaseExpr, Ty, "const", IP);
} else {
IntegerType *Ty = ConstInfo.BaseInt->getType();
IntegerType *Ty = ConstInfo.BaseInt->getIntegerType();
Base = new BitCastInst(ConstInfo.BaseInt, Ty, "const", IP);
}

Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/Transforms/Scalar/LoopFlatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,8 @@ static bool verifyTripCount(Value *RHS, Loop *L,
// If the RHS of the compare is equal to the backedge taken count we need
// to add one to get the trip count.
if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
Value *NewRHS = ConstantInt::get(
ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
Value *NewRHS = ConstantInt::get(ConstantRHS->getContext(),
ConstantRHS->getValue() + 1);
return setLoopComponents(NewRHS, TripCount, Increment,
IterationInstructions);
}
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6293,7 +6293,7 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) {
}
case BitMapKind: {
// Type of the bitmap (e.g. i59).
IntegerType *MapTy = BitMap->getType();
IntegerType *MapTy = BitMap->getIntegerType();

// Cast Index to the same type as the bitmap.
// Note: The Index is <= the number of elements in the table, so
Expand Down Expand Up @@ -6668,7 +6668,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
Value *TableIndex;
ConstantInt *TableIndexOffset;
if (UseSwitchConditionAsTableIndex) {
TableIndexOffset = ConstantInt::get(MaxCaseVal->getType(), 0);
TableIndexOffset = ConstantInt::get(MaxCaseVal->getIntegerType(), 0);
TableIndex = SI->getCondition();
} else {
TableIndexOffset = MinCaseVal;
Expand Down Expand Up @@ -6752,7 +6752,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
// Get the TableIndex'th bit of the bitmask.
// If this bit is 0 (meaning hole) jump to the default destination,
// else continue with table lookup.
IntegerType *MapTy = TableMask->getType();
IntegerType *MapTy = TableMask->getIntegerType();
Value *MaskIndex =
Builder.CreateZExtOrTrunc(TableIndex, MapTy, "switch.maskindex");
Value *Shifted = Builder.CreateLShr(TableMask, MaskIndex, "switch.shifted");
Expand Down Expand Up @@ -6975,7 +6975,7 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
// Replace each case with its trailing zeros number.
for (auto &Case : SI->cases()) {
auto *OrigValue = Case.getCaseValue();
Case.setValue(ConstantInt::get(OrigValue->getType(),
Case.setValue(ConstantInt::get(OrigValue->getIntegerType(),
OrigValue->getValue().countr_zero()));
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ static TypedAttr getScalarConstantAsAttr(OpBuilder &builder,
// Convert scalar intergers.
if (auto *constInt = dyn_cast<llvm::ConstantInt>(constScalar)) {
return builder.getIntegerAttr(
IntegerType::get(context, constInt->getType()->getBitWidth()),
IntegerType::get(context, constInt->getBitWidth()),
constInt->getValue());
}

Expand Down

0 comments on commit dea16eb

Please sign in to comment.