Skip to content

Commit

Permalink
[NVPTX] Added support for half-precision floating point.
Browse files Browse the repository at this point in the history
Only scalar half-precision operations are supported at the moment.

- Adds general support for 'half' type in NVPTX.
- fp16 math operations are supported on sm_53+ GPUs only
  (can be disabled with --nvptx-no-f16-math).
- Type conversions to/from fp16 are supported on all GPU variants.
- On GPU variants that do not have full fp16 support (or if it's disabled),
  fp16 operations are promoted to fp32 and results are converted back
  to fp16 for storage.

Differential Revision: https://reviews.llvm.org/D28540

llvm-svn: 291956
  • Loading branch information
Artem-B committed Jan 13, 2017
1 parent 836a3b4 commit 64dc9be
Show file tree
Hide file tree
Showing 18 changed files with 1,487 additions and 102 deletions.
9 changes: 8 additions & 1 deletion llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp
Expand Up @@ -61,6 +61,9 @@ void NVPTXInstPrinter::printRegName(raw_ostream &OS, unsigned RegNo) const {
case 6:
OS << "%fd";
break;
case 7:
OS << "%h";
break;
}

unsigned VReg = RegNo & 0x0FFFFFFF;
Expand Down Expand Up @@ -247,8 +250,12 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
O << "s";
else if (Imm == NVPTX::PTXLdStInstCode::Unsigned)
O << "u";
else
else if (Imm == NVPTX::PTXLdStInstCode::Untyped)
O << "b";
else if (Imm == NVPTX::PTXLdStInstCode::Float)
O << "f";
else
llvm_unreachable("Unknown register type");
} else if (!strcmp(Modifier, "vec")) {
if (Imm == NVPTX::PTXLdStInstCode::V2)
O << ".v2";
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTX.h
Expand Up @@ -108,7 +108,8 @@ enum AddressSpace {
enum FromType {
Unsigned = 0,
Signed,
Float
Float,
Untyped
};
enum VecType {
Scalar = 1,
Expand Down
21 changes: 19 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Expand Up @@ -320,6 +320,10 @@ bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,

switch (Cnt->getType()->getTypeID()) {
default: report_fatal_error("Unsupported FP type"); break;
case Type::HalfTyID:
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
break;
case Type::FloatTyID:
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
Expand Down Expand Up @@ -357,6 +361,8 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
Ret = (5 << 28);
} else if (RC == &NVPTX::Float64RegsRegClass) {
Ret = (6 << 28);
} else if (RC == &NVPTX::Float16RegsRegClass) {
Ret = (7 << 28);
} else {
report_fatal_error("Bad register class");
}
Expand Down Expand Up @@ -396,12 +402,15 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
size = ITy->getBitWidth();
if (size < 32)
size = 32;
} else {
assert(Ty->isFloatingPointTy() && "Floating point type expected here");
size = Ty->getPrimitiveSizeInBits();
}
// PTX ABI requires all scalar return values to be at least 32
// bits in size. fp16 normally uses .b16 as its storage type in
// PTX, so its size must be adjusted here, too.
if (size < 32)
size = 32;

O << ".param .b" << size << " func_retval0";
} else if (isa<PointerType>(Ty)) {
Expand Down Expand Up @@ -1376,6 +1385,9 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
}
break;
}
case Type::HalfTyID:
// fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
return "b16";
case Type::FloatTyID:
return "f32";
case Type::DoubleTyID:
Expand Down Expand Up @@ -1601,6 +1613,11 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
sz = 32;
} else if (isa<PointerType>(Ty))
sz = thePointerTy.getSizeInBits();
else if (Ty->isHalfTy())
// PTX ABI requires all scalar parameters to be at least 32
// bits in size. fp16 normally uses .b16 as its storage type
// in PTX, so its size must be adjusted here, too.
sz = 32;
else
sz = Ty->getPrimitiveSizeInBits();
if (isABI)
Expand Down
71 changes: 68 additions & 3 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Expand Up @@ -42,7 +42,6 @@ FtzEnabled("nvptx-f32ftz", cl::ZeroOrMore, cl::Hidden,
cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."),
cl::init(false));


/// createNVPTXISelDag - This pass converts a legalized DAG into a
/// NVPTX-specific DAG, ready for instruction scheduling.
FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
Expand Down Expand Up @@ -520,6 +519,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
case ISD::ADDRSPACECAST:
SelectAddrSpaceCast(N);
return;
case ISD::ConstantFP:
if (tryConstantFP16(N))
return;
break;
default:
break;
}
Expand All @@ -541,6 +544,19 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
}
}

// There's no way to specify FP16 immediates in .f16 ops, so we have to
// load them into an .f16 register first.
bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) {
if (N->getValueType(0) != MVT::f16)
return false;
SDValue Val = CurDAG->getTargetConstantFP(
cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::f16);
SDNode *LoadConstF16 =
CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val);
ReplaceNode(N, LoadConstF16);
return true;
}

