Skip to content

Commit

Permalink
[RISCV][GISEL] Legalize G_VSCALE
Browse files Browse the repository at this point in the history
G_VSCALE should be lowered using VLENB.
  • Loading branch information
michaelmaitland committed Mar 25, 2024
1 parent 05dc5d9 commit 4768150
Show file tree
Hide file tree
Showing 8 changed files with 528 additions and 1 deletion.
11 changes: 11 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,17 @@ class MachineIRBuilder {
/// \return a MachineInstrBuilder for the newly created instruction.
MachineInstrBuilder buildVScale(const DstOp &Res, const ConstantInt &MinElts);

/// Build and insert \p Res = G_VSCALE \p MinElts
///
/// G_VSCALE puts the value of the runtime vscale multiplied by \p MinElts
/// into \p Res.
///
/// \pre setBasicBlock or setMI must have been called.
/// \pre \p Res must be a generic virtual register with scalar type.
///
/// \return a MachineInstrBuilder for the newly created instruction.
MachineInstrBuilder buildVScale(const DstOp &Res, const APInt &MinElts);

/// Build and insert a G_INTRINSIC instruction.
///
/// There are four different opcodes based on combinations of whether the
Expand Down
51 changes: 50 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,36 @@ LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
case TargetOpcode::G_FLDEXP:
case TargetOpcode::G_STRICT_FLDEXP:
return narrowScalarFLDEXP(MI, TypeIdx, NarrowTy);
case TargetOpcode::G_VSCALE: {
LLT Ty = MRI.getType(MI.getOperand(0).getReg());
const APInt &Val = MI.getOperand(1).getCImm()->getValue();
unsigned TotalSize = Ty.getSizeInBits();
unsigned NarrowSize = NarrowTy.getSizeInBits();
int NumParts = TotalSize / NarrowSize;

SmallVector<Register, 4> PartRegs;
for (int I = 0; I != NumParts; ++I) {
unsigned Offset = I * NarrowSize;
auto K =
MIRBuilder.buildVScale(NarrowTy, Val.lshr(Offset).trunc(NarrowSize));
PartRegs.push_back(K.getReg(0));
}
LLT LeftoverTy;
unsigned LeftoverBits = TotalSize - NumParts * NarrowSize;
SmallVector<Register, 1> LeftoverRegs;
if (LeftoverBits != 0) {
LeftoverTy = LLT::scalar(LeftoverBits);
auto K = MIRBuilder.buildVScale(
LeftoverTy, Val.lshr(NumParts * NarrowSize).trunc(LeftoverBits));
LeftoverRegs.push_back(K.getReg(0));
}

insertParts(MI.getOperand(0).getReg(), Ty, NarrowTy, PartRegs, LeftoverTy,
LeftoverRegs);

MI.eraseFromParent();
return Legalized;
}
}
}

Expand Down Expand Up @@ -2966,7 +2996,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
case TargetOpcode::G_VECREDUCE_FMIN:
case TargetOpcode::G_VECREDUCE_FMAX:
case TargetOpcode::G_VECREDUCE_FMINIMUM:
case TargetOpcode::G_VECREDUCE_FMAXIMUM:
case TargetOpcode::G_VECREDUCE_FMAXIMUM: {
if (TypeIdx != 0)
return UnableToLegalize;
Observer.changingInstr(MI);
Expand All @@ -2980,6 +3010,25 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
Observer.changedInstr(MI);
return Legalized;
}
case TargetOpcode::G_VSCALE: {
MachineOperand &SrcMO = MI.getOperand(1);
LLVMContext &Ctx = MIRBuilder.getMF().getFunction().getContext();
unsigned ExtOpc = LI.getExtOpcodeForWideningConstant(
MRI.getType(MI.getOperand(0).getReg()));
assert((ExtOpc == TargetOpcode::G_ZEXT || ExtOpc == TargetOpcode::G_SEXT ||
ExtOpc == TargetOpcode::G_ANYEXT) &&
"Illegal Extend");
const APInt &SrcVal = SrcMO.getCImm()->getValue();
const APInt &Val = (ExtOpc == TargetOpcode::G_SEXT)
? SrcVal.sext(WideTy.getSizeInBits())
: SrcVal.zext(WideTy.getSizeInBits());
Observer.changingInstr(MI);
SrcMO.setCImm(ConstantInt::get(Ctx, Val));
widenScalarDst(MI, WideTy);
Observer.changedInstr(MI);
return Legalized;
}
}
}

