Skip to content

Commit

Permalink
[GISEL] Add IRTranslation for shufflevector on scalable vector types (#…
Browse files Browse the repository at this point in the history
…80378)

Recommits #80378 which was reverted in
#84330. The problem was that the change in
llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir used
217 as an opcode instead of a regex.
  • Loading branch information
michaelmaitland committed Mar 7, 2024
1 parent 8f79cdd commit 96049fc
Show file tree
Hide file tree
Showing 15 changed files with 1,890 additions and 21 deletions.
5 changes: 5 additions & 0 deletions llvm/docs/GlobalISel/GenericOpcode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,11 @@ Concatenate two vectors and shuffle the elements according to the mask operand.
The mask operand should be an IR Constant which exactly matches the
corresponding mask for the IR shufflevector instruction.

G_SPLAT_VECTOR
^^^^^^^^^^^^^^^^

Create a vector where all elements are the scalar from the source operand.

Vector Reduction Operations
---------------------------

Expand Down
12 changes: 10 additions & 2 deletions llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1063,8 +1063,7 @@ class MachineIRBuilder {

/// Build and insert \p Res = G_BUILD_VECTOR with \p Src replicated to fill
/// the number of elements
MachineInstrBuilder buildSplatVector(const DstOp &Res,
const SrcOp &Src);
MachineInstrBuilder buildSplatBuildVector(const DstOp &Res, const SrcOp &Src);

/// Build and insert \p Res = G_BUILD_VECTOR_TRUNC \p Op0, ...
///
Expand Down Expand Up @@ -1099,6 +1098,15 @@ class MachineIRBuilder {
MachineInstrBuilder buildShuffleVector(const DstOp &Res, const SrcOp &Src1,
const SrcOp &Src2, ArrayRef<int> Mask);

/// Build and insert \p Res = G_SPLAT_VECTOR \p Val
///
/// \pre setBasicBlock or setMI must have been called.
/// \pre \p Res must be a generic virtual register with vector type.
/// \pre \p Val must be a generic virtual register with scalar type.
///
/// \return a MachineInstrBuilder for the newly created instruction.
MachineInstrBuilder buildSplatVector(const DstOp &Res, const SrcOp &Val);

/// Build and insert \p Res = G_CONCAT_VECTORS \p Op0, ...
///
/// G_CONCAT_VECTORS creates a vector from the concatenation of 2 or more
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/Support/TargetOpcodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,9 @@ HANDLE_TARGET_OPCODE(G_EXTRACT_VECTOR_ELT)
/// Generic shufflevector.
HANDLE_TARGET_OPCODE(G_SHUFFLE_VECTOR)

/// Generic splatvector.
HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR)

/// Generic count trailing zeroes.
HANDLE_TARGET_OPCODE(G_CTTZ)

Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/Target/GenericOpcodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,13 @@ def G_SHUFFLE_VECTOR: GenericInstruction {
let hasSideEffects = false;
}

// Generic splatvector.
def G_SPLAT_VECTOR: GenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type1:$val);
let hasSideEffects = false;
}

//------------------------------------------------------------------------------
// Vector reductions
//------------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ MachineInstrBuilder CSEMIRBuilder::buildConstant(const DstOp &Res,
// For vectors, CSE the element only for now.
LLT Ty = Res.getLLTTy(*getMRI());
if (Ty.isVector())
return buildSplatVector(Res, buildConstant(Ty.getElementType(), Val));
return buildSplatBuildVector(Res, buildConstant(Ty.getElementType(), Val));

