Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] Improve lowering of v4i8 #67866

Merged
merged 13 commits into from
Oct 9, 2023
38 changes: 31 additions & 7 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
Expand Down Expand Up @@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
case MVT::v4i8:
return Opcode_i32;
case MVT::f32:
return Opcode_f32;
Expand Down Expand Up @@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// Vector Setting
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
"Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
fromTypeWidth = 32;
}
Expand Down Expand Up @@ -1254,19 +1257,23 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
SDLoc DL(N);
SDNode *LD;
SDValue Base, Offset, Addr;
EVT OrigType = N->getValueType(0);

EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
(EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
EltVT = OrigType;
NumElts /= 2;
} else if (OrigType == MVT::v4i8) {
EltVT = OrigType;
NumElts = 1;
}
}

Expand Down Expand Up @@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// concept of sign-/zero-extension, so emulate it here by adding an explicit
// CVT instruction. Ptxas should clean up any redundancies here.

EVT OrigType = N->getValueType(0);
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);

if (OrigType != EltVT &&
Expand Down Expand Up @@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
"Unexpected vector type");
// v2x16 is stored using st.b32
toTypeWidth = 32;
}
Expand Down Expand Up @@ -3563,6 +3570,23 @@ bool NVPTXDAGToDAGISel::SelectADDRri64(SDNode *OpNode, SDValue Addr,
return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i64);
}

bool NVPTXDAGToDAGISel::SelectExtractEltFromV4I8(SDValue N, SDValue &V,
SDValue &BitOffset) {
SDValue Vector = N->getOperand(0);
if (!(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Vector->getValueType(0) == MVT::v4i8))
return false;

if (const ConstantSDNode *IdxConst =
dyn_cast<ConstantSDNode>(N->getOperand(1))) {
V = Vector;
BitOffset = CurDAG->getTargetConstant(IdxConst->getZExtValue() * 8,
SDLoc(N), MVT::i32);
return true;
}
return false;
}

bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N,
unsigned int spN) const {
const Value *Src = nullptr;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
SDValue &Offset);
bool SelectADDRsi64(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset);
bool SelectExtractEltFromV4I8(SDValue N, SDValue &Value, SDValue &Idx);

bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;

Expand Down
87 changes: 52 additions & 35 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
llvm_unreachable("Unexpected type");
}
NumElts /= 2;
} else if (EltVT.getSimpleVT() == MVT::i8 &&
(NumElts % 4 == 0 || NumElts == 3)) {
// v*i8 are formally lowered as v4i8
EltVT = MVT::v4i8;
NumElts = (NumElts + 3) / 4;
}
for (unsigned j = 0; j != NumElts; ++j) {
ValueVTs.push_back(EltVT);
Expand Down Expand Up @@ -458,6 +463,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
Expand Down Expand Up @@ -491,6 +497,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);

// TODO: we should eventually lower it as PRMT instruction.
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Expand);
setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);

// Operations not directly supported by NVPTX.
Artem-B marked this conversation as resolved.
Show resolved Hide resolved
for (MVT VT :
{MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32, MVT::f64,
Expand Down Expand Up @@ -2150,45 +2160,47 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
}

// We can init constant f16x2 with a single .b32 move. Normally it
// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
// would get lowered as two constant loads and vector-packing move.
// mov.b16 %h1, 0x4000;
// mov.b16 %h2, 0x3C00;
// mov.b32 %hh2, {%h2, %h1};
// Instead we want just a constant move:
// mov.b32 %hh2, 0x40003C00
//
// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0
// generates good SASS in both cases.
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op->getValueType(0);
if (!(Isv2x16VT(VT)))
if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
return Op;

if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
isa<ConstantFPSDNode>(Operand);
}))
return Op;
APInt E0;
APInt E1;
if (VT == MVT::v2f16 || VT == MVT::v2bf16) {
if (!(isa<ConstantFPSDNode>(Op->getOperand(0)) &&
isa<ConstantFPSDNode>(Op->getOperand(1))))
return Op;

E0 = cast<ConstantFPSDNode>(Op->getOperand(0))
->getValueAPF()
.bitcastToAPInt();
E1 = cast<ConstantFPSDNode>(Op->getOperand(1))
->getValueAPF()
.bitcastToAPInt();
} else {
assert(VT == MVT::v2i16);
if (!(isa<ConstantSDNode>(Op->getOperand(0)) &&
isa<ConstantSDNode>(Op->getOperand(1))))
return Op;

E0 = cast<ConstantSDNode>(Op->getOperand(0))->getAPIntValue();
E1 = cast<ConstantSDNode>(Op->getOperand(1))->getAPIntValue();
// Get value or the Nth operand as an APInt(32). Undef values treated as 0.
auto GetOperand = [](SDValue Op, int N) -> APInt {
const SDValue &Operand = Op->getOperand(N);
EVT VT = Op->getValueType(0);
if (Operand->isUndef())
return APInt(32, 0);
APInt Value;
if (VT == MVT::v2f16 || VT == MVT::v2bf16)
Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
else if (VT == MVT::v2i16 || VT == MVT::v4i8)
Value = cast<ConstantSDNode>(Operand)->getAPIntValue();
else
llvm_unreachable("Unsupported type");
return Value.zext(32);
Artem-B marked this conversation as resolved.
Show resolved Hide resolved
};
APInt Value;
if (Isv2x16VT(VT)) {
Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
} else if (VT == MVT::v4i8) {
Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
} else {
llvm_unreachable("Unsupported type");
}
SDValue Const =
DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
}

Expand Down Expand Up @@ -2631,7 +2643,7 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
return expandUnalignedStore(Store, DAG);

// v2f16, v2bf16 and v2i16 don't need special handling.
if (Isv2x16VT(VT))
if (Isv2x16VT(VT) || VT == MVT::v4i8)
return SDValue();

if (VT.isVector())
Expand Down Expand Up @@ -2903,7 +2915,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
EVT LoadVT = EltVT;
if (EltVT == MVT::i1)
LoadVT = MVT::i8;
else if (Isv2x16VT(EltVT))
else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
// getLoad needs a vector type, but it can't handle
// vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
Expand All @@ -2929,7 +2941,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (EltVT == MVT::i1)
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
// v2f16 was loaded as an i32. Now we must bitcast it back.
else if (Isv2x16VT(EltVT))
else if (EltVT != LoadVT)
Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);

// If a promoted integer type is used, truncate down to the original
Expand Down Expand Up @@ -5258,9 +5270,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
IsPTXVectorType(VectorVT.getSimpleVT()))
return SDValue(); // Native vector loads already combine nicely w/
// extract_vector_elt.
// extract_vector_elt, except for v4i8.
// Don't mess with singletons or v2*16 types, we already handle them OK.
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT))
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
VectorVT == MVT::v4i8)
return SDValue();

uint64_t VectorBits = VectorVT.getSizeInBits();
Expand Down Expand Up @@ -5289,6 +5302,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
// If element has non-integer type, bitcast it back to the expected type.
if (EltVT != EltIVT)
Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
// Past legalizer, we may need to extent i8 -> i16 to match the register type.
if (EltVT != N->getValueType(0))
Result = DCI.DAG.getNode(ISD::ANY_EXTEND, DL, N->getValueType(0), Result);

return Result;
}

Expand Down
Loading
Loading