Skip to content

Commit

Permalink
[AArch64]SME2 Multi-vector - Index/Single/Multi Array Vectors FMA sou…
Browse files Browse the repository at this point in the history
…rces

This patch adds the assembly/disassembly for the following instructions:
  INT:
     SMLAL
     SMLSL
     UMLAL
     UMLSL
  FP:
    BFMLAL
    BFMLSL
    FMLAL
    FMLSL
For multiple and indexed vector, Multiple and Single vector and
Multi vectors, for 1, 2 and 4 ZA registers.

The reference can be found here:

https://developer.arm.com/documentation/ddi0602/2022-09

It also adds a new immediate:
  uimm3s2range for off3
  uimm2s2range for off2
to represent the vector select offset.
The new operands have the range between the first and the last vector position.

Depends on: D135563

Reviewed By: aemerson, sdesmalen

Differential Revision: https://reviews.llvm.org/D135785
  • Loading branch information
CarolineConcatto committed Oct 20, 2022
1 parent d12d485 commit 3fee935
Show file tree
Hide file tree
Showing 22 changed files with 9,283 additions and 8 deletions.
25 changes: 25 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Expand Up @@ -1373,6 +1373,31 @@ def sme_elm_idx0_15 : Operand<i64>, ImmLeaf<i64, [{
let PrintMethod = "printMatrixIndex";
}

class UImmScaledMemoryIndexedRange<int Width, int Scale, int OffsetVal> : AsmOperandClass {
let Name = "UImm" # Width # "s" # Scale # "Range";
let DiagnosticType = "InvalidMemoryIndexedRange" # Scale # "UImm" # Width;
let RenderMethod = "addImmScaledRangeOperands<" # Scale # ">";
let PredicateMethod = "isUImmScaled<" # Width # ", " # Scale # ", " # OffsetVal # ", /*IsRange=*/true>";
let ParserMethod = "tryParseImmRange";
}

def UImm2s2RangeOperand : UImmScaledMemoryIndexedRange<2, 2, 1>;
def UImm3s2RangeOperand : UImmScaledMemoryIndexedRange<3, 2, 1>;

def uimm2s2range : Operand<i64>, ImmLeaf<i64,
[{ return Imm >= 0 && Imm <= 6 && ((Imm % 2) == 0); }], UImmS2XForm> {
let PrintMethod = "printImmRangeScale<2, 1>";
let ParserMatchClass = UImm2s2RangeOperand;
}


def uimm3s2range : Operand<i64>, ImmLeaf<i64,
[{ return Imm >= 0 && Imm <= 14 && ((Imm % 2) == 0); }], UImmS2XForm> {
let PrintMethod = "printImmRangeScale<2, 1>";
let ParserMatchClass = UImm3s2RangeOperand;
}


// 8-bit immediate for AdvSIMD where 64-bit values of the form:
// aaaaaaaa bbbbbbbb cccccccc dddddddd eeeeeeee ffffffff gggggggg hhhhhhhh
// are encoded as the eight bit value 'abcdefgh'.
Expand Down
72 changes: 72 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Expand Up @@ -272,6 +272,78 @@ defm ADD_VG4_4ZZ : sme2_sqdmulh_add_vector_vg4_single<"add", 0b011000>;

defm SQDMULH_2ZZ : sme2_sqdmulh_add_vector_vg2_single<"sqdmulh", 0b100000>;
defm SQDMULH_4ZZ : sme2_sqdmulh_add_vector_vg4_single<"sqdmulh", 0b100000>;

defm FMLAL_MZZI : sme2_mla_long_array_index<"fmlal", 0b10, 0b00>;
defm FMLAL_VG2_M2ZZI : sme2_fp_mla_long_array_vg2_index<"fmlal", 0b00>;
defm FMLAL_VG4_M4ZZI : sme2_fp_mla_long_array_vg4_index<"fmlal", 0b00>;
defm FMLAL_MZZ : sme2_mla_long_array_single<"fmlal", 0b00, 0b00>;
defm FMLAL_VG2_M2ZZ : sme2_fp_mla_long_array_vg2_single<"fmlal", 0b00>;
defm FMLAL_VG4_M4ZZ : sme2_fp_mla_long_array_vg4_single<"fmlal", 0b00>;
defm FMLAL_VG2_M2Z2Z : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b00>;
defm FMLAL_VG4_M4Z4Z : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b00>;

defm FMLSL_MZZI : sme2_mla_long_array_index<"fmlsl", 0b10, 0b01>;
defm FMLSL_VG2_M2ZZI : sme2_fp_mla_long_array_vg2_index<"fmlsl", 0b01>;
defm FMLSL_VG4_M4ZZI : sme2_fp_mla_long_array_vg4_index<"fmlsl", 0b01>;
defm FMLSL_MZZ : sme2_mla_long_array_single<"fmlsl", 0b00, 0b01>;
defm FMLSL_VG2_M2ZZ : sme2_fp_mla_long_array_vg2_single<"fmlsl", 0b01>;
defm FMLSL_VG4_M4ZZ : sme2_fp_mla_long_array_vg4_single<"fmlsl", 0b01>;
defm FMLSL_VG2_M2Z2Z : sme2_fp_mla_long_array_vg2_multi<"fmlsl", 0b01>;
defm FMLSL_VG4_M4Z4Z : sme2_fp_mla_long_array_vg4_multi<"fmlsl", 0b01>;

defm BFMLAL_MZZI : sme2_mla_long_array_index<"bfmlal", 0b10, 0b10>;
defm BFMLAL_VG2_M2ZZI : sme2_fp_mla_long_array_vg2_index<"bfmlal", 0b10>;
defm BFMLAL_VG4_M4ZZI : sme2_fp_mla_long_array_vg4_index<"bfmlal", 0b10>;
defm BFMLAL_MZZ : sme2_mla_long_array_single<"bfmlal", 0b00, 0b10>;
defm BFMLAL_VG2_M2ZZ : sme2_fp_mla_long_array_vg2_single<"bfmlal", 0b10>;
defm BFMLAL_VG4_M4ZZ : sme2_fp_mla_long_array_vg4_single<"bfmlal", 0b10>;
defm BFMLAL_VG2_M2Z2Z : sme2_fp_mla_long_array_vg2_multi<"bfmlal", 0b10>;
defm BFMLAL_VG4_M4Z4Z : sme2_fp_mla_long_array_vg4_multi<"bfmlal", 0b10>;

defm BFMLSL_MZZI : sme2_mla_long_array_index<"bfmlsl", 0b10, 0b11>;
defm BFMLSL_VG2_M2ZZI : sme2_fp_mla_long_array_vg2_index<"bfmlsl", 0b11>;
defm BFMLSL_VG4_M4ZZI : sme2_fp_mla_long_array_vg4_index<"bfmlsl", 0b11>;
defm BFMLSL_MZZ : sme2_mla_long_array_single<"bfmlsl", 0b00, 0b11>;
defm BFMLSL_VG2_M2ZZ : sme2_fp_mla_long_array_vg2_single<"bfmlsl", 0b11>;
defm BFMLSL_VG4_M4ZZ : sme2_fp_mla_long_array_vg4_single<"bfmlsl", 0b11>;
defm BFMLSL_VG2_M2Z2Z : sme2_fp_mla_long_array_vg2_multi<"bfmlsl", 0b11>;
defm BFMLSL_VG4_M4Z4Z : sme2_fp_mla_long_array_vg4_multi<"bfmlsl", 0b11>;

defm SMLAL_MZZI : sme2_mla_long_array_index<"smlal", 0b11, 0b00>;
defm SMLAL_VG2_M2ZZI : sme2_int_mla_long_array_vg2_index<"smlal", 0b00>;
defm SMLAL_VG4_M4ZZI : sme2_int_mla_long_array_vg4_index<"smlal", 0b00>;
defm SMLAL_MZZ : sme2_mla_long_array_single<"smlal",0b01, 0b00>;
defm SMLAL_VG2_M2ZZ : sme2_int_mla_long_array_vg2_single<"smlal", 0b00>;
defm SMLAL_VG4_M4ZZ : sme2_int_mla_long_array_vg4_single<"smlal", 0b00>;
defm SMLAL_VG2_M2Z2Z : sme2_int_mla_long_array_vg2_multi<"smlal", 0b00>;
defm SMLAL_VG4_M4Z4Z : sme2_int_mla_long_array_vg4_multi<"smlal", 0b00>;

defm SMLSL_MZZI : sme2_mla_long_array_index<"smlsl", 0b11, 0b01>;
defm SMLSL_VG2_M2ZZI : sme2_int_mla_long_array_vg2_index<"smlsl", 0b01>;
defm SMLSL_VG4_M4ZZI : sme2_int_mla_long_array_vg4_index<"smlsl", 0b01>;
defm SMLSL_MZZ : sme2_mla_long_array_single<"smlsl",0b01, 0b01>;
defm SMLSL_VG2_M2ZZ : sme2_int_mla_long_array_vg2_single<"smlsl", 0b01>;
defm SMLSL_VG4_M4ZZ : sme2_int_mla_long_array_vg4_single<"smlsl", 0b01>;
defm SMLSL_VG2_M2Z2Z : sme2_int_mla_long_array_vg2_multi<"smlsl", 0b01>;
defm SMLSL_VG4_M4Z4Z : sme2_int_mla_long_array_vg4_multi<"smlsl", 0b01>;

defm UMLAL_MZZI : sme2_mla_long_array_index<"umlal", 0b11, 0b10>;
defm UMLAL_VG2_M2ZZI : sme2_int_mla_long_array_vg2_index<"umlal", 0b10>;
defm UMLAL_VG4_M4ZZI : sme2_int_mla_long_array_vg4_index<"umlal", 0b10>;
defm UMLAL_MZZ : sme2_mla_long_array_single<"umlal",0b01, 0b10>;
defm UMLAL_VG2_M2ZZ : sme2_int_mla_long_array_vg2_single<"umlal", 0b10>;
defm UMLAL_VG4_M4ZZ : sme2_int_mla_long_array_vg4_single<"umlal", 0b10>;
defm UMLAL_VG2_M2Z2Z : sme2_int_mla_long_array_vg2_multi<"umlal", 0b10>;
defm UMLAL_VG4_M4Z4Z : sme2_int_mla_long_array_vg4_multi<"umlal", 0b10>;

defm UMLSL_MZZI : sme2_mla_long_array_index<"umlsl", 0b11, 0b11>;
defm UMLSL_VG2_M2ZZI : sme2_int_mla_long_array_vg2_index<"umlsl", 0b11>;
defm UMLSL_VG4_M4ZZI : sme2_int_mla_long_array_vg4_index<"umlsl", 0b11>;
defm UMLSL_MZZ : sme2_mla_long_array_single<"umlsl",0b01, 0b11>;
defm UMLSL_VG2_M2ZZ : sme2_int_mla_long_array_vg2_single<"umlsl", 0b11>;
defm UMLSL_VG4_M4ZZ : sme2_int_mla_long_array_vg4_single<"umlsl", 0b11>;
defm UMLSL_VG2_M2Z2Z : sme2_int_mla_long_array_vg2_multi<"umlsl", 0b11>;
defm UMLSL_VG4_M4Z4Z : sme2_int_mla_long_array_vg4_multi<"umlsl", 0b11>;
}


