Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Remove SEW operand for load/store and SEW-aware pseudos #90396

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/TargetParser/RISCVTargetParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ enum VLMUL : uint8_t {
LMUL_F2
};

enum VSEW : uint8_t {
SEW_8 = 0,
SEW_16,
SEW_32,
SEW_64,
};

enum {
TAIL_UNDISTURBED_MASK_UNDISTURBED = 0,
TAIL_AGNOSTIC = 1,
Expand Down
56 changes: 51 additions & 5 deletions llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/MC/MCInstrDesc.h"
#include "llvm/TargetParser/RISCVISAInfo.h"
#include "llvm/TargetParser/RISCVTargetParser.h"
#include "llvm/TargetParser/SubtargetFeature.h"
#include <cstdint>

namespace llvm {

Expand Down Expand Up @@ -123,6 +126,12 @@ enum {
// 3 -> widening case
TargetOverlapConstraintTypeShift = UsesVXRMShift + 1,
TargetOverlapConstraintTypeMask = 3ULL << TargetOverlapConstraintTypeShift,

HasImplictSEWShift = TargetOverlapConstraintTypeShift + 2,
HasImplictSEWMask = 1 << HasImplictSEWShift,

VSEWShift = HasImplictSEWShift + 1,
VSEWMask = 0b11 << VSEWShift,
};

// Helper functions to read TSFlags.
Expand Down Expand Up @@ -171,14 +180,29 @@ static inline bool hasRoundModeOp(uint64_t TSFlags) {
/// \returns true if this instruction uses vxrm
static inline bool usesVXRM(uint64_t TSFlags) { return TSFlags & UsesVXRMMask; }

/// \returns true if this instruction has implict SEW value.
static inline bool hasImplictSEW(uint64_t TSFlags) {
wangpc-pp marked this conversation as resolved.
Show resolved Hide resolved
return TSFlags & HasImplictSEWMask;
}

/// \returns the VSEW for the instruction.
static inline VSEW getVSEW(uint64_t TSFlags) {
return static_cast<VSEW>((TSFlags & VSEWMask) >> VSEWShift);
}

/// \returns true if there is a SEW value for the instruction.
static inline bool hasSEW(uint64_t TSFlags) {
return hasSEWOp(TSFlags) || hasImplictSEW(TSFlags);
}

static inline unsigned getVLOpNum(const MCInstrDesc &Desc) {
const uint64_t TSFlags = Desc.TSFlags;
// This method is only called if we expect to have a VL operand, and all
// instructions with VL also have SEW.
assert(hasSEWOp(TSFlags) && hasVLOp(TSFlags));
unsigned Offset = 2;
// This method is only called if we expect to have a VL operand.
assert(hasVLOp(TSFlags));
// Some instructions don't have SEW operand.
unsigned Offset = 1 + hasSEWOp(TSFlags);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think hasSEWOp returns bool. Is it correct to add unsigned + bool?

if (hasVecPolicyOp(TSFlags))
Offset = 3;
Offset = Offset + 1;
return Desc.getNumOperands() - Offset;
}

Expand All @@ -191,6 +215,28 @@ static inline unsigned getSEWOpNum(const MCInstrDesc &Desc) {
return Desc.getNumOperands() - Offset;
}

static inline unsigned getLog2SEW(uint64_t TSFlags) {
return 3 + RISCVII::getVSEW(TSFlags);
}

static inline MachineOperand getSEWOp(const MachineInstr &MI) {
uint64_t TSFlags = MI.getDesc().TSFlags;
assert(hasSEW(TSFlags) && "The instruction doesn't have SEW value!");
if (hasSEWOp(TSFlags))
return MI.getOperand(getSEWOpNum(MI.getDesc()));

return MachineOperand::CreateImm(getLog2SEW(TSFlags));
}

static inline unsigned getLog2SEW(const MachineInstr &MI) {
uint64_t TSFlags = MI.getDesc().TSFlags;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about to implement getLog2SEW with getSEWOp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this before, but it seems to be silly because for the code path that RISCVII::hasSEWOp returns false, we will create an immediate operand and then extract the imm.

assert(RISCVII::hasSEW(TSFlags) && "The instruction doesn't have SEW value!");
if (RISCVII::hasSEWOp(TSFlags))
return MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();

return getLog2SEW(TSFlags);
}

static inline unsigned getVecPolicyOpNum(const MCInstrDesc &Desc) {
assert(hasVecPolicyOp(Desc.TSFlags));
return Desc.getNumOperands() - 1;
Expand Down
35 changes: 23 additions & 12 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,11 @@ void RISCVDAGToDAGISel::addVectorLoadStoreOperands(
Operands.push_back(VL);

MVT XLenVT = Subtarget->getXLenVT();
SDValue SEWOp = CurDAG->getTargetConstant(Log2SEW, DL, XLenVT);
Operands.push_back(SEWOp);
// Add SEW operand if it is indexed or mask load/store instruction.
if (Log2SEW == 0 || IndexVT) {
SDValue SEWOp = CurDAG->getTargetConstant(Log2SEW, DL, XLenVT);
Operands.push_back(SEWOp);
}

// At the IR layer, all the masked load intrinsics have policy operands,
// none of the others do. All have passthru operands. For our pseudos,
Expand Down Expand Up @@ -2226,7 +2229,6 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
selectVLOp(Node->getOperand(2), VL);

unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits());
SDValue SEW = CurDAG->getTargetConstant(Log2SEW, DL, XLenVT);

// If VL=1, then we don't need to do a strided load and can just do a
// regular load.
Expand All @@ -2243,7 +2245,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
Operands.push_back(CurDAG->getRegister(RISCV::X0, XLenVT));
uint64_t Policy = RISCVII::MASK_AGNOSTIC | RISCVII::TAIL_AGNOSTIC;
SDValue PolicyOp = CurDAG->getTargetConstant(Policy, DL, XLenVT);
Operands.append({VL, SEW, PolicyOp, Ld->getChain()});
Operands.append({VL, PolicyOp, Ld->getChain()});

RISCVII::VLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
const RISCV::VLEPseudo *P = RISCV::getVLEPseudo(
Expand Down Expand Up @@ -2970,7 +2972,7 @@ static bool vectorPseudoHasAllNBitUsers(SDNode *User, unsigned UserOpNo,

const MCInstrDesc &MCID = TII->get(User->getMachineOpcode());
const uint64_t TSFlags = MCID.TSFlags;
if (!RISCVII::hasSEWOp(TSFlags))
if (!RISCVII::hasSEW(TSFlags))
return false;
assert(RISCVII::hasVLOp(TSFlags));

Expand All @@ -2980,7 +2982,9 @@ static bool vectorPseudoHasAllNBitUsers(SDNode *User, unsigned UserOpNo,
bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(TSFlags);
unsigned VLIdx =
User->getNumOperands() - HasVecPolicyOp - HasChainOp - HasGlueOp - 2;
const unsigned Log2SEW = User->getConstantOperandVal(VLIdx + 1);
const unsigned Log2SEW = RISCVII::hasSEWOp(TSFlags)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the assignment here backwards? If hasSEWOp is true then should we get it from getLog2SEW(TSFlags)?

? User->getConstantOperandVal(VLIdx + 1)
: RISCVII::getLog2SEW(TSFlags);

if (UserOpNo == VLIdx)
return false;
Expand Down Expand Up @@ -3696,12 +3700,18 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
return false;
}

SDLoc DL(N);

// The vector policy operand may be present for masked intrinsics
bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(TrueTSFlags);
unsigned TrueVLIndex =
True.getNumOperands() - HasVecPolicyOp - HasChainOp - HasGlueOp - 2;
bool HasSEWOp = RISCVII::hasSEWOp(TrueTSFlags);
unsigned TrueVLIndex = True.getNumOperands() - HasVecPolicyOp - HasChainOp -
HasGlueOp - 1 - HasSEWOp;
SDValue TrueVL = True.getOperand(TrueVLIndex);
SDValue SEW = True.getOperand(TrueVLIndex + 1);
SDValue SEW =
HasSEWOp ? True.getOperand(TrueVLIndex + 1)
: CurDAG->getTargetConstant(RISCVII::getLog2SEW(TrueTSFlags), DL,
Subtarget->getXLenVT());

auto GetMinVL = [](SDValue LHS, SDValue RHS) {
if (LHS == RHS)
Expand Down Expand Up @@ -3732,8 +3742,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
!True->getFlags().hasNoFPExcept())
return false;

SDLoc DL(N);

// From the preconditions we checked above, we know the mask and thus glue
// for the result node will be taken from True.
if (IsMasked) {
Expand Down Expand Up @@ -3799,7 +3807,10 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
if (HasRoundingMode)
Ops.push_back(True->getOperand(TrueVLIndex - 1));

Ops.append({VL, SEW, PolicyOp});
Ops.push_back(VL);
if (RISCVII::hasSEWOp(TrueTSFlags))
Ops.push_back(SEW);
Ops.push_back(PolicyOp);

// Result node should have chain operand of True.
if (HasChainOp)
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17857,7 +17857,6 @@ static MachineBasicBlock *emitVFROUND_NOEXCEPT_MASK(MachineInstr &MI,
.add(MI.getOperand(3))
.add(MachineOperand::CreateImm(7)) // frm = DYN
.add(MI.getOperand(4))
.add(MI.getOperand(5))
.add(MI.getOperand(6))
.add(MachineOperand::CreateReg(RISCV::FRM,
/*IsDef*/ false,
Expand Down
20 changes: 8 additions & 12 deletions llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ static unsigned getVLOpNum(const MachineInstr &MI) {
return RISCVII::getVLOpNum(MI.getDesc());
}

static unsigned getSEWOpNum(const MachineInstr &MI) {
return RISCVII::getSEWOpNum(MI.getDesc());
}

static bool isVectorConfigInstr(const MachineInstr &MI) {
return MI.getOpcode() == RISCV::PseudoVSETVLI ||
MI.getOpcode() == RISCV::PseudoVSETVLIX0 ||
Expand Down Expand Up @@ -166,9 +162,9 @@ static bool isNonZeroLoadImmediate(const MachineInstr &MI) {
/// Return true if this is an operation on mask registers. Note that
/// this includes both arithmetic/logical ops and load/store (vlm/vsm).
static bool isMaskRegOp(const MachineInstr &MI) {
if (!RISCVII::hasSEWOp(MI.getDesc().TSFlags))
if (!RISCVII::hasSEW(MI.getDesc().TSFlags))
return false;
const unsigned Log2SEW = MI.getOperand(getSEWOpNum(MI)).getImm();
const unsigned Log2SEW = RISCVII::getLog2SEW(MI);
// A Log2SEW of 0 is an operation on mask registers only.
return Log2SEW == 0;
}
Expand Down Expand Up @@ -383,7 +379,7 @@ DemandedFields getDemanded(const MachineInstr &MI,
Res.demandVTYPE();
// Start conservative on the unlowered form too
uint64_t TSFlags = MI.getDesc().TSFlags;
if (RISCVII::hasSEWOp(TSFlags)) {
if (RISCVII::hasSEW(TSFlags)) {
Res.demandVTYPE();
if (RISCVII::hasVLOp(TSFlags))
Res.demandVL();
Expand All @@ -405,7 +401,7 @@ DemandedFields getDemanded(const MachineInstr &MI,
}

// Store instructions don't use the policy fields.
if (RISCVII::hasSEWOp(TSFlags) && MI.getNumExplicitDefs() == 0) {
if (RISCVII::hasSEW(TSFlags) && MI.getNumExplicitDefs() == 0) {
Res.TailPolicy = false;
Res.MaskPolicy = false;
}
Expand Down Expand Up @@ -940,7 +936,7 @@ static VSETVLIInfo computeInfoForInstr(const MachineInstr &MI, uint64_t TSFlags,

RISCVII::VLMUL VLMul = RISCVII::getLMul(TSFlags);

unsigned Log2SEW = MI.getOperand(getSEWOpNum(MI)).getImm();
unsigned Log2SEW = RISCVII::getLog2SEW(MI);
// A Log2SEW of 0 is an operation on mask registers only.
unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
Expand Down Expand Up @@ -1176,7 +1172,7 @@ static VSETVLIInfo adjustIncoming(VSETVLIInfo PrevInfo, VSETVLIInfo NewInfo,
void RISCVInsertVSETVLI::transferBefore(VSETVLIInfo &Info,
const MachineInstr &MI) const {
uint64_t TSFlags = MI.getDesc().TSFlags;
if (!RISCVII::hasSEWOp(TSFlags))
if (!RISCVII::hasSEW(TSFlags))
return;

const VSETVLIInfo NewInfo = computeInfoForInstr(MI, TSFlags, *ST, MRI);
Expand Down Expand Up @@ -1256,7 +1252,7 @@ bool RISCVInsertVSETVLI::computeVLVTYPEChanges(const MachineBasicBlock &MBB,
for (const MachineInstr &MI : MBB) {
transferBefore(Info, MI);

if (isVectorConfigInstr(MI) || RISCVII::hasSEWOp(MI.getDesc().TSFlags))
if (isVectorConfigInstr(MI) || RISCVII::hasSEW(MI.getDesc().TSFlags))
HadVectorOp = true;

transferAfter(Info, MI);
Expand Down Expand Up @@ -1385,7 +1381,7 @@ void RISCVInsertVSETVLI::emitVSETVLIs(MachineBasicBlock &MBB) {
}

uint64_t TSFlags = MI.getDesc().TSFlags;
if (RISCVII::hasSEWOp(TSFlags)) {
if (RISCVII::hasSEW(TSFlags)) {
if (PrevInfo != CurInfo) {
// If this is the first implicit state change, and the state change
// requested can be proven to produce the same register contents, we
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ class RVInstCommon<dag outs, dag ins, string opcodestr, string argstr,
// 3 -> widening case
bits<2> TargetOverlapConstraintType = 0;
let TSFlags{22-21} = TargetOverlapConstraintType;

bit HasImplictSEW = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you add a comment what it means to have an implicit SEW?

let TSFlags{23} = HasImplictSEW;

// The actual SEW value is 8 * (2 ^ VSEW).
bits<2> VSEW = 0;
let TSFlags{25-24} = VSEW;
}

class RVInst<dag outs, dag ins, string opcodestr, string argstr,
Expand Down
22 changes: 8 additions & 14 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ static bool isConvertibleToVMV_V_V(const RISCVSubtarget &STI,

// If the producing instruction does not depend on vsetvli, do not
// convert COPY to vmv.v.v. For example, VL1R_V or PseudoVRELOAD.
if (!RISCVII::hasSEWOp(TSFlags) || !RISCVII::hasVLOp(TSFlags))
if (!RISCVII::hasSEW(TSFlags) || !RISCVII::hasVLOp(TSFlags))
return false;

// Found the definition.
Expand Down Expand Up @@ -410,9 +410,9 @@ void RISCVInstrInfo::copyPhysRegVector(
MIB = MIB.addReg(ActualSrcReg, getKillRegState(KillSrc));
if (UseVMV) {
const MCInstrDesc &Desc = DefMBBI->getDesc();
MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
MIB.addImm(0); // tu, mu
MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
MIB.add(RISCVII::getSEWOp(*DefMBBI)); // SEW
MIB.addImm(0); // tu, mu
MIB.addReg(RISCV::VL, RegState::Implicit);
MIB.addReg(RISCV::VTYPE, RegState::Implicit);
}
Expand Down Expand Up @@ -1706,8 +1706,7 @@ bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &Root,
return false;

// SEW
if (RISCVII::hasSEWOp(TSFlags) &&
!checkImmOperand(RISCVII::getSEWOpNum(Desc)))
if (RISCVII::hasSEW(TSFlags) && !checkImmOperand(RISCVII::getSEWOpNum(Desc)))
return false;

// Mask
Expand Down Expand Up @@ -2463,10 +2462,6 @@ bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
return false;
}
}
if (!RISCVII::hasSEWOp(TSFlags)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check hasSEWOp && !hasVLOperand that would throw an error SEW operand without VL operand.

ErrInfo = "VL operand w/o SEW operand?";
return false;
}
}
if (RISCVII::hasSEWOp(TSFlags)) {
unsigned OpIdx = RISCVII::getSEWOpNum(Desc);
Expand Down Expand Up @@ -3521,8 +3516,8 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI,
case CASE_FP_WIDEOP_OPCODE_LMULS_MF4(FWADD_WV):
case CASE_FP_WIDEOP_OPCODE_LMULS_MF4(FWSUB_WV): {
assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags) &&
MI.getNumExplicitOperands() == 7 &&
"Expect 7 explicit operands rd, rs2, rs1, rm, vl, sew, policy");
MI.getNumExplicitOperands() == 6 &&
"Expect 6 explicit operands rd, rs2, rs1, rm, vl, policy");
// If the tail policy is undisturbed we can't convert.
if ((MI.getOperand(RISCVII::getVecPolicyOpNum(MI.getDesc())).getImm() &
1) == 0)
Expand All @@ -3545,8 +3540,7 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI,
.add(MI.getOperand(2))
.add(MI.getOperand(3))
.add(MI.getOperand(4))
.add(MI.getOperand(5))
.add(MI.getOperand(6));
.add(MI.getOperand(5));
break;
}
case CASE_WIDEOP_OPCODE_LMULS(WADD_WV):
Expand Down
Loading
Loading