FoldingSetNodeID ID;
GISelInstProfileBuilder ProfBuilder(ID, *getMRI());
Expand All @@ -336,7 +336,7 @@ MachineInstrBuilder CSEMIRBuilder::buildFConstant(const DstOp &Res,
// For vectors, CSE the element only for now.
LLT Ty = Res.getLLTTy(*getMRI());
if (Ty.isVector())
return buildSplatVector(Res, buildFConstant(Ty.getElementType(), Val));
return buildSplatBuildVector(Res, buildFConstant(Ty.getElementType(), Val));

FoldingSetNodeID ID;
GISelInstProfileBuilder ProfBuilder(ID, *getMRI());
Expand Down
27 changes: 21 additions & 6 deletions llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1598,10 +1598,10 @@ bool IRTranslator::translateGetElementPtr(const User &U,
// We might need to splat the base pointer into a vector if the offsets
// are vectors.
if (WantSplatVector && !PtrTy.isVector()) {
BaseReg =
MIRBuilder
.buildSplatVector(LLT::fixed_vector(VectorWidth, PtrTy), BaseReg)
.getReg(0);
BaseReg = MIRBuilder
.buildSplatBuildVector(LLT::fixed_vector(VectorWidth, PtrTy),
BaseReg)
.getReg(0);
PtrIRTy = FixedVectorType::get(PtrIRTy, VectorWidth);
PtrTy = getLLTForType(*PtrIRTy, *DL);
OffsetIRTy = DL->getIndexType(PtrIRTy);
Expand Down Expand Up @@ -1639,8 +1639,10 @@ bool IRTranslator::translateGetElementPtr(const User &U,
LLT IdxTy = MRI->getType(IdxReg);
if (IdxTy != OffsetTy) {
if (!IdxTy.isVector() && WantSplatVector) {
IdxReg = MIRBuilder.buildSplatVector(
OffsetTy.changeElementType(IdxTy), IdxReg).getReg(0);
IdxReg = MIRBuilder
.buildSplatBuildVector(OffsetTy.changeElementType(IdxTy),
IdxReg)
.getReg(0);
}

IdxReg = MIRBuilder.buildSExtOrTrunc(OffsetTy, IdxReg).getReg(0);
Expand Down Expand Up @@ -2997,6 +2999,19 @@ bool IRTranslator::translateExtractElement(const User &U,

bool IRTranslator::translateShuffleVector(const User &U,
MachineIRBuilder &MIRBuilder) {
// A ShuffleVector that has operates on scalable vectors is a splat vector
// where the value of the splat vector is the 0th element of the first
// operand, since the index mask operand is the zeroinitializer (undef and
// poison are treated as zeroinitializer here).
if (U.getOperand(0)->getType()->isScalableTy()) {
Value *Op0 = U.getOperand(0);
auto SplatVal = MIRBuilder.buildExtractVectorElementConstant(
LLT::scalar(Op0->getType()->getScalarSizeInBits()),
getOrCreateVReg(*Op0), 0);
MIRBuilder.buildSplatVector(getOrCreateVReg(U), SplatVal);
return true;
}

ArrayRef<int> Mask;
if (auto *SVI = dyn_cast<ShuffleVectorInst>(&U))
Mask = SVI->getShuffleMask();
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8391,7 +8391,7 @@ static Register getMemsetValue(Register Val, LLT Ty, MachineIRBuilder &MIB) {

// For vector types create a G_BUILD_VECTOR.
if (Ty.isVector())
Val = MIB.buildSplatVector(Ty, Val).getReg(0);
Val = MIB.buildSplatBuildVector(Ty, Val).getReg(0);

return Val;
}
Expand Down
16 changes: 12 additions & 4 deletions llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ MachineInstrBuilder MachineIRBuilder::buildConstant(const DstOp &Res,
auto Const = buildInstr(TargetOpcode::G_CONSTANT)
.addDef(getMRI()->createGenericVirtualRegister(EltTy))
.addCImm(&Val);
return buildSplatVector(Res, Const);
return buildSplatBuildVector(Res, Const);
}

auto Const = buildInstr(TargetOpcode::G_CONSTANT);
Expand Down Expand Up @@ -363,7 +363,7 @@ MachineInstrBuilder MachineIRBuilder::buildFConstant(const DstOp &Res,
.addDef(getMRI()->createGenericVirtualRegister(EltTy))
.addFPImm(&Val);

return buildSplatVector(Res, Const);
return buildSplatBuildVector(Res, Const);
}

auto Const = buildInstr(TargetOpcode::G_FCONSTANT);
Expand Down Expand Up @@ -711,8 +711,8 @@ MachineIRBuilder::buildBuildVectorConstant(const DstOp &Res,
return buildInstr(TargetOpcode::G_BUILD_VECTOR, Res, TmpVec);
}

MachineInstrBuilder MachineIRBuilder::buildSplatVector(const DstOp &Res,
const SrcOp &Src) {
MachineInstrBuilder MachineIRBuilder::buildSplatBuildVector(const DstOp &Res,
const SrcOp &Src) {
SmallVector<SrcOp, 8> TmpVec(Res.getLLTTy(*getMRI()).getNumElements(), Src);
return buildInstr(TargetOpcode::G_BUILD_VECTOR, Res, TmpVec);
}
Expand Down Expand Up @@ -742,6 +742,14 @@ MachineInstrBuilder MachineIRBuilder::buildShuffleSplat(const DstOp &Res,
return buildShuffleVector(DstTy, InsElt, UndefVec, ZeroMask);
}

MachineInstrBuilder MachineIRBuilder::buildSplatVector(const DstOp &Res,
const SrcOp &Src) {
LLT DstTy = Res.getLLTTy(*getMRI());
assert(Src.getLLTTy(*getMRI()) == DstTy.getElementType() &&
"Expected Src to match Dst elt ty");
return buildInstr(TargetOpcode::G_SPLAT_VECTOR, Res, Src);
}

MachineInstrBuilder MachineIRBuilder::buildShuffleVector(const DstOp &Res,
const SrcOp &Src1,
const SrcOp &Src2,
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/CodeGen/MachineVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,24 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {

break;
}

case TargetOpcode::G_SPLAT_VECTOR: {
LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
LLT SrcTy = MRI->getType(MI->getOperand(1).getReg());

if (!DstTy.isScalableVector())
report("Destination type must be a scalable vector", MI);

if (!SrcTy.isScalar())
report("Source type must be a scalar", MI);

if (DstTy.getScalarType() != SrcTy)
report("Element type of the destination must be the same type as the "
"source type",
MI);

break;
}
case TargetOpcode::G_DYN_STACKALLOC: {
const MachineOperand &DstOp = MI->getOperand(0);
const MachineOperand &AllocOp = MI->getOperand(1);
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20920,7 +20920,8 @@ bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
unsigned Op = Inst.getOpcode();
if (Op == Instruction::Add || Op == Instruction::Sub ||
Op == Instruction::And || Op == Instruction::Or ||
Op == Instruction::Xor || Op == Instruction::InsertElement)
Op == Instruction::Xor || Op == Instruction::InsertElement ||
Op == Instruction::Xor || Op == Instruction::ShuffleVector)
return false;

if (Inst.getType()->isScalableTy())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@
# DEBUG-NEXT: G_SHUFFLE_VECTOR (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: G_SPLAT_VECTOR (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: G_CTTZ (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
Expand Down

0 comments on commit 96049fc

Please sign in to comment.