Expand Down
122 changes: 114 additions & 8 deletions llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
Expand Up @@ -272,6 +272,7 @@ class AArch64AsmParser : public MCTargetAsmParser {
OperandMatchResultTy tryParseMatrixTileList(OperandVector &Operands);
OperandMatchResultTy tryParseSVEPattern(OperandVector &Operands);
OperandMatchResultTy tryParseGPR64x8(OperandVector &Operands);
OperandMatchResultTy tryParseImmRange(OperandVector &Operands);

public:
enum AArch64MatchResultTy {
Expand Down Expand Up @@ -327,6 +328,7 @@ class AArch64Operand : public MCParsedAsmOperand {
enum KindTy {
k_Immediate,
k_ShiftedImm,
k_ImmRange,
k_CondCode,
k_Register,
k_MatrixRegister,
Expand Down Expand Up @@ -417,6 +419,11 @@ class AArch64Operand : public MCParsedAsmOperand {
unsigned ShiftAmount;
};

struct ImmRangeOp {
unsigned First;
unsigned Last;
};

struct CondCodeOp {
AArch64CC::CondCode Code;
};
Expand Down Expand Up @@ -478,6 +485,7 @@ class AArch64Operand : public MCParsedAsmOperand {
struct VectorIndexOp VectorIndex;
struct ImmOp Imm;
struct ShiftedImmOp ShiftedImm;
struct ImmRangeOp ImmRange;
struct CondCodeOp CondCode;
struct FPImmOp FPImm;
struct BarrierOp Barrier;
Expand Down Expand Up @@ -511,6 +519,9 @@ class AArch64Operand : public MCParsedAsmOperand {
case k_ShiftedImm:
ShiftedImm = o.ShiftedImm;
break;
case k_ImmRange:
ImmRange = o.ImmRange;
break;
case k_CondCode:
CondCode = o.CondCode;
break;
Expand Down Expand Up @@ -589,6 +600,16 @@ class AArch64Operand : public MCParsedAsmOperand {
return ShiftedImm.ShiftAmount;
}

unsigned getFirstImmVal() const {
assert(Kind == k_ImmRange && "Invalid access!");
return ImmRange.First;
}

unsigned getLastImmVal() const {
assert(Kind == k_ImmRange && "Invalid access!");
return ImmRange.Last;
}

AArch64CC::CondCode getCondCode() const {
assert(Kind == k_CondCode && "Invalid access!");
return CondCode.Code;
Expand Down Expand Up @@ -752,18 +773,30 @@ class AArch64Operand : public MCParsedAsmOperand {
return isImmScaled<Bits, Scale>(true);
}

template <int Bits, int Scale> DiagnosticPredicate isUImmScaled() const {
return isImmScaled<Bits, Scale>(false);
template <int Bits, int Scale, int Offset = 0, bool IsRange = false>
DiagnosticPredicate isUImmScaled() const {
if (IsRange && isImmRange() &&
(getLastImmVal() != getFirstImmVal() + Offset))
return DiagnosticPredicateTy::NoMatch;

return isImmScaled<Bits, Scale, IsRange>(false);
}

template <int Bits, int Scale>
template <int Bits, int Scale, bool IsRange = false>
DiagnosticPredicate isImmScaled(bool Signed) const {
if (!isImm())
if ((!isImm() && !isImmRange()) || (isImm() && IsRange) ||
(isImmRange() && !IsRange))
return DiagnosticPredicateTy::NoMatch;

const MCConstantExpr *MCE = dyn_cast<MCConstantExpr>(getImm());
if (!MCE)
return DiagnosticPredicateTy::NoMatch;
int64_t Val;
if (isImmRange())
Val = getFirstImmVal();
else {
const MCConstantExpr *MCE = dyn_cast<MCConstantExpr>(getImm());
if (!MCE)
return DiagnosticPredicateTy::NoMatch;
Val = MCE->getValue();
}

int64_t MinVal, MaxVal;
if (Signed) {
Expand All @@ -775,7 +808,6 @@ class AArch64Operand : public MCParsedAsmOperand {
MaxVal = ((int64_t(1) << Bits) - 1) * Scale;
}

int64_t Val = MCE->getValue();
if (Val >= MinVal && Val <= MaxVal && (Val % Scale) == 0)
return DiagnosticPredicateTy::Match;

Expand Down Expand Up @@ -875,6 +907,8 @@ class AArch64Operand : public MCParsedAsmOperand {

bool isShiftedImm() const { return Kind == k_ShiftedImm; }

bool isImmRange() const { return Kind == k_ImmRange; }

/// Returns the immediate value as a pair of (imm, shift) if the immediate is
/// a shifted immediate by value 'Shift' or '0', or if it is an unshifted
/// immediate that can be shifted by 'Shift'.
Expand Down Expand Up @@ -1770,6 +1804,12 @@ class AArch64Operand : public MCParsedAsmOperand {
Inst.addOperand(MCOperand::createImm(MCE->getValue() / Scale));
}

template <int Scale>
void addImmScaledRangeOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
Inst.addOperand(MCOperand::createImm(getFirstImmVal() / Scale));
}

template <typename T>
void addLogicalImmOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
Expand Down Expand Up @@ -2111,6 +2151,17 @@ class AArch64Operand : public MCParsedAsmOperand {
return Op;
}

static std::unique_ptr<AArch64Operand> CreateImmRange(unsigned First,
unsigned Last, SMLoc S,
SMLoc E,
MCContext &Ctx) {
auto Op = std::make_unique<AArch64Operand>(k_ImmRange, Ctx);
Op->ImmRange.First = First;
Op->ImmRange.Last = Last;
Op->EndLoc = E;
return Op;
}

static std::unique_ptr<AArch64Operand>
CreateCondCode(AArch64CC::CondCode Code, SMLoc S, SMLoc E, MCContext &Ctx) {
auto Op = std::make_unique<AArch64Operand>(k_CondCode, Ctx);
Expand Down Expand Up @@ -2273,6 +2324,12 @@ void AArch64Operand::print(raw_ostream &OS) const {
OS << ", lsl #" << AArch64_AM::getShiftValue(Shift) << ">";
break;
}
case k_ImmRange: {
OS << "<immrange ";
OS << getFirstImmVal();
OS << ":" << getLastImmVal() << ">";
break;
}
case k_CondCode:
OS << "<condcode " << getCondCode() << ">";
break;
Expand Down Expand Up @@ -2999,6 +3056,10 @@ AArch64AsmParser::tryParseImmWithOptionalShift(OperandVector &Operands) {
// Operand should start from # or should be integer, emit error otherwise.
return MatchOperand_NoMatch;

if (getTok().is(AsmToken::Integer) &&
getLexer().peekTok().is(AsmToken::Colon))
return tryParseImmRange(Operands);

const MCExpr *Imm = nullptr;
if (parseSymbolicImmVal(Imm))
return MatchOperand_ParseFail;
Expand Down Expand Up @@ -5351,6 +5412,15 @@ bool AArch64AsmParser::showMatchError(SMLoc Loc, unsigned ErrCode,
return Error(Loc, "immediate must be an integer in range [1, 32].");
case Match_InvalidImm1_64:
return Error(Loc, "immediate must be an integer in range [1, 64].");
case Match_InvalidMemoryIndexedRange2UImm2:
case Match_InvalidMemoryIndexedRange2UImm3:
return Error(
Loc,
"vector select offset must be an immediate range of the form "
"<immf>:<imml>, "
"where the first immediate is a multiple of 2 in the range [0, 6] or "
"[0, 14] "
"depending on the instruction, and the second immediate is immf + 1.");
case Match_InvalidSVEAddSubImm8:
return Error(Loc, "immediate must be an integer in range [0, 255]"
" with a shift amount of 0");
Expand Down Expand Up @@ -5996,6 +6066,8 @@ bool AArch64AsmParser::MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode,
case Match_InvalidImm1_16:
case Match_InvalidImm1_32:
case Match_InvalidImm1_64:
case Match_InvalidMemoryIndexedRange2UImm2:
case Match_InvalidMemoryIndexedRange2UImm3:
case Match_InvalidSVEAddSubImm8:
case Match_InvalidSVEAddSubImm16:
case Match_InvalidSVEAddSubImm32:
Expand Down Expand Up @@ -7290,3 +7362,37 @@ AArch64AsmParser::tryParseGPR64x8(OperandVector &Operands) {
AArch64Operand::CreateReg(X8Reg, RegKind::Scalar, SS, getLoc(), ctx));
return MatchOperand_Success;
}

OperandMatchResultTy
AArch64AsmParser::tryParseImmRange(OperandVector &Operands) {
SMLoc S = getLoc();

if (getTok().isNot(AsmToken::Integer))
return MatchOperand_NoMatch;

if (getLexer().peekTok().isNot(AsmToken::Colon))
return MatchOperand_NoMatch;

const MCExpr *ImmF;
if (getParser().parseExpression(ImmF))
return MatchOperand_NoMatch;

if (getTok().isNot(AsmToken::Colon))
return MatchOperand_NoMatch;

Lex(); // Eat ':'
if (getTok().isNot(AsmToken::Integer))
return MatchOperand_NoMatch;

SMLoc E = getTok().getLoc();
const MCExpr *ImmL;
if (getParser().parseExpression(ImmL))
return MatchOperand_NoMatch;

unsigned ImmFVal = dyn_cast<MCConstantExpr>(ImmF)->getValue();
unsigned ImmLVal = dyn_cast<MCConstantExpr>(ImmL)->getValue();

Operands.push_back(
AArch64Operand::CreateImmRange(ImmFVal, ImmLVal, S, E, getContext()));
return MatchOperand_Success;
}
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp
Expand Up @@ -1223,6 +1223,15 @@ void AArch64InstPrinter::printImmScale(const MCInst *MI, unsigned OpNum,
<< formatImm(Scale * MI->getOperand(OpNum).getImm()) << markup(">");
}

template <int Scale, int Offset>
void AArch64InstPrinter::printImmRangeScale(const MCInst *MI, unsigned OpNum,
const MCSubtargetInfo &STI,
raw_ostream &O) {
unsigned FirstImm = Scale * MI->getOperand(OpNum).getImm();
O << formatImm(FirstImm);
O << ":" << formatImm(FirstImm + Offset);
}

void AArch64InstPrinter::printUImm12Offset(const MCInst *MI, unsigned OpNum,
unsigned Scale, raw_ostream &O) {
const MCOperand MO = MI->getOperand(OpNum);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h
Expand Up @@ -130,6 +130,10 @@ class AArch64InstPrinter : public MCInstPrinter {
void printImmScale(const MCInst *MI, unsigned OpNum,
const MCSubtargetInfo &STI, raw_ostream &O);

template <int Scale, int Offset>
void printImmRangeScale(const MCInst *MI, unsigned OpNum,
const MCSubtargetInfo &STI, raw_ostream &O);

template <bool IsSVEPrefetch = false>
void printPrefetchOp(const MCInst *MI, unsigned OpNum,
const MCSubtargetInfo &STI, raw_ostream &O);
Expand Down

0 comments on commit 3fee935

Please sign in to comment.