Skip to content

Commit

Permalink
[ValueTypes] Add support for scalable EVTs
Browse files Browse the repository at this point in the history
Summary:
* Remove a bunch of asserts checking for unsupported scalable types and
  add some more now that they are supported.
* Propagate the scalable flag where necessary.
* Add another `EVT::getExtendedVectorVT` method that takes an
  ElementCount parameter.
* Add `EVT::isExtendedScalableVector` and
  `EVT::getExtendedVectorElementCount` - latter is currently unused.

Reviewers: sdesmalen, efriedma, rengolin, craig.topper, huntergr

Reviewed By: efriedma

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75672
  • Loading branch information
c-rhodes committed Mar 19, 2020
1 parent e26e9ba commit 5ce38fc
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
31 changes: 11 additions & 20 deletions llvm/include/llvm/CodeGen/ValueTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ namespace llvm {
MVT M = MVT::getVectorVT(VT.V, NumElements, IsScalable);
if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE)
return M;

assert(!IsScalable && "We don't support extended scalable types yet");
return getExtendedVectorVT(Context, VT, NumElements);
return getExtendedVectorVT(Context, VT, NumElements, IsScalable);
}

/// Returns the EVT that represents a vector EC.Min elements in length,
Expand All @@ -87,19 +85,15 @@ namespace llvm {
MVT M = MVT::getVectorVT(VT.V, EC);
if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE)
return M;
assert (!EC.Scalable && "We don't support extended scalable types yet");
return getExtendedVectorVT(Context, VT, EC.Min);
return getExtendedVectorVT(Context, VT, EC);
}

