diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp index 9a21251aac8828..c406ba382c28a7 100644 --- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp +++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp @@ -171,6 +171,7 @@ class RISCVAsmParser : public MCTargetAsmParser { OperandMatchResultTy parseMaskReg(OperandVector &Operands); OperandMatchResultTy parseInsnDirectiveOpcode(OperandVector &Operands); OperandMatchResultTy parseGPRAsFPR(OperandVector &Operands); + OperandMatchResultTy parseFRMArg(OperandVector &Operands); bool parseOperand(OperandVector &Operands, StringRef Mnemonic); @@ -276,6 +277,7 @@ struct RISCVOperand : public MCParsedAsmOperand { Immediate, SystemRegister, VType, + FRM, } Kind; bool IsRV64; @@ -302,6 +304,10 @@ struct RISCVOperand : public MCParsedAsmOperand { unsigned Val; }; + struct FRMOp { + RISCVFPRndMode::RoundingMode FRM; + }; + SMLoc StartLoc, EndLoc; union { StringRef Tok; @@ -309,6 +315,7 @@ struct RISCVOperand : public MCParsedAsmOperand { ImmOp Imm; struct SysRegOp SysReg; struct VTypeOp VType; + struct FRMOp FRM; }; RISCVOperand(KindTy K) : Kind(K) {} @@ -335,6 +342,9 @@ struct RISCVOperand : public MCParsedAsmOperand { case KindTy::VType: VType = o.VType; break; + case KindTy::FRM: + FRM = o.FRM; + break; } } @@ -493,18 +503,7 @@ struct RISCVOperand : public MCParsedAsmOperand { } /// Return true if the operand is a valid floating point rounding mode. - bool isFRMArg() const { - if (!isImm()) - return false; - const MCExpr *Val = getImm(); - auto *SVal = dyn_cast(Val); - if (!SVal || SVal->getKind() != MCSymbolRefExpr::VK_None) - return false; - - StringRef Str = SVal->getSymbol().getName(); - - return RISCVFPRndMode::stringToRoundingMode(Str) != RISCVFPRndMode::Invalid; - } + bool isFRMArg() const { return Kind == KindTy::FRM; } bool isImmXLenLI() const { int64_t Imm; @@ -814,6 +813,11 @@ struct RISCVOperand : public MCParsedAsmOperand { return VType.Val; } + RISCVFPRndMode::RoundingMode getFRM() const { + assert(Kind == KindTy::FRM && "Invalid type access!"); + return FRM.FRM; + } + void print(raw_ostream &OS) const override { auto RegName = [](MCRegister Reg) { if (Reg) @@ -840,6 +844,11 @@ struct RISCVOperand : public MCParsedAsmOperand { RISCVVType::printVType(getVType(), OS); OS << '>'; break; + case KindTy::FRM: + OS << "'; + break; } } @@ -887,6 +896,16 @@ struct RISCVOperand : public MCParsedAsmOperand { return Op; } + static std::unique_ptr + createFRMArg(RISCVFPRndMode::RoundingMode FRM, SMLoc S, bool IsRV64) { + auto Op = std::make_unique(KindTy::FRM); + Op->FRM.FRM = FRM; + Op->StartLoc = S; + Op->EndLoc = S; + Op->IsRV64 = IsRV64; + return Op; + } + static std::unique_ptr createVType(unsigned VTypeI, SMLoc S, bool IsRV64) { auto Op = std::make_unique(KindTy::VType); @@ -980,20 +999,9 @@ struct RISCVOperand : public MCParsedAsmOperand { Inst.addOperand(MCOperand::createImm(Imm)); } - // Returns the rounding mode represented by this RISCVOperand. Should only - // be called after checking isFRMArg. - RISCVFPRndMode::RoundingMode getRoundingMode() const { - // isFRMArg has validated the operand, meaning this cast is safe. - auto SE = cast(getImm()); - RISCVFPRndMode::RoundingMode FRM = - RISCVFPRndMode::stringToRoundingMode(SE->getSymbol().getName()); - assert(FRM != RISCVFPRndMode::Invalid && "Invalid rounding mode"); - return FRM; - } - void addFRMArgOperands(MCInst &Inst, unsigned N) const { assert(N == 1 && "Invalid number of operands!"); - Inst.addOperand(MCOperand::createImm(getRoundingMode())); + Inst.addOperand(MCOperand::createImm(getFRM())); } }; } // end anonymous namespace. @@ -1254,12 +1262,6 @@ bool RISCVAsmParser::MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode, return Error(ErrorLoc, "operand must be formed of letters selected " "in-order from 'iorw' or be 0"); } - case Match_InvalidFRMArg: { - SMLoc ErrorLoc = ((RISCVOperand &)*Operands[ErrorInfo]).getStartLoc(); - return Error( - ErrorLoc, - "operand must be a valid floating point rounding mode mnemonic"); - } case Match_InvalidBareSymbol: { SMLoc ErrorLoc = ((RISCVOperand &)*Operands[ErrorInfo]).getStartLoc(); return Error(ErrorLoc, "operand must be a bare symbol name"); @@ -1837,6 +1839,25 @@ OperandMatchResultTy RISCVAsmParser::parseGPRAsFPR(OperandVector &Operands) { return MatchOperand_Success; } +OperandMatchResultTy RISCVAsmParser::parseFRMArg(OperandVector &Operands) { + if (getLexer().isNot(AsmToken::Identifier)) { + TokError("operand must be a valid floating point rounding mode mnemonic"); + return MatchOperand_ParseFail; + } + + StringRef Str = getLexer().getTok().getIdentifier(); + RISCVFPRndMode::RoundingMode FRM = RISCVFPRndMode::stringToRoundingMode(Str); + + if (FRM == RISCVFPRndMode::Invalid) { + TokError("operand must be a valid floating point rounding mode mnemonic"); + return MatchOperand_ParseFail; + } + + Operands.push_back(RISCVOperand::createFRMArg(FRM, getLoc(), isRV64())); + Lex(); // Eat identifier token. + return MatchOperand_Success; +} + OperandMatchResultTy RISCVAsmParser::parseMemOpBaseReg(OperandVector &Operands) { if (getLexer().isNot(AsmToken::LParen)) { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td index 92d8a2bab4c0db..344a3ae173a41f 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td @@ -140,7 +140,7 @@ defvar FXIN64X = [FX_64, FX_INX_64]; def FRMArg : AsmOperandClass { let Name = "FRMArg"; let RenderMethod = "addFRMArgOperands"; - let DiagnosticType = "InvalidFRMArg"; + let ParserMethod = "parseFRMArg"; } def frmarg : Operand {