Skip to content

Commit

Permalink
[LLVM][IR] Add native vector support to ConstantInt & ConstantFP.
Browse files Browse the repository at this point in the history
NOTE: For brevity the following talks about ConstantInt but
everything extends to cover ConstantFP as well.

Whilst ConstantInt::get() supports the creation of vectors whereby
each lane has the same value, it achieves this via other constants:

  * ConstantVector for fixed-length vectors
  * ConstantExprs for scalable vectors

However, ConstantExprs are being deprecated and ConstantVector is
not space efficient for larger vector types. By extending ConstantInt
we can represent vector splats by only storing the underlying scalar
value.

More specifically:

 * ConstantInt gains an ElementCount variant of get().
 * LLVMContext is extended to map <EC,APInt>->ConstantInt.
 * BitcodeReader/Writer support is extended to allow vector types.

Whilst this patch adds the base support, more work is required
before it's production ready. For example, there's likely to be
many places where isa<ConstantInt> assumes a scalar type. Accordingly
the default behaviour of ConstantInt::get() remains unchanged but a
set of flags are added to allow wider testing and thus help with the
migration:

  --use-constant-int-for-fixed-length-splat
  --use-constant-fp-for-fixed-length-splat
  --use-constant-int-for-scalable-splat
  --use-constant-fp-for-scalable-splat

NOTE: No change is required to the bitcode format because types and
values are handled separately.

NOTE: For similar reasons as above, code generation doesn't work
out-the-box.
  • Loading branch information
paulwalker-arm committed Feb 9, 2024
1 parent 2cb61a1 commit 6c6baf8
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 38 deletions.
12 changes: 10 additions & 2 deletions llvm/include/llvm/IR/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ConstantInt final : public ConstantData {

APInt Val;

ConstantInt(IntegerType *Ty, const APInt &V);
ConstantInt(Type *Ty, const APInt &V);

void destroyConstantImpl();

Expand Down Expand Up @@ -123,6 +123,12 @@ class ConstantInt final : public ConstantData {
/// type is the integer type that corresponds to the bit width of the value.
static ConstantInt *get(LLVMContext &Context, const APInt &V);

/// Return a ConstantInt with the specified value and an implied Type. The
/// type is the vector type whose integer element type corresponds to the bit
/// width of the value.
static ConstantInt *get(LLVMContext &Context, ElementCount EC,
const APInt &V);

/// Return a ConstantInt constructed from the string strStart with the given
/// radix.
static ConstantInt *get(IntegerType *Ty, StringRef Str, uint8_t Radix);
Expand All @@ -136,7 +142,7 @@ class ConstantInt final : public ConstantData {
/// Return the constant's value.
inline const APInt &getValue() const { return Val; }

/// getBitWidth - Return the bitwidth of this constant.
/// getBitWidth - Return the scalar bitwidth of this constant.
unsigned getBitWidth() const { return Val.getBitWidth(); }

/// Return the constant as a 64-bit unsigned integer value after it
Expand Down Expand Up @@ -281,6 +287,8 @@ class ConstantFP final : public ConstantData {

static Constant *get(Type *Ty, StringRef Str);
static ConstantFP *get(LLVMContext &Context, const APFloat &V);
static ConstantFP *get(LLVMContext &Context, ElementCount EC,
const APFloat &V);
static Constant *getNaN(Type *Ty, bool Negative = false,
uint64_t Payload = 0);
static Constant *getQNaN(Type *Ty, bool Negative = false,
Expand Down
55 changes: 28 additions & 27 deletions llvm/lib/Bitcode/Reader/BitcodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3060,48 +3060,49 @@ Error BitcodeReader::parseConstants() {
V = Constant::getNullValue(CurTy);
break;
case bitc::CST_CODE_INTEGER: // INTEGER: [intval]
if (!CurTy->isIntegerTy() || Record.empty())
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
return error("Invalid integer const record");
V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
break;
case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
if (!CurTy->isIntegerTy() || Record.empty())
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
return error("Invalid wide integer const record");

APInt VInt =
readWideAPInt(Record, cast<IntegerType>(CurTy)->getBitWidth());
V = ConstantInt::get(Context, VInt);

auto *ScalarTy = cast<IntegerType>(CurTy->getScalarType());
APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth());
V = ConstantInt::get(CurTy, VInt);
break;
}
case bitc::CST_CODE_FLOAT: { // FLOAT: [fpval]
if (Record.empty())
return error("Invalid float const record");
if (CurTy->isHalfTy())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
APInt(16, (uint16_t)Record[0])));
else if (CurTy->isBFloatTy())
V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
APInt(16, (uint32_t)Record[0])));
else if (CurTy->isFloatTy())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
APInt(32, (uint32_t)Record[0])));
else if (CurTy->isDoubleTy())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(),
APInt(64, Record[0])));
else if (CurTy->isX86_FP80Ty()) {

auto *ScalarTy = CurTy->getScalarType();
if (ScalarTy->isHalfTy())
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
APInt(16, (uint16_t)Record[0])));
else if (ScalarTy->isBFloatTy())
V = ConstantFP::get(
CurTy, APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0])));
else if (ScalarTy->isFloatTy())
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEsingle(),
APInt(32, (uint32_t)Record[0])));
else if (ScalarTy->isDoubleTy())
V = ConstantFP::get(
CurTy, APFloat(APFloat::IEEEdouble(), APInt(64, Record[0])));
else if (ScalarTy->isX86_FP80Ty()) {
// Bits are not stored the same way as a normal i80 APInt, compensate.
uint64_t Rearrange[2];
Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16);
Rearrange[1] = Record[0] >> 48;
V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(),
APInt(80, Rearrange)));
} else if (CurTy->isFP128Ty())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(),
APInt(128, Record)));
else if (CurTy->isPPC_FP128Ty())
V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(),
APInt(128, Record)));
V = ConstantFP::get(
CurTy, APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange)));
} else if (ScalarTy->isFP128Ty())
V = ConstantFP::get(CurTy,
APFloat(APFloat::IEEEquad(), APInt(128, Record)));
else if (ScalarTy->isPPC_FP128Ty())
V = ConstantFP::get(
CurTy, APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record)));
else
V = UndefValue::get(CurTy);
break;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2624,7 +2624,7 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
}
} else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
Code = bitc::CST_CODE_FLOAT;
Type *Ty = CFP->getType();
Type *Ty = CFP->getType()->getScalarType();
if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
Ty->isDoubleTy()) {
Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
Expand Down
37 changes: 32 additions & 5 deletions llvm/lib/IR/AsmWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1409,16 +1409,32 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
AsmWriterContext &WriterCtx) {
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
if (CI->getType()->isIntegerTy(1)) {
Out << (CI->getZExtValue() ? "true" : "false");
return;
if (CI->getType()->isVectorTy()) {
Out << "splat (";
WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out);
Out << " ";
}
Out << CI->getValue();

if (CI->getType()->getScalarType()->isIntegerTy(1))
Out << (CI->getZExtValue() ? "true" : "false");
else
Out << CI->getValue();

if (CI->getType()->isVectorTy())
Out << ")";

return;
}

