Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion llvm/include/llvm/CodeGenTypes/LowLevelType.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,16 @@ class LLT {
return isVector() ? getElementType() : *this;
}

/// Returns a vector with the same number of elements but the new element
/// type. Must only be called on vector types.
constexpr LLT changeVectorElementType(LLT NewEltTy) const {
return LLT::vector(getElementCount(), NewEltTy);
}

/// If this type is a vector, return a vector with the same number of elements
/// but the new element type. Otherwise, return the new element type.
constexpr LLT changeElementType(LLT NewEltTy) const {
return isVector() ? LLT::vector(getElementCount(), NewEltTy) : NewEltTy;
return isVector() ? changeVectorElementType(NewEltTy) : NewEltTy;
}

/// If this type is a vector, return a vector with the same number of elements
Expand All @@ -223,6 +229,14 @@ class LLT {
: LLT::scalar(NewEltSize);
}

/// Return a vector with the same element type and the new element count. Must
/// be called on vector types.
constexpr LLT changeVectorElementCount(ElementCount EC) const {
assert(isVector() &&
"cannot change vector element count of non-vector type");
return LLT::vector(EC, getElementType());
}

/// Return a vector or scalar with the same element type and the new element
/// count.
constexpr LLT changeElementCount(ElementCount EC) const {
Expand Down
49 changes: 22 additions & 27 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) {
unsigned EltSize = OrigTy.getScalarSizeInBits();
if (LeftoverSize % EltSize != 0)
return {-1, -1};
LeftoverTy =
LLT::scalarOrVector(ElementCount::getFixed(LeftoverSize / EltSize),
OrigTy.getElementType());
LeftoverTy = OrigTy.changeElementCount(
ElementCount::getFixed(LeftoverSize / EltSize));
} else {
LeftoverTy = LLT::scalar(LeftoverSize);
}
Expand Down Expand Up @@ -1558,10 +1557,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
// combines not being hit). This seems to be a problem related to the
// artifact combiner.
if (SizeOp0 % NarrowSize != 0) {
LLT ImplicitTy = NarrowTy;
if (DstTy.isVector())
ImplicitTy = LLT::vector(DstTy.getElementCount(), ImplicitTy);

LLT ImplicitTy = DstTy.changeElementType(NarrowTy);
Register ImplicitReg = MIRBuilder.buildUndef(ImplicitTy).getReg(0);
MIRBuilder.buildAnyExt(DstReg, ImplicitReg);

Expand Down Expand Up @@ -3289,7 +3285,8 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
Observer.changingInstr(MI);

widenScalarSrc(
MI, LLT::vector(VecTy.getElementCount(), WideTy.getSizeInBits()), 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cases directly using LLT::vector/fixed_vector are slightly different, since this doesn't handle scalar inputs which are presumed impossible. This is a direct replacement for the scalarOrVector cases

MI,
VecTy.changeVectorElementType(LLT::scalar(WideTy.getSizeInBits())), 1,
TargetOpcode::G_ANYEXT);

widenScalarDst(MI, WideTy, 0);
Expand Down Expand Up @@ -3321,7 +3318,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {

Register VecReg = MI.getOperand(1).getReg();
LLT VecTy = MRI.getType(VecReg);
LLT WideVecTy = LLT::vector(VecTy.getElementCount(), WideTy);
LLT WideVecTy = VecTy.changeVectorElementType(WideTy);

widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_ANYEXT);
widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
Expand Down Expand Up @@ -3522,9 +3519,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
Observer.changingInstr(MI);
Register VecReg = MI.getOperand(1).getReg();
LLT VecTy = MRI.getType(VecReg);
LLT WideVecTy = VecTy.isVector()
? LLT::vector(VecTy.getElementCount(), WideTy)
: WideTy;
LLT WideVecTy = VecTy.changeElementType(WideTy);
widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_FPEXT);
widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
Observer.changedInstr(MI);
Expand Down Expand Up @@ -3658,7 +3653,8 @@ LegalizerHelper::lowerBitcast(MachineInstr &MI) {
// %3:_(<2 x s8>) = G_BITCAST %2
// %4:_(<2 x s8>) = G_BITCAST %3
// %1:_(<4 x s16>) = G_CONCAT_VECTORS %3, %4
DstCastTy = LLT::fixed_vector(NumDstElt / NumSrcElt, DstEltTy);
DstCastTy = DstTy.changeVectorElementCount(
ElementCount::getFixed(NumDstElt / NumSrcElt));
SrcPartTy = SrcEltTy;
} else if (NumSrcElt > NumDstElt) { // Source element type is smaller.
//
Expand All @@ -3670,7 +3666,8 @@ LegalizerHelper::lowerBitcast(MachineInstr &MI) {
// %3:_(s16) = G_BITCAST %2
// %4:_(s16) = G_BITCAST %3
// %1:_(<2 x s16>) = G_BUILD_VECTOR %3, %4
SrcPartTy = LLT::fixed_vector(NumSrcElt / NumDstElt, SrcEltTy);
SrcPartTy = SrcTy.changeVectorElementCount(
ElementCount::getFixed(NumSrcElt / NumDstElt));
DstCastTy = DstEltTy;
}

Expand Down Expand Up @@ -3736,7 +3733,7 @@ LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx,
unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1;
unsigned OldNumElts = SrcVecTy.getNumElements();

LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy;
LLT NewEltTy = CastTy.getScalarType();
Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0);

