Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clang/lib/Basic/Targets/Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,14 @@ bool HexagonTargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
HasFastHalfType = true;
HasFloat16 = true;
}
if (CPU.compare("hexagonv81") >= 0)
HasBFloat16 = true;

return true;
}

bool HexagonTargetInfo::hasBFloat16Type() const { return HasBFloat16; }

const char *const HexagonTargetInfo::GCCRegNames[] = {
// Scalar registers:
"r0", "r1", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", "r11",
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/Basic/Targets/Hexagon.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class LLVM_LIBRARY_VISIBILITY HexagonTargetInfo : public TargetInfo {
// for modeling predicate registers in HVX, and the bool -> byte
// correspondence matches the HVX architecture.
BoolWidth = BoolAlign = 8;
BFloat16Width = BFloat16Align = 16;
BFloat16Format = &llvm::APFloat::BFloat();
}

llvm::SmallVector<Builtin::InfosShard> getTargetBuiltins() const override;
Expand Down Expand Up @@ -95,6 +97,8 @@ class LLVM_LIBRARY_VISIBILITY HexagonTargetInfo : public TargetInfo {

bool hasFeature(StringRef Feature) const override;

bool hasBFloat16Type() const override;

bool
initFeatureMap(llvm::StringMap<bool> &Features, DiagnosticsEngine &Diags,
StringRef CPU,
Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/Target/Hexagon/HexagonCallingConv.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def CC_HexagonStack: CallingConv<[
def CC_Hexagon_Legacy: CallingConv<[
CCIfType<[i1,i8,i16],
CCPromoteToType<i32>>,
CCIfType<[bf16],
CCBitConvertToType<i32>>,
CCIfType<[f32],
CCBitConvertToType<i32>>,
CCIfType<[f64],
Expand Down Expand Up @@ -55,6 +57,8 @@ def CC_Hexagon_Legacy: CallingConv<[
def CC_Hexagon: CallingConv<[
CCIfType<[i1,i8,i16],
CCPromoteToType<i32>>,
CCIfType<[bf16],
CCBitConvertToType<i32>>,
CCIfType<[f32],
CCBitConvertToType<i32>>,
CCIfType<[f64],
Expand Down Expand Up @@ -88,6 +92,8 @@ def CC_Hexagon: CallingConv<[
def RetCC_Hexagon: CallingConv<[
CCIfType<[i1,i8,i16],
CCPromoteToType<i32>>,
CCIfType<[bf16],
CCBitConvertToType<i32>>,
CCIfType<[f32],
CCBitConvertToType<i32>>,
CCIfType<[f64],
Expand Down Expand Up @@ -149,16 +155,16 @@ def CC_Hexagon_HVX: CallingConv<[
CCIfType<[v128i1], CCPromoteToType<v128i8>>>,

CCIfHvx128<
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
CCAssignToReg<[V0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15]>>>,
CCIfHvx128<
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v128bf16],
CCAssignToReg<[W0,W1,W2,W3,W4,W5,W6,W7]>>>,
CCIfHvx128<
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
CCAssignToStack<128,128>>>,
CCIfHvx128<
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v64bf16],
CCAssignToStack<256,128>>>,

CCDelegateTo<CC_Hexagon>
Expand All @@ -175,10 +181,10 @@ def RetCC_Hexagon_HVX: CallingConv<[

// HVX 128-byte mode
CCIfHvx128<
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
CCAssignToReg<[V0]>>>,
CCIfHvx128<
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v128bf16],
CCAssignToReg<[W0]>>>,

CCDelegateTo<RetCC_Hexagon>
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,8 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
}
// Turn FP truncstore into trunc + store.
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
// Turn FP extload into load/fpextend.
for (MVT VT : MVT::fp_valuetypes())
setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand);
Expand Down Expand Up @@ -1872,9 +1874,15 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand);
setOperationAction(ISD::BF16_TO_FP, MVT::f64, Expand);
setOperationAction(ISD::FP_TO_BF16, MVT::f64, Expand);

setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);

setTruncStoreAction(MVT::f32, MVT::f16, Expand);
setTruncStoreAction(MVT::f64, MVT::f16, Expand);

Expand Down
68 changes: 60 additions & 8 deletions llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ HexagonTargetLowering::initializeHVXLowering() {
addRegisterClass(MVT::v64f32, &Hexagon::HvxWRRegClass);
addRegisterClass(MVT::v128f16, &Hexagon::HvxWRRegClass);
}
if (Subtarget.useHVXV81Ops()) {
addRegisterClass(MVT::v64bf16, &Hexagon::HvxVRRegClass);
addRegisterClass(MVT::v128bf16, &Hexagon::HvxWRRegClass);
}
}

// Set up operation actions.
Expand Down Expand Up @@ -162,6 +166,30 @@ HexagonTargetLowering::initializeHVXLowering() {
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v64f32, ByteW);
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v32f32, ByteV);

if (Subtarget.useHVXV81Ops()) {
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v128bf16, ByteW);
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v64bf16, ByteV);
setPromoteTo(ISD::SETCC, MVT::v64bf16, MVT::v64f32);
setPromoteTo(ISD::FADD, MVT::v64bf16, MVT::v64f32);
setPromoteTo(ISD::FSUB, MVT::v64bf16, MVT::v64f32);
setPromoteTo(ISD::FMUL, MVT::v64bf16, MVT::v64f32);
setPromoteTo(ISD::FMINNUM, MVT::v64bf16, MVT::v64f32);
setPromoteTo(ISD::FMAXNUM, MVT::v64bf16, MVT::v64f32);

setOperationAction(ISD::SPLAT_VECTOR, MVT::v64bf16, Legal);
setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64bf16, Custom);
setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v64bf16, Custom);

