Skip to content

Commit

Permalink
[NVPTX] Improve lowering of v4i8 (#67866)
Browse files Browse the repository at this point in the history
Make v4i8 a legal type and plumb through lowering of relevant instructions.
  • Loading branch information
Artem-B committed Oct 9, 2023
1 parent 67b675e commit cbafb6f
Show file tree
Hide file tree
Showing 15 changed files with 1,897 additions and 540 deletions.
31 changes: 31 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,34 @@ void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
O << Sym.getName();
}

void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
int64_t Imm = MO.getImm();

switch (Imm) {
default:
return;
case NVPTX::PTXPrmtMode::NONE:
break;
case NVPTX::PTXPrmtMode::F4E:
O << ".f4e";
break;
case NVPTX::PTXPrmtMode::B4E:
O << ".b4e";
break;
case NVPTX::PTXPrmtMode::RC8:
O << ".rc8";
break;
case NVPTX::PTXPrmtMode::ECL:
O << ".ecl";
break;
case NVPTX::PTXPrmtMode::ECR:
O << ".ecr";
break;
case NVPTX::PTXPrmtMode::RC16:
O << ".rc16";
break;
}
}
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
raw_ostream &O, const char *Modifier = nullptr);
void printProtoIdent(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier = nullptr);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *Modifier = nullptr);
};

}
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,18 @@ enum CmpMode {
FTZ_FLAG = 0x100
};
}

namespace PTXPrmtMode {
enum PrmtMode {
NONE,
F4E,
B4E,
RC8,
ECL,
ECR,
RC16,
};
}
}
void initializeNVPTXDAGToDAGISelPass(PassRegistry &);
} // namespace llvm
Expand Down
21 changes: 14 additions & 7 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
Expand Down Expand Up @@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
case MVT::v4i8:
return Opcode_i32;
case MVT::f32:
return Opcode_f32;
Expand Down Expand Up @@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// Vector Setting
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
"Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
fromTypeWidth = 32;
}
Expand Down Expand Up @@ -1254,19 +1257,23 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
SDLoc DL(N);
SDNode *LD;
SDValue Base, Offset, Addr;
EVT OrigType = N->getValueType(0);

EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
(EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
EltVT = OrigType;
NumElts /= 2;
} else if (OrigType == MVT::v4i8) {
EltVT = OrigType;
NumElts = 1;
}
}

Expand Down Expand Up @@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// concept of sign-/zero-extension, so emulate it here by adding an explicit
// CVT instruction. Ptxas should clean up any redundancies here.

EVT OrigType = N->getValueType(0);
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);

if (OrigType != EltVT &&
Expand Down Expand Up @@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
"Unexpected vector type");
// v2x16 is stored using st.b32
toTypeWidth = 32;
}
Expand Down
Loading

0 comments on commit cbafb6f

Please sign in to comment.