if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
const APFloat &APF = CFP->getValueAPF();

if (CFP->getType()->isVectorTy()) {
Out << "splat (";
WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out);
Out << " ";
}

if (&APF.getSemantics() == &APFloat::IEEEsingle() ||
&APF.getSemantics() == &APFloat::IEEEdouble()) {
// We would like to output the FP constant value in exponential notation,
Expand All @@ -1444,6 +1460,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
// Reparse stringized version!
if (APFloat(APFloat::IEEEdouble(), StrVal).convertToDouble() == Val) {
Out << StrVal;

if (CFP->getType()->isVectorTy())
Out << ")";

return;
}
}
Expand All @@ -1469,6 +1489,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
}
}
Out << format_hex(apf.bitcastToAPInt().getZExtValue(), 0, /*Upper=*/true);

if (CFP->getType()->isVectorTy())
Out << ")";

return;
}

Expand All @@ -1483,7 +1507,6 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
/*Upper=*/true);
Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
/*Upper=*/true);
return;
} else if (&APF.getSemantics() == &APFloat::IEEEquad()) {
Out << 'L';
Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
Expand All @@ -1506,6 +1529,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
/*Upper=*/true);
} else
llvm_unreachable("Unsupported floating point type");

if (CFP->getType()->isVectorTy())
Out << ")";

return;
}

Expand Down
82 changes: 79 additions & 3 deletions llvm/lib/IR/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@
using namespace llvm;
using namespace PatternMatch;

// As set of temporary options to help migrate how splats are represented.
static cl::opt<bool> UseConstantIntForFixedLengthSplat(
"use-constant-int-for-fixed-length-splat", cl::init(false), cl::Hidden,
cl::desc("Use ConstantInt's native fixed-length vector splat support."));
static cl::opt<bool> UseConstantFPForFixedLengthSplat(
"use-constant-fp-for-fixed-length-splat", cl::init(false), cl::Hidden,
cl::desc("Use ConstantFP's native fixed-length vector splat support."));
static cl::opt<bool> UseConstantIntForScalableSplat(
"use-constant-int-for-scalable-splat", cl::init(false), cl::Hidden,
cl::desc("Use ConstantInt's native scalable vector splat support."));
static cl::opt<bool> UseConstantFPForScalableSplat(
"use-constant-fp-for-scalable-splat", cl::init(false), cl::Hidden,
cl::desc("Use ConstantFP's native scalable vector splat support."));