setOperationAction(ISD::MLOAD, MVT::v64bf16, Custom);
setOperationAction(ISD::MSTORE, MVT::v64bf16, Custom);
setOperationAction(ISD::BUILD_VECTOR, MVT::v64bf16, Custom);
setOperationAction(ISD::CONCAT_VECTORS, MVT::v64bf16, Custom);

setOperationAction(ISD::SPLAT_VECTOR, MVT::bf16, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::bf16, Custom);
setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom);
}

for (MVT P : FloatW) {
setOperationAction(ISD::LOAD, P, Custom);
setOperationAction(ISD::STORE, P, Custom);
Expand Down Expand Up @@ -462,6 +490,10 @@ HexagonTargetLowering::initializeHVXLowering() {

unsigned
HexagonTargetLowering::getPreferredHvxVectorAction(MVT VecTy) const {
// Early exit for invalid input types
if (!VecTy.isVector())
return ~0u;

MVT ElemTy = VecTy.getVectorElementType();
unsigned VecLen = VecTy.getVectorNumElements();
unsigned HwLen = Subtarget.getVectorLength();
Expand Down Expand Up @@ -1667,14 +1699,15 @@ HexagonTargetLowering::LowerHvxBuildVector(SDValue Op, SelectionDAG &DAG)
// In case of MVT::f16 BUILD_VECTOR, since MVT::f16 is
// not a legal type, just bitcast the node to use i16
// types and bitcast the result back to f16
if (VecTy.getVectorElementType() == MVT::f16) {
SmallVector<SDValue,64> NewOps;
if (VecTy.getVectorElementType() == MVT::f16 ||
VecTy.getVectorElementType() == MVT::bf16) {
SmallVector<SDValue, 64> NewOps;
for (unsigned i = 0; i != Size; i++)
NewOps.push_back(DAG.getBitcast(MVT::i16, Ops[i]));

SDValue T0 = DAG.getNode(ISD::BUILD_VECTOR, dl,
tyVector(VecTy, MVT::i16), NewOps);
return DAG.getBitcast(tyVector(VecTy, MVT::f16), T0);
SDValue T0 =
DAG.getNode(ISD::BUILD_VECTOR, dl, tyVector(VecTy, MVT::i16), NewOps);
return DAG.getBitcast(tyVector(VecTy, VecTy.getVectorElementType()), T0);
}

// First, split the BUILD_VECTOR for vector pairs. We could generate
Expand All @@ -1698,7 +1731,7 @@ HexagonTargetLowering::LowerHvxSplatVector(SDValue Op, SelectionDAG &DAG)
MVT VecTy = ty(Op);
MVT ArgTy = ty(Op.getOperand(0));

if (ArgTy == MVT::f16) {
if (ArgTy == MVT::f16 || ArgTy == MVT::bf16) {
MVT SplatTy = MVT::getVectorVT(MVT::i16, VecTy.getVectorNumElements());
SDValue ToInt16 = DAG.getBitcast(MVT::i16, Op.getOperand(0));
SDValue ToInt32 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, ToInt16);
Expand Down Expand Up @@ -1831,12 +1864,12 @@ HexagonTargetLowering::LowerHvxInsertElement(SDValue Op, SelectionDAG &DAG)
if (ElemTy == MVT::i1)
return insertHvxElementPred(VecV, IdxV, ValV, dl, DAG);

if (ElemTy == MVT::f16) {
if (ElemTy == MVT::f16 || ElemTy == MVT::bf16) {
SDValue T0 = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl,
tyVector(VecTy, MVT::i16),
DAG.getBitcast(tyVector(VecTy, MVT::i16), VecV),
DAG.getBitcast(MVT::i16, ValV), IdxV);
return DAG.getBitcast(tyVector(VecTy, MVT::f16), T0);
return DAG.getBitcast(tyVector(VecTy, ElemTy), T0);
}

return insertHvxElementReg(VecV, IdxV, ValV, dl, DAG);
Expand Down Expand Up @@ -2334,6 +2367,25 @@ SDValue HexagonTargetLowering::LowerHvxFpExtend(SDValue Op,
MVT VecTy = ty(Op);
MVT ArgTy = ty(Op.getOperand(0));
const SDLoc &dl(Op);

if (ArgTy == MVT::v64bf16) {
MVT HalfTy = typeSplit(VecTy).first;
SDValue BF16Vec = Op.getOperand(0);
SDValue Zeroes =
getInstr(Hexagon::V6_vxor, dl, HalfTy, {BF16Vec, BF16Vec}, DAG);
// Interleave zero vector with the bf16 vector, with zeroes in the lower
// half of each 32 bit lane, effectively extending the bf16 values to fp32
// values.
SDValue ShuffVec =
getInstr(Hexagon::V6_vshufoeh, dl, VecTy, {BF16Vec, Zeroes}, DAG);
VectorPair VecPair = opSplit(ShuffVec, dl, DAG);
SDValue Result = getInstr(Hexagon::V6_vshuffvdd, dl, VecTy,
{VecPair.second, VecPair.first,
DAG.getSignedConstant(-4, dl, MVT::i32)},
DAG);
return Result;
}

assert(VecTy == MVT::v64f32 && ArgTy == MVT::v64f16);

SDValue F16Vec = Op.getOperand(0);
Expand Down
13 changes: 12 additions & 1 deletion llvm/lib/Target/Hexagon/HexagonPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ def Fptoui: pf1<fp_to_uint>;
def Sitofp: pf1<sint_to_fp>;
def Uitofp: pf1<uint_to_fp>;


// --(1) Immediate -------------------------------------------------------
//

Expand Down Expand Up @@ -474,6 +473,18 @@ def: OpR_R_pat<F2_conv_df2uw_chop, pf1<fp_to_uint>, i32, F64>;
def: OpR_R_pat<F2_conv_sf2ud_chop, pf1<fp_to_uint>, i64, F32>;
def: OpR_R_pat<F2_conv_df2ud_chop, pf1<fp_to_uint>, i64, F64>;

def: Pat<(i32 (fp_to_bf16 F32:$v)),
(C2_mux (F2_sfclass F32:$v, 0x10), (A2_tfrsi(i32 0x7fff)),
(C2_mux
(C2_cmpeq
(A2_and F32:$v, (A2_tfrsi (i32 0x1FFFF))),
(A2_tfrsi (i32 0x08000))),
(A2_and (A2_asrh F32:$v), (A2_tfrsi (i32 65535))),
(A2_and
(A2_asrh
(A2_add F32:$v, (A2_and F32:$v, (A2_tfrsi (i32 0x8000))))),
(A2_tfrsi (i32 65535))))
)>;
// Bitcast is different than [fp|sint|uint]_to_[sint|uint|fp].
def: Pat<(i32 (bitconvert F32:$v)), (I32:$v)>;
def: Pat<(f32 (bitconvert I32:$v)), (F32:$v)>;
Expand Down
Loading