const unsigned NewEltSize = NewEltTy.getSizeInBits();
Expand All @@ -3758,7 +3755,7 @@ LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx,
// Type of the intermediate result vector.
const unsigned NewEltsPerOldElt = NewNumElts / OldNumElts;
LLT MidTy =
LLT::scalarOrVector(ElementCount::getFixed(NewEltsPerOldElt), NewEltTy);
CastTy.changeElementCount(ElementCount::getFixed(NewEltsPerOldElt));

auto NewEltsPerOldEltK = MIRBuilder.buildConstant(IdxTy, NewEltsPerOldElt);

Expand Down Expand Up @@ -4231,8 +4228,8 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
// the size of the load without needing to scalarize it.
if (Alignment.value() * 8 > MemSizeInBits &&
isPowerOf2_64(DstTy.getScalarSizeInBits())) {
LLT MoreTy = LLT::fixed_vector(NextPowerOf2(DstTy.getNumElements()),
DstTy.getElementType());
LLT MoreTy = DstTy.changeVectorElementCount(
ElementCount::getFixed(NextPowerOf2(DstTy.getNumElements())));
MachineMemOperand *NewMMO = MF.getMachineMemOperand(&MMO, 0, MoreTy);
auto NewLoad = MIRBuilder.buildLoad(MoreTy, PtrReg, *NewMMO);
MIRBuilder.buildDeleteTrailingVectorElements(LoadMI.getReg(0),
Expand Down Expand Up @@ -5023,8 +5020,7 @@ static void makeDstOps(SmallVectorImpl<DstOp> &DstOps, LLT Ty,
unsigned NumElts) {
LLT LeftoverTy;
assert(Ty.isVector() && "Expected vector type");
LLT EltTy = Ty.getElementType();
LLT NarrowTy = (NumElts == 1) ? EltTy : LLT::fixed_vector(NumElts, EltTy);
LLT NarrowTy = Ty.changeElementCount(ElementCount::getFixed(NumElts));
int NumParts, NumLeftover;
std::tie(NumParts, NumLeftover) =
getNarrowTypeBreakDown(Ty, NarrowTy, LeftoverTy);
Expand Down Expand Up @@ -5705,7 +5701,8 @@ LegalizerHelper::fewerElementsBitcast(MachineInstr &MI, unsigned int TypeIdx,
auto Unmerge = MIRBuilder.buildUnmerge(SrcNarrowTy, SrcReg);
getUnmergeResults(SrcVRegs, *Unmerge);
} else {
LLT SrcNarrowTy = LLT::fixed_vector(NewElemCount, SrcTy.getElementType());
LLT SrcNarrowTy =
SrcTy.changeVectorElementCount(ElementCount::getFixed(NewElemCount));

// Split the Src and Dst Reg into smaller registers
if (extractGCDType(SrcVRegs, DstTy, SrcNarrowTy, SrcReg) != SrcNarrowTy)
Expand Down Expand Up @@ -6837,8 +6834,7 @@ LegalizerHelper::moreElementsVector(MachineInstr &MI, unsigned TypeIdx,
Observer.changingInstr(MI);
moreElementsVectorSrc(MI, MoreTy, 2);
moreElementsVectorSrc(MI, MoreTy, 3);
LLT CondTy = LLT::fixed_vector(
MoreTy.getNumElements(),
LLT CondTy = MoreTy.changeVectorElementType(
MRI.getType(MI.getOperand(0).getReg()).getElementType());
moreElementsVectorDst(MI, CondTy, 0);
Observer.changedInstr(MI);
Expand Down Expand Up @@ -6930,7 +6926,8 @@ LegalizerHelper::equalizeVectorShuffleLengths(MachineInstr &MI) {

unsigned PaddedMaskNumElts = alignTo(MaskNumElts, SrcNumElts);
unsigned NumConcat = PaddedMaskNumElts / SrcNumElts;
LLT PaddedTy = LLT::fixed_vector(PaddedMaskNumElts, DestEltTy);
LLT PaddedTy =
DstTy.changeVectorElementCount(ElementCount::getFixed(PaddedMaskNumElts));

// Create new source vectors by concatenating the initial
// source vectors with undefined vectors of the same size.
Expand Down Expand Up @@ -9894,9 +9891,7 @@ LegalizerHelper::lowerISFPCLASS(MachineInstr &MI) {
unsigned BitSize = SrcTy.getScalarSizeInBits();
const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());

LLT IntTy = LLT::scalar(BitSize);
if (SrcTy.isVector())
IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
LLT IntTy = SrcTy.changeElementType(LLT::scalar(BitSize));
auto AsInt = MIRBuilder.buildCopy(IntTy, SrcReg);

// Various masks.
Expand Down