Skip to content

Commit

Permalink
[RISCV][MC] Add FLI instruction support for the experimental zfa exte…
Browse files Browse the repository at this point in the history
…nsion

This implements experimental support for the RISCV Zfa extension as specified here: https://github.com/riscv/riscv-isa-manual/releases/download/draft-20221119-5234c63/riscv-spec.pdf, Ch. 25. This extension has not been ratified. Once ratified, it'll move out of experimental status.

This change adds assembly support for load-immediate instructions (fli.s/fli.d/fli.h). The assembly prefers decimal constants in C-like syntax. In my implementation, an integer encoding ranging from 0 to 31 can also be accepted, but for the MCInst printer, the constant is specified in decimal notation by default.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D140460
  • Loading branch information
joshua-arch1 committed Mar 7, 2023
1 parent 97f6283 commit 8a002d4
Show file tree
Hide file tree
Showing 7 changed files with 816 additions and 0 deletions.
101 changes: 101 additions & 0 deletions llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
Expand Up @@ -158,6 +158,7 @@ class RISCVAsmParser : public MCTargetAsmParser {
#include "RISCVGenAsmMatcher.inc"

OperandMatchResultTy parseCSRSystemRegister(OperandVector &Operands);
OperandMatchResultTy parseFPImm(OperandVector &Operands);
OperandMatchResultTy parseImmediate(OperandVector &Operands);
OperandMatchResultTy parseRegister(OperandVector &Operands,
bool AllowParens = false);
Expand Down Expand Up @@ -274,6 +275,7 @@ struct RISCVOperand final : public MCParsedAsmOperand {
Token,
Register,
Immediate,
FPImmediate,
SystemRegister,
VType,
FRM,
Expand All @@ -290,6 +292,10 @@ struct RISCVOperand final : public MCParsedAsmOperand {
bool IsRV64;
};

struct FPImmOp {
uint64_t Val;
};

struct SysRegOp {
const char *Data;
unsigned Length;
Expand All @@ -315,6 +321,7 @@ struct RISCVOperand final : public MCParsedAsmOperand {
StringRef Tok;
RegOp Reg;
ImmOp Imm;
FPImmOp FPImm;
struct SysRegOp SysReg;
struct VTypeOp VType;
struct FRMOp FRM;
Expand All @@ -335,6 +342,9 @@ struct RISCVOperand final : public MCParsedAsmOperand {
case KindTy::Immediate:
Imm = o.Imm;
break;
case KindTy::FPImmediate:
FPImm = o.FPImm;
break;
case KindTy::Token:
Tok = o.Tok;
break;
Expand Down Expand Up @@ -480,6 +490,12 @@ struct RISCVOperand final : public MCParsedAsmOperand {
bool isFRMArg() const { return Kind == KindTy::FRM; }
bool isRTZArg() const { return isFRMArg() && FRM.FRM == RISCVFPRndMode::RTZ; }

/// Return true if the operand is a valid fli.s floating-point immediate.
bool isLoadFPImm() const {
return Kind == KindTy::FPImmediate &&
RISCVLoadFPImm::getLoadFP32Imm(APInt(32, getFPConst())) != -1;
}

bool isImmXLenLI() const {
int64_t Imm;
RISCVMCExpr::VariantKind VK = RISCVMCExpr::VK_RISCV_None;
Expand Down Expand Up @@ -793,6 +809,11 @@ struct RISCVOperand final : public MCParsedAsmOperand {
return Imm.Val;
}

uint64_t getFPConst() const {
assert(Kind == KindTy::FPImmediate && "Invalid type access!");
return FPImm.Val;
}

StringRef getToken() const {
assert(Kind == KindTy::Token && "Invalid type access!");
return Tok;
Expand Down Expand Up @@ -825,6 +846,8 @@ struct RISCVOperand final : public MCParsedAsmOperand {
case KindTy::Immediate:
OS << *getImm();
break;
case KindTy::FPImmediate:
break;
case KindTy::Register:
OS << "<register " << RegName(getReg()) << ">";
break;
Expand Down Expand Up @@ -880,6 +903,14 @@ struct RISCVOperand final : public MCParsedAsmOperand {
return Op;
}

static std::unique_ptr<RISCVOperand> createFPImm(uint64_t Val, SMLoc S) {
auto Op = std::make_unique<RISCVOperand>(KindTy::FPImmediate);
Op->FPImm.Val = Val;
Op->StartLoc = S;
Op->EndLoc = S;
return Op;
}

static std::unique_ptr<RISCVOperand> createSysReg(StringRef Str, SMLoc S,
unsigned Encoding) {
auto Op = std::make_unique<RISCVOperand>(KindTy::SystemRegister);
Expand Down Expand Up @@ -940,6 +971,12 @@ struct RISCVOperand final : public MCParsedAsmOperand {
addExpr(Inst, getImm(), isRV64Imm());
}

void addFPImmOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
int Imm = RISCVLoadFPImm::getLoadFP32Imm(APInt(32, getFPConst()));
Inst.addOperand(MCOperand::createImm(Imm));
}

void addFenceArgOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
Inst.addOperand(MCOperand::createImm(Fence.Val));
Expand Down Expand Up @@ -1246,6 +1283,10 @@ bool RISCVAsmParser::MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode,
"operand must be a valid system register "
"name or an integer in the range");
}
case Match_InvalidLoadFPImm: {
SMLoc ErrorLoc = ((RISCVOperand &)*Operands[ErrorInfo]).getStartLoc();
return Error(ErrorLoc, "operand must be a valid floating-point constant");
}
case Match_InvalidBareSymbol: {
SMLoc ErrorLoc = ((RISCVOperand &)*Operands[ErrorInfo]).getStartLoc();
return Error(ErrorLoc, "operand must be a bare symbol name");
Expand Down Expand Up @@ -1519,6 +1560,66 @@ RISCVAsmParser::parseCSRSystemRegister(OperandVector &Operands) {
return MatchOperand_NoMatch;
}

OperandMatchResultTy RISCVAsmParser::parseFPImm(OperandVector &Operands) {
SMLoc S = getLoc();

// Handle negation, as that still comes through as a separate token.
bool IsNegative = parseOptionalToken(AsmToken::Minus);

const AsmToken &Tok = getTok();
if (!Tok.is(AsmToken::Real) && !Tok.is(AsmToken::Integer) &&
!Tok.is(AsmToken::Identifier)) {
TokError("invalid floating point immediate");
return MatchOperand_NoMatch;
}

// Parse special floats (inf/nan/min) representation.
if (Tok.is(AsmToken::Identifier)) {
if (Tok.getString().compare_insensitive("inf") == 0) {
APFloat SpecialVal = APFloat::getInf(APFloat::IEEEsingle());
Operands.push_back(RISCVOperand::createFPImm(
SpecialVal.bitcastToAPInt().getZExtValue(), S));
} else if (Tok.getString().compare_insensitive("nan") == 0) {
APFloat SpecialVal = APFloat::getNaN(APFloat::IEEEsingle());
Operands.push_back(RISCVOperand::createFPImm(
SpecialVal.bitcastToAPInt().getZExtValue(), S));
} else if (Tok.getString().compare_insensitive("min") == 0) {
unsigned SpecialVal = RISCVLoadFPImm::getFPImm(1);
Operands.push_back(RISCVOperand::createFPImm(SpecialVal, S));
} else {
TokError("invalid floating point literal");
return MatchOperand_ParseFail;
}
} else if (Tok.is(AsmToken::Integer)) {
// Parse integer representation.
if (Tok.getIntVal() > 31 || IsNegative) {
TokError("encoded floating point value out of range");
return MatchOperand_ParseFail;
}
unsigned F = RISCVLoadFPImm::getFPImm(Tok.getIntVal());
Operands.push_back(RISCVOperand::createFPImm(F, S));
} else {
// Parse FP representation.
APFloat RealVal(APFloat::IEEEsingle());
auto StatusOrErr =
RealVal.convertFromString(Tok.getString(), APFloat::rmTowardZero);
if (errorToBool(StatusOrErr.takeError())) {
TokError("invalid floating point representation");
return MatchOperand_ParseFail;
}

if (IsNegative)
RealVal.changeSign();

Operands.push_back(RISCVOperand::createFPImm(
RealVal.bitcastToAPInt().getZExtValue(), S));
}

Lex(); // Eat the token.

return MatchOperand_Success;
}

OperandMatchResultTy RISCVAsmParser::parseImmediate(OperandVector &Operands) {
SMLoc S = getLoc();
SMLoc E;
Expand Down
124 changes: 124 additions & 0 deletions llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
Expand Up @@ -14,6 +14,8 @@
#define LLVM_LIB_TARGET_RISCV_MCTARGETDESC_RISCVBASEINFO_H

#include "MCTargetDesc/RISCVMCTargetDesc.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/MC/MCInstrDesc.h"
Expand Down Expand Up @@ -340,6 +342,128 @@ inline static bool isValidRoundingMode(unsigned Mode) {
}
} // namespace RISCVFPRndMode

//===----------------------------------------------------------------------===//
// Floating-point Immediates
//

// We expect an 5-bit binary encoding of a floating-point constant here.
static const std::pair<uint8_t, uint8_t> LoadFPImmArr[] = {
{0b00000001, 0b000}, {0b01101111, 0b000}, {0b01110000, 0b000},
{0b01110111, 0b000}, {0b01111000, 0b000}, {0b01111011, 0b000},
{0b01111100, 0b000}, {0b01111101, 0b000}, {0b01111101, 0b010},
{0b01111101, 0b100}, {0b01111101, 0b110}, {0b01111110, 0b000},
{0b01111110, 0b010}, {0b01111110, 0b100}, {0b01111110, 0b110},
{0b01111111, 0b000}, {0b01111111, 0b010}, {0b01111111, 0b100},
{0b01111111, 0b110}, {0b10000000, 0b000}, {0b10000000, 0b010},
{0b10000000, 0b100}, {0b10000001, 0b000}, {0b10000010, 0b000},
{0b10000011, 0b000}, {0b10000110, 0b000}, {0b10000111, 0b000},
{0b10001110, 0b000}, {0b10001111, 0b000}, {0b11111111, 0b000},
{0b11111111, 0b100},
};

static inline int getLoadFPImm(uint8_t Sign, uint8_t Exp, uint8_t Mantissa) {
if (Sign == 0b1 && Exp == 0b01111111 && Mantissa == 0b000)
return 0;

if (Sign == 0b0) {
auto EMI = llvm::find(LoadFPImmArr, std::make_pair(Exp, Mantissa));
if (EMI != std::end(LoadFPImmArr))
return std::distance(std::begin(LoadFPImmArr), EMI) + 1;
}

return -1;
}

namespace RISCVLoadFPImm {
inline static uint32_t getFPImm(unsigned Imm) {
uint8_t Sign;
uint8_t Exp;
uint8_t Mantissa;

if (Imm == 0) {
Sign = 0b1;
Exp = 0b01111111;
Mantissa = 0b000;
} else {
Sign = 0b0;
Exp = LoadFPImmArr[Imm - 1].first;
Mantissa = LoadFPImmArr[Imm - 1].second;
}

return Sign << 31 | Exp << 23 | Mantissa << 20;
}

/// getLoadFP32Imm - Return a 5-bit binary encoding of the 32-bit
/// floating-point immediate value. If the value cannot be represented as a
/// 5-bit binary encoding, then return -1.
static inline int getLoadFP32Imm(const APInt &Imm) {
if ((Imm.extractBitsAsZExtValue(9, 23) == 0b001110001 &&
Imm.extractBitsAsZExtValue(23, 0) == 0) ||
Imm.extractBitsAsZExtValue(32, 0) == 0)
return 1;

if (Imm.extractBitsAsZExtValue(20, 0) != 0)
return -1;

uint8_t Sign = Imm.extractBitsAsZExtValue(1, 31);
uint8_t Exp = Imm.extractBitsAsZExtValue(8, 23);
uint8_t Mantissa = Imm.extractBitsAsZExtValue(3, 20);
return getLoadFPImm(Sign, Exp, Mantissa);
}

static inline int getLoadFP32Imm(const APFloat &FPImm) {
return getLoadFP32Imm(FPImm.bitcastToAPInt());
}

/// getLoadFP64Imm - Return a 5-bit binary encoding of the 64-bit
/// floating-point immediate value. If the value cannot be represented as a
/// 5-bit binary encoding, then return -1.
static inline int getLoadFP64Imm(const APInt &Imm) {
if (Imm.extractBitsAsZExtValue(49, 0) != 0)
return -1;

uint8_t Sign = Imm.extractBitsAsZExtValue(1, 63);
uint8_t Mantissa = Imm.extractBitsAsZExtValue(3, 49);
uint8_t Exp;
if (Imm.extractBitsAsZExtValue(11, 52) == 1)
Exp = 0b00000001;
else if (Imm.extractBitsAsZExtValue(11, 52) == 2047)
Exp = 0b11111111;
else
Exp = Imm.extractBitsAsZExtValue(11, 52) - 1023 + 127;

return getLoadFPImm(Sign, Exp, Mantissa);
}

static inline int getLoadFP64Imm(const APFloat &FPImm) {
return getLoadFP64Imm(FPImm.bitcastToAPInt());
}

/// getLoadFP16Imm - Return a 5-bit binary encoding of the 16-bit
/// floating-point immediate value. If the value cannot be represented as a
/// 5-bit binary encoding, then return -1.
static inline int getLoadFP16Imm(const APInt &Imm) {
if (Imm.extractBitsAsZExtValue(7, 0) != 0)
return -1;

uint8_t Sign = Imm.extractBitsAsZExtValue(1, 15);
uint8_t Mantissa = Imm.extractBitsAsZExtValue(3, 7);
uint8_t Exp;
if (Imm.extractBitsAsZExtValue(5, 10) == 1)
Exp = 0b00000001;
else if (Imm.extractBitsAsZExtValue(5, 10) == 31)
Exp = 0b11111111;
else
Exp = Imm.extractBitsAsZExtValue(5, 10) - 15 + 127;

return getLoadFPImm(Sign, Exp, Mantissa);
}

static inline int getLoadFP16Imm(const APFloat &FPImm) {
return getLoadFP16Imm(FPImm.bitcastToAPInt());
}
} // namespace RISCVLoadFPImm

namespace RISCVSysReg {
struct SysReg {
const char *Name;
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp
Expand Up @@ -154,6 +154,20 @@ void RISCVInstPrinter::printFRMArg(const MCInst *MI, unsigned OpNo,
O << ", " << RISCVFPRndMode::roundingModeToString(FRMArg);
}

void RISCVInstPrinter::printFPImmOperand(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O) {
const MCOperand &MO = MI->getOperand(OpNo);
if (MO.getImm() == 1)
O << "min";
else if (MO.getImm() == 30)
O << "inf";
else if (MO.getImm() == 31)
O << "nan";
else
O << bit_cast<float>(RISCVLoadFPImm::getFPImm(MO.getImm()));
}

void RISCVInstPrinter::printZeroOffsetMemOp(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O) {
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h
Expand Up @@ -40,6 +40,8 @@ class RISCVInstPrinter : public MCInstPrinter {
const MCSubtargetInfo &STI, raw_ostream &O);
void printFRMArg(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
raw_ostream &O);
void printFPImmOperand(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printZeroOffsetMemOp(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI, raw_ostream &O);
void printVTypeI(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
Expand Down

0 comments on commit 8a002d4

Please sign in to comment.