//===----------------------------------------------------------------------===//
// Constant Class
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -825,9 +839,11 @@ bool Constant::isManifestConstant() const {
// ConstantInt
//===----------------------------------------------------------------------===//

ConstantInt::ConstantInt(IntegerType *Ty, const APInt &V)
ConstantInt::ConstantInt(Type *Ty, const APInt &V)
: ConstantData(Ty, ConstantIntVal), Val(V) {
assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type");
assert(V.getBitWidth() ==
cast<IntegerType>(Ty->getScalarType())->getBitWidth() &&
"Invalid constant for type");
}

ConstantInt *ConstantInt::getTrue(LLVMContext &Context) {
Expand Down Expand Up @@ -885,6 +901,26 @@ ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
return Slot.get();
}

// Get a ConstantInt vector with each lane set to the same APInt.
ConstantInt *ConstantInt::get(LLVMContext &Context, ElementCount EC,
const APInt &V) {
// Get an existing value or the insertion position.
std::unique_ptr<ConstantInt> &Slot =
Context.pImpl->IntSplatConstants[std::make_pair(EC, V)];
if (!Slot) {
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
VectorType *VTy = VectorType::get(ITy, EC);
Slot.reset(new ConstantInt(VTy, V));
}

#ifndef NDEBUG
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
VectorType *VTy = VectorType::get(ITy, EC);
assert(Slot->getType() == VTy);
#endif
return Slot.get();
}

Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
Constant *C = get(cast<IntegerType>(Ty->getScalarType()), V, isSigned);

Expand Down Expand Up @@ -1024,6 +1060,26 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
return Slot.get();
}

// Get a ConstantFP vector with each lane set to the same APFloat.
ConstantFP *ConstantFP::get(LLVMContext &Context, ElementCount EC,
const APFloat &V) {
// Get an existing value or the insertion position.
std::unique_ptr<ConstantFP> &Slot =
Context.pImpl->FPSplatConstants[std::make_pair(EC, V)];
if (!Slot) {
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
VectorType *VTy = VectorType::get(EltTy, EC);
Slot.reset(new ConstantFP(VTy, V));
}

#ifndef NDEBUG
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
VectorType *VTy = VectorType::get(EltTy, EC);
assert(Slot->getType() == VTy);
#endif
return Slot.get();
}

Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics();
Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
Expand All @@ -1036,7 +1092,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {

ConstantFP::ConstantFP(Type *Ty, const APFloat &V)
: ConstantData(Ty, ConstantFPVal), Val(V) {
assert(&V.getSemantics() == &Ty->getFltSemantics() &&
assert(&V.getSemantics() == &Ty->getScalarType()->getFltSemantics() &&
"FP type Mismatch");
}

Expand Down Expand Up @@ -1384,6 +1440,16 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {

Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
if (!EC.isScalable()) {
// Maintain special handling of zero.
if (!V->isNullValue()) {
if (UseConstantIntForFixedLengthSplat && isa<ConstantInt>(V))
return ConstantInt::get(V->getContext(), EC,
cast<ConstantInt>(V)->getValue());
if (UseConstantFPForFixedLengthSplat && isa<ConstantFP>(V))
return ConstantFP::get(V->getContext(), EC,
cast<ConstantFP>(V)->getValue());
}

// If this splat is compatible with ConstantDataVector, use it instead of
// ConstantVector.
if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
Expand All @@ -1394,6 +1460,16 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
return get(Elts);
}

// Maintain special handling of zero.
if (!V->isNullValue()) {
if (UseConstantIntForScalableSplat && isa<ConstantInt>(V))
return ConstantInt::get(V->getContext(), EC,
cast<ConstantInt>(V)->getValue());
if (UseConstantFPForScalableSplat && isa<ConstantFP>(V))
return ConstantFP::get(V->getContext(), EC,
cast<ConstantFP>(V)->getValue());
}

Type *VTy = VectorType::get(V->getType(), EC);

if (V->isNullValue())
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/IR/LLVMContextImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ LLVMContextImpl::~LLVMContextImpl() {
IntZeroConstants.clear();
IntOneConstants.clear();
IntConstants.clear();
IntSplatConstants.clear();
FPConstants.clear();
FPSplatConstants.clear();
CDSConstants.clear();

// Destroy attribute node lists.
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/IR/LLVMContextImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1488,8 +1488,12 @@ class LLVMContextImpl {
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntZeroConstants;
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntOneConstants;
DenseMap<APInt, std::unique_ptr<ConstantInt>> IntConstants;
DenseMap<std::pair<ElementCount, APInt>, std::unique_ptr<ConstantInt>>
IntSplatConstants;

DenseMap<APFloat, std::unique_ptr<ConstantFP>> FPConstants;
DenseMap<std::pair<ElementCount, APFloat>, std::unique_ptr<ConstantFP>>
FPSplatConstants;

FoldingSet<AttributeImpl> AttrsSet;
FoldingSet<AttributeListImpl> AttrsLists;
Expand Down

0 comments on commit 6c6baf8

Please sign in to comment.