/// Return a vector with the same number of elements as this vector, but
/// with the element type converted to an integer type with the same
/// bitwidth.
EVT changeVectorElementTypeToInteger() const {
if (!isSimple()) {
assert (!isScalableVector() &&
"We don't support extended scalable types yet");
if (!isSimple())
return changeExtendedVectorElementTypeToInteger();
}
MVT EltTy = getSimpleVT().getVectorElementType();
unsigned BitWidth = EltTy.getSizeInBits();
MVT IntTy = MVT::getIntegerVT(BitWidth);
Expand Down Expand Up @@ -156,12 +150,7 @@ namespace llvm {
/// Return true if this is a vector type where the runtime
/// length is machine dependent
bool isScalableVector() const {
// FIXME: We don't support extended scalable types yet, because the
// matching IR type doesn't exist. Once it has been added, this can
// be changed to call isExtendedScalableVector.
if (!isSimple())
return false;
return V.isScalableVector();
return isSimple() ? V.isScalableVector() : isExtendedScalableVector();
}

bool isFixedLengthVector() const {
Expand Down Expand Up @@ -300,9 +289,7 @@ namespace llvm {
if (isSimple())
return V.getVectorElementCount();

assert(!isScalableVector() &&
"We don't support extended scalable types yet");
return {getExtendedVectorNumElements(), false};
return {getExtendedVectorNumElements(), isExtendedScalableVector()};
}

/// Return the size of the specified value type in bits.
Expand Down Expand Up @@ -443,8 +430,10 @@ namespace llvm {
EVT changeExtendedTypeToInteger() const;
EVT changeExtendedVectorElementTypeToInteger() const;
static EVT getExtendedIntegerVT(LLVMContext &C, unsigned BitWidth);
static EVT getExtendedVectorVT(LLVMContext &C, EVT VT,
unsigned NumElements);
static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, unsigned NumElements,
bool IsScalable);
static EVT getExtendedVectorVT(LLVMContext &Context, EVT VT,
ElementCount EC);
bool isExtendedFloatingPoint() const LLVM_READONLY;
bool isExtendedInteger() const LLVM_READONLY;
bool isExtendedScalarInteger() const LLVM_READONLY;
Expand All @@ -458,8 +447,10 @@ namespace llvm {
bool isExtended1024BitVector() const LLVM_READONLY;
bool isExtended2048BitVector() const LLVM_READONLY;
bool isExtendedFixedLengthVector() const LLVM_READONLY;
bool isExtendedScalableVector() const LLVM_READONLY;
EVT getExtendedVectorElementType() const;
unsigned getExtendedVectorNumElements() const LLVM_READONLY;
ElementCount getExtendedVectorElementCount() const LLVM_READONLY;
TypeSize getExtendedSizeInBits() const LLVM_READONLY;
};

Expand Down
11 changes: 7 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,13 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
// Build a vector with BUILD_VECTOR or CONCAT_VECTORS from the
// intermediate operands.
EVT BuiltVectorTy =
EVT::getVectorVT(*DAG.getContext(), IntermediateVT.getScalarType(),
(IntermediateVT.isVector()
? IntermediateVT.getVectorNumElements() * NumParts
: NumIntermediates));
IntermediateVT.isVector()
? EVT::getVectorVT(
*DAG.getContext(), IntermediateVT.getScalarType(),
IntermediateVT.getVectorElementCount() * NumParts)
: EVT::getVectorVT(*DAG.getContext(),
IntermediateVT.getScalarType(),
NumIntermediates);
Val = DAG.getNode(IntermediateVT.isVector() ? ISD::CONCAT_VECTORS
: ISD::BUILD_VECTOR,
DL, BuiltVectorTy, Ops);
Expand Down
27 changes: 23 additions & 4 deletions llvm/lib/CodeGen/ValueTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ EVT EVT::changeExtendedTypeToInteger() const {
EVT EVT::changeExtendedVectorElementTypeToInteger() const {
LLVMContext &Context = LLVMTy->getContext();
EVT IntTy = getIntegerVT(Context, getScalarSizeInBits());
return getVectorVT(Context, IntTy, getVectorNumElements());
return getVectorVT(Context, IntTy, getVectorNumElements(),
isScalableVector());
}

EVT EVT::getExtendedIntegerVT(LLVMContext &Context, unsigned BitWidth) {
Expand All @@ -32,10 +33,19 @@ EVT EVT::getExtendedIntegerVT(LLVMContext &Context, unsigned BitWidth) {
return VT;
}

EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT,
unsigned NumElements) {
EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements,
bool IsScalable) {
EVT ResultVT;
ResultVT.LLVMTy = VectorType::get(VT.getTypeForEVT(Context), NumElements);
ResultVT.LLVMTy =
VectorType::get(VT.getTypeForEVT(Context), NumElements, IsScalable);
assert(ResultVT.isExtended() && "Type is not extended!");
return ResultVT;
}

EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, ElementCount EC) {
EVT ResultVT;
ResultVT.LLVMTy =
VectorType::get(VT.getTypeForEVT(Context), {EC.Min, EC.Scalable});
assert(ResultVT.isExtended() && "Type is not extended!");
return ResultVT;
}
Expand Down Expand Up @@ -96,6 +106,10 @@ bool EVT::isExtendedFixedLengthVector() const {
return isExtendedVector() && !cast<VectorType>(LLVMTy)->isScalable();
}

bool EVT::isExtendedScalableVector() const {
return isExtendedVector() && cast<VectorType>(LLVMTy)->isScalable();
}

EVT EVT::getExtendedVectorElementType() const {
assert(isExtended() && "Type is not extended!");
return EVT::getEVT(cast<VectorType>(LLVMTy)->getElementType());
Expand All @@ -106,6 +120,11 @@ unsigned EVT::getExtendedVectorNumElements() const {
return cast<VectorType>(LLVMTy)->getNumElements();
}

ElementCount EVT::getExtendedVectorElementCount() const {
assert(isExtended() && "Type is not extended!");
return cast<VectorType>(LLVMTy)->getElementCount();
}

TypeSize EVT::getExtendedSizeInBits() const {
assert(isExtended() && "Type is not extended!");
if (IntegerType *ITy = dyn_cast<IntegerType>(LLVMTy))
Expand Down

0 comments on commit 5ce38fc

Please sign in to comment.