static void getUnmergePieces(SmallVectorImpl<Register> &Pieces,
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,13 @@ MachineInstrBuilder MachineIRBuilder::buildVScale(const DstOp &Res,
return VScale;
}

MachineInstrBuilder MachineIRBuilder::buildVScale(const DstOp &Res,
const APInt &MinElts) {
ConstantInt *CI =
ConstantInt::get(getMF().getFunction().getContext(), MinElts);
return buildVScale(Res, *CI);
}

static unsigned getIntrinsicOpcode(bool HasSideEffects, bool IsConvergent) {
if (HasSideEffects && IsConvergent)
return TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS;
Expand Down
48 changes: 48 additions & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
.clampScalar(0, s32, sXLen)
.lowerForCartesianProduct({s32, sXLen, p0}, {p0});

getActionDefinitionsBuilder(G_VSCALE)
.clampScalar(0, sXLen, sXLen)
.customFor({sXLen});

getLegacyLegalizerInfo().computeTables();
}

Expand Down Expand Up @@ -527,6 +531,48 @@ bool RISCVLegalizerInfo::shouldBeInConstantPool(APInt APImm,
return !(!SeqLo.empty() && (SeqLo.size() + 2) <= STI.getMaxBuildIntsCost());
}

bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
MachineIRBuilder &MIB) const {
const LLT XLenTy(STI.getXLenVT());
Register Dst = MI.getOperand(0).getReg();

// We define our scalable vector types for lmul=1 to use a 64 bit known
// minimum size. e.g. <vscale x 2 x i32>. VLENB is in bytes so we calculate
// vscale as VLENB / 8.
static_assert(RISCV::RVVBitsPerBlock == 64, "Unexpected bits per block!");
if (STI.getRealMinVLen() < RISCV::RVVBitsPerBlock)
// Support for VLEN==32 is incomplete.
return false;

// We assume VLENB is a multiple of 8. We manually choose the best shift
// here because SimplifyDemandedBits isn't always able to simplify it.
uint64_t Val = MI.getOperand(1).getCImm()->getZExtValue();
if (isPowerOf2_64(Val)) {
uint64_t Log2 = Log2_64(Val);
if (Log2 < 3) {
auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {});
MIB.buildLShr(Dst, VLENB, MIB.buildConstant(XLenTy, 3 - Log2));
} else if (Log2 > 3) {
auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {});
MIB.buildShl(Dst, VLENB, MIB.buildConstant(XLenTy, Log2 - 3));
} else {
MIB.buildInstr(RISCV::G_READ_VLENB, {Dst}, {});
}
} else if ((Val % 8) == 0) {
// If the multiplier is a multiple of 8, scale it down to avoid needing
// to shift the VLENB value.
auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {});
MIB.buildMul(Dst, VLENB, MIB.buildConstant(XLenTy, Val / 8));
} else {
auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {});
auto VScale = MIB.buildLShr(XLenTy, VLENB, MIB.buildConstant(XLenTy, 3));
MIB.buildMul(Dst, VScale, MIB.buildConstant(XLenTy, Val));
}

MI.eraseFromParent();
return true;
}

bool RISCVLegalizerInfo::legalizeCustom(
LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const {
Expand Down Expand Up @@ -584,6 +630,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
}
case TargetOpcode::G_VASTART:
return legalizeVAStart(MI, MIRBuilder);
case TargetOpcode::G_VSCALE:
return legalizeVScale(MI, MIRBuilder);
}

llvm_unreachable("expected switch to return");
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class RISCVLegalizerInfo : public LegalizerInfo {
GISelChangeObserver &Observer) const;

bool legalizeVAStart(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const;
};
} // end namespace llvm
#endif
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,11 @@ def G_FCLASS : RISCVGenericInstruction {
let hasSideEffects = false;
}
def : GINodeEquiv<G_FCLASS, riscv_fclass>;

// Pseudo equivalent to a RISCVISD::READ_VLENB.
def G_READ_VLENB : RISCVGenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins);
let hasSideEffects = false;
}
def : GINodeEquiv<G_READ_VLENB, riscv_read_vlenb>;
Loading

0 comments on commit 4768150

Please sign in to comment.