Skip to content

Commit

Permalink
[RISCV][llvm-mca] Use correct LMUL and SEW for strided loads and stor…
Browse files Browse the repository at this point in the history
…es (#76869)

The pseudos for strided loads and stores use the SEW coming from the
name. For example, vlse8 has SEW=8 and vlse16 has SEW=16.

When llvm-mca tries to lookup (VLSE8_V, SEW=S, LMUL=L) in the inverse
pseudo table, a result will only be found when S=8, where S was set from
the previous vsetvli instruction. Instead, for a match to be found, we
must lookup (VLSE8_V, SEW=8, LMUL=L') where L' is the EMUL which was
calculated by scaling the LMUL and SEW from the previous vsetvli and the
SEW=8.
  • Loading branch information
michaelmaitland committed Jan 4, 2024
1 parent 2ab5c47 commit 58f1640
Show file tree
Hide file tree
Showing 2 changed files with 385 additions and 9 deletions.
33 changes: 24 additions & 9 deletions llvm/lib/Target/RISCV/MCA/RISCVCustomBehaviour.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,30 +186,37 @@ RISCVInstrumentManager::createInstruments(const MCInst &Inst) {
}

static std::pair<uint8_t, uint8_t>
getEEWAndEMULForUnitStrideLoadStore(unsigned Opcode, RISCVII::VLMUL LMUL,
uint8_t SEW) {
getEEWAndEMUL(unsigned Opcode, RISCVII::VLMUL LMUL, uint8_t SEW) {
uint8_t EEW;
switch (Opcode) {
case RISCV::VLM_V:
case RISCV::VSM_V:
case RISCV::VLE8_V:
case RISCV::VSE8_V:
case RISCV::VLSE8_V:
case RISCV::VSSE8_V:
EEW = 8;
break;
case RISCV::VLE16_V:
case RISCV::VSE16_V:
case RISCV::VLSE16_V:
case RISCV::VSSE16_V:
EEW = 16;
break;
case RISCV::VLE32_V:
case RISCV::VSE32_V:
case RISCV::VLSE32_V:
case RISCV::VSSE32_V:
EEW = 32;
break;
case RISCV::VLE64_V:
case RISCV::VSE64_V:
case RISCV::VLSE64_V:
case RISCV::VSSE64_V:
EEW = 64;
break;
default:
llvm_unreachable("Opcode is not a vector unit stride load nor store");
llvm_unreachable("Could not determine EEW from Opcode");
}

auto EMUL = RISCVVType::getSameRatioLMUL(SEW, LMUL, EEW);
Expand All @@ -218,6 +225,18 @@ getEEWAndEMULForUnitStrideLoadStore(unsigned Opcode, RISCVII::VLMUL LMUL,
return std::make_pair(EEW, *EMUL);
}

bool opcodeHasEEWAndEMULInfo(unsigned short Opcode) {
return Opcode == RISCV::VLM_V || Opcode == RISCV::VSM_V ||
Opcode == RISCV::VLE8_V || Opcode == RISCV::VSE8_V ||
Opcode == RISCV::VLE16_V || Opcode == RISCV::VSE16_V ||
Opcode == RISCV::VLE32_V || Opcode == RISCV::VSE32_V ||
Opcode == RISCV::VLE64_V || Opcode == RISCV::VSE64_V ||
Opcode == RISCV::VLSE8_V || Opcode == RISCV::VSSE8_V ||
Opcode == RISCV::VLSE16_V || Opcode == RISCV::VSSE16_V ||
Opcode == RISCV::VLSE32_V || Opcode == RISCV::VSSE32_V ||
Opcode == RISCV::VLSE64_V || Opcode == RISCV::VSSE64_V;
}

unsigned RISCVInstrumentManager::getSchedClassID(
const MCInstrInfo &MCII, const MCInst &MCI,
const llvm::SmallVector<Instrument *> &IVec) const {
Expand Down Expand Up @@ -249,13 +268,9 @@ unsigned RISCVInstrumentManager::getSchedClassID(
uint8_t SEW = SI ? SI->getSEW() : 0;

const RISCVVInversePseudosTable::PseudoInfo *RVV = nullptr;
if (Opcode == RISCV::VLM_V || Opcode == RISCV::VSM_V ||
Opcode == RISCV::VLE8_V || Opcode == RISCV::VSE8_V ||
Opcode == RISCV::VLE16_V || Opcode == RISCV::VSE16_V ||
Opcode == RISCV::VLE32_V || Opcode == RISCV::VSE32_V ||
Opcode == RISCV::VLE64_V || Opcode == RISCV::VSE64_V) {
if (opcodeHasEEWAndEMULInfo(Opcode)) {
RISCVII::VLMUL VLMUL = static_cast<RISCVII::VLMUL>(LMUL);
auto [EEW, EMUL] = getEEWAndEMULForUnitStrideLoadStore(Opcode, VLMUL, SEW);
auto [EEW, EMUL] = getEEWAndEMUL(Opcode, VLMUL, SEW);
RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, EMUL, EEW);
} else {
// Check if it depends on LMUL and SEW
Expand Down

0 comments on commit 58f1640

Please sign in to comment.