static unsigned int getCodeAddrSpace(MemSDNode *N) {
const Value *Src = N->getMemOperand()->getValue();

Expand Down Expand Up @@ -740,7 +756,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
if ((LD->getExtensionType() == ISD::SEXTLOAD))
fromType = NVPTX::PTXLdStInstCode::Signed;
else if (ScalarVT.isFloatingPoint())
fromType = NVPTX::PTXLdStInstCode::Float;
// f16 uses .b16 as its storage type.
fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
: NVPTX::PTXLdStInstCode::Float;
else
fromType = NVPTX::PTXLdStInstCode::Unsigned;

Expand All @@ -766,6 +784,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_avar;
break;
case MVT::f16:
Opcode = NVPTX::LD_f16_avar;
break;
case MVT::f32:
Opcode = NVPTX::LD_f32_avar;
break;
Expand Down Expand Up @@ -794,6 +815,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_asi;
break;
case MVT::f16:
Opcode = NVPTX::LD_f16_asi;
break;
case MVT::f32:
Opcode = NVPTX::LD_f32_asi;
break;
Expand Down Expand Up @@ -823,6 +847,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_ari_64;
break;
case MVT::f16:
Opcode = NVPTX::LD_f16_ari_64;
break;
case MVT::f32:
Opcode = NVPTX::LD_f32_ari_64;
break;
Expand All @@ -846,6 +873,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_ari;
break;
case MVT::f16:
Opcode = NVPTX::LD_f16_ari;
break;
case MVT::f32:
Opcode = NVPTX::LD_f32_ari;
break;
Expand Down Expand Up @@ -875,6 +905,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_areg_64;
break;
case MVT::f16:
Opcode = NVPTX::LD_f16_areg_64;
break;
case MVT::f32:
Opcode = NVPTX::LD_f32_areg_64;
break;
Expand All @@ -898,6 +931,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_areg;
break;
case MVT::f16:
Opcode = NVPTX::LD_f16_areg;
break;
case MVT::f32:
Opcode = NVPTX::LD_f32_areg;
break;
Expand Down Expand Up @@ -2173,7 +2209,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
unsigned toTypeWidth = ScalarVT.getSizeInBits();
unsigned int toType;
if (ScalarVT.isFloatingPoint())
toType = NVPTX::PTXLdStInstCode::Float;
// f16 uses .b16 as its storage type.
toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
: NVPTX::PTXLdStInstCode::Float;
else
toType = NVPTX::PTXLdStInstCode::Unsigned;

Expand All @@ -2200,6 +2238,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_avar;
break;
case MVT::f16:
Opcode = NVPTX::ST_f16_avar;
break;
case MVT::f32:
Opcode = NVPTX::ST_f32_avar;
break;
Expand Down Expand Up @@ -2229,6 +2270,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_asi;
break;
case MVT::f16:
Opcode = NVPTX::ST_f16_asi;
break;
case MVT::f32:
Opcode = NVPTX::ST_f32_asi;
break;
Expand Down Expand Up @@ -2259,6 +2303,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_ari_64;
break;
case MVT::f16:
Opcode = NVPTX::ST_f16_ari_64;
break;
case MVT::f32:
Opcode = NVPTX::ST_f32_ari_64;
break;
Expand All @@ -2282,6 +2329,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_ari;
break;
case MVT::f16:
Opcode = NVPTX::ST_f16_ari;
break;
case MVT::f32:
Opcode = NVPTX::ST_f32_ari;
break;
Expand Down Expand Up @@ -2312,6 +2362,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_areg_64;
break;
case MVT::f16:
Opcode = NVPTX::ST_f16_areg_64;
break;
case MVT::f32:
Opcode = NVPTX::ST_f32_areg_64;
break;
Expand All @@ -2335,6 +2388,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_areg;
break;
case MVT::f16:
Opcode = NVPTX::ST_f16_areg;
break;
case MVT::f32:
Opcode = NVPTX::ST_f32_areg;
break;
Expand Down Expand Up @@ -2786,6 +2842,9 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
case MVT::i64:
Opc = NVPTX::LoadParamMemI64;
break;
case MVT::f16:
Opc = NVPTX::LoadParamMemF16;
break;
case MVT::f32:
Opc = NVPTX::LoadParamMemF32;
break;
Expand Down Expand Up @@ -2921,6 +2980,9 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::StoreRetvalI64;
break;
case MVT::f16:
Opcode = NVPTX::StoreRetvalF16;
break;
case MVT::f32:
Opcode = NVPTX::StoreRetvalF32;
break;
Expand Down Expand Up @@ -3054,6 +3116,9 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::StoreParamI64;
break;
case MVT::f16:
Opcode = NVPTX::StoreParamF16;
break;
case MVT::f32:
Opcode = NVPTX::StoreParamF32;
break;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Expand Up @@ -70,6 +70,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryTextureIntrinsic(SDNode *N);
bool trySurfaceIntrinsic(SDNode *N);
bool tryBFE(SDNode *N);
bool tryConstantFP16(SDNode *N);

inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
Expand Down

0 comments on commit 64dc9be

Please sign in to comment.