Skip to content

Commit 93813a6

Browse files
mallick-qcMuntasir Mallickkaushik-quicinc
authored andcommitted
[Hexagon] Enable soft bf16 in hexagon
This patch adds: 1. Support to recognize bf16 type in the frontend and isel/abi support for scalar bf16 programs Limitations: fp_to_bf16 is being generated with a tablegen pattern instead of lowering via expansion. This is because we do not have support for fcanonincalize instruction which should prevent an SNaN being converted to an infinity due to truncation. 2. Vector codegen support for bf16 Patch By: Fateme Hosseini Co-authored-by: Muntasir Mallick <mallick@qti.qualcomm.com> Co-authored-by: Kaushik Kulkarni <quic_kauskulk@quicinc.com>
1 parent 1347b23 commit 93813a6

File tree

11 files changed

+910
-481
lines changed

11 files changed

+910
-481
lines changed

clang/lib/Basic/Targets/Hexagon.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,14 @@ bool HexagonTargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
155155
HasFastHalfType = true;
156156
HasFloat16 = true;
157157
}
158+
if (CPU.compare("hexagonv81") >= 0)
159+
HasBFloat16 = true;
160+
158161
return true;
159162
}
160163

164+
bool HexagonTargetInfo::hasBFloat16Type() const { return HasBFloat16; }
165+
161166
const char *const HexagonTargetInfo::GCCRegNames[] = {
162167
// Scalar registers:
163168
"r0", "r1", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", "r11",

clang/lib/Basic/Targets/Hexagon.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class LLVM_LIBRARY_VISIBILITY HexagonTargetInfo : public TargetInfo {
6464
// for modeling predicate registers in HVX, and the bool -> byte
6565
// correspondence matches the HVX architecture.
6666
BoolWidth = BoolAlign = 8;
67+
BFloat16Width = BFloat16Align = 16;
68+
BFloat16Format = &llvm::APFloat::BFloat();
6769
}
6870

6971
llvm::SmallVector<Builtin::InfosShard> getTargetBuiltins() const override;
@@ -95,6 +97,8 @@ class LLVM_LIBRARY_VISIBILITY HexagonTargetInfo : public TargetInfo {
9597

9698
bool hasFeature(StringRef Feature) const override;
9799

100+
bool hasBFloat16Type() const override;
101+
98102
bool
99103
initFeatureMap(llvm::StringMap<bool> &Features, DiagnosticsEngine &Diags,
100104
StringRef CPU,

llvm/lib/Target/Hexagon/HexagonCallingConv.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def CC_HexagonStack: CallingConv<[
2525
def CC_Hexagon_Legacy: CallingConv<[
2626
CCIfType<[i1,i8,i16],
2727
CCPromoteToType<i32>>,
28+
CCIfType<[bf16],
29+
CCBitConvertToType<i32>>,
2830
CCIfType<[f32],
2931
CCBitConvertToType<i32>>,
3032
CCIfType<[f64],
@@ -55,6 +57,8 @@ def CC_Hexagon_Legacy: CallingConv<[
5557
def CC_Hexagon: CallingConv<[
5658
CCIfType<[i1,i8,i16],
5759
CCPromoteToType<i32>>,
60+
CCIfType<[bf16],
61+
CCBitConvertToType<i32>>,
5862
CCIfType<[f32],
5963
CCBitConvertToType<i32>>,
6064
CCIfType<[f64],
@@ -88,6 +92,8 @@ def CC_Hexagon: CallingConv<[
8892
def RetCC_Hexagon: CallingConv<[
8993
CCIfType<[i1,i8,i16],
9094
CCPromoteToType<i32>>,
95+
CCIfType<[bf16],
96+
CCBitConvertToType<i32>>,
9197
CCIfType<[f32],
9298
CCBitConvertToType<i32>>,
9399
CCIfType<[f64],
@@ -149,16 +155,16 @@ def CC_Hexagon_HVX: CallingConv<[
149155
CCIfType<[v128i1], CCPromoteToType<v128i8>>>,
150156

151157
CCIfHvx128<
152-
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
158+
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
153159
CCAssignToReg<[V0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15]>>>,
154160
CCIfHvx128<
155-
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
161+
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v128bf16],
156162
CCAssignToReg<[W0,W1,W2,W3,W4,W5,W6,W7]>>>,
157163
CCIfHvx128<
158-
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
164+
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
159165
CCAssignToStack<128,128>>>,
160166
CCIfHvx128<
161-
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
167+
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v64bf16],
162168
CCAssignToStack<256,128>>>,
163169

164170
CCDelegateTo<CC_Hexagon>
@@ -175,10 +181,10 @@ def RetCC_Hexagon_HVX: CallingConv<[
175181

176182
// HVX 128-byte mode
177183
CCIfHvx128<
178-
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
184+
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
179185
CCAssignToReg<[V0]>>>,
180186
CCIfHvx128<
181-
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
187+
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v128bf16],
182188
CCAssignToReg<[W0]>>>,
183189

184190
CCDelegateTo<RetCC_Hexagon>

llvm/lib/Target/Hexagon/HexagonISelLowering.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,6 +1677,8 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
16771677
}
16781678
// Turn FP truncstore into trunc + store.
16791679
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
1680+
setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
1681+
setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
16801682
// Turn FP extload into load/fpextend.
16811683
for (MVT VT : MVT::fp_valuetypes())
16821684
setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand);
@@ -1872,9 +1874,15 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
18721874
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
18731875
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
18741876
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
1877+
setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand);
1878+
setOperationAction(ISD::BF16_TO_FP, MVT::f64, Expand);
1879+
setOperationAction(ISD::FP_TO_BF16, MVT::f64, Expand);
18751880

18761881
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
18771882
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
1883+
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
1884+
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
1885+
18781886
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
18791887
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
18801888

llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ HexagonTargetLowering::initializeHVXLowering() {
8888
addRegisterClass(MVT::v64f32, &Hexagon::HvxWRRegClass);
8989
addRegisterClass(MVT::v128f16, &Hexagon::HvxWRRegClass);
9090
}
91+
if (Subtarget.useHVXV81Ops()) {
92+
addRegisterClass(MVT::v64bf16, &Hexagon::HvxVRRegClass);
93+
addRegisterClass(MVT::v128bf16, &Hexagon::HvxWRRegClass);
94+
}
9195
}
9296

9397
// Set up operation actions.
@@ -162,6 +166,30 @@ HexagonTargetLowering::initializeHVXLowering() {
162166
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v64f32, ByteW);
163167
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v32f32, ByteV);
164168

169+
if (Subtarget.useHVXV81Ops()) {
170+
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v128bf16, ByteW);
171+
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v64bf16, ByteV);
172+
setPromoteTo(ISD::SETCC, MVT::v64bf16, MVT::v64f32);
173+
setPromoteTo(ISD::FADD, MVT::v64bf16, MVT::v64f32);
174+
setPromoteTo(ISD::FSUB, MVT::v64bf16, MVT::v64f32);
175+
setPromoteTo(ISD::FMUL, MVT::v64bf16, MVT::v64f32);
176+
setPromoteTo(ISD::FMINNUM, MVT::v64bf16, MVT::v64f32);
177+
setPromoteTo(ISD::FMAXNUM, MVT::v64bf16, MVT::v64f32);
178+
179+
setOperationAction(ISD::SPLAT_VECTOR, MVT::v64bf16, Legal);
180+
setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64bf16, Custom);
181+
setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v64bf16, Custom);
182+
183+
setOperationAction(ISD::MLOAD, MVT::v64bf16, Custom);
184+
setOperationAction(ISD::MSTORE, MVT::v64bf16, Custom);
185+
setOperationAction(ISD::BUILD_VECTOR, MVT::v64bf16, Custom);
186+
setOperationAction(ISD::CONCAT_VECTORS, MVT::v64bf16, Custom);
187+
188+
setOperationAction(ISD::SPLAT_VECTOR, MVT::bf16, Custom);
189+
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::bf16, Custom);
190+
setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom);
191+
}
192+
165193
for (MVT P : FloatW) {
166194
setOperationAction(ISD::LOAD, P, Custom);
167195
setOperationAction(ISD::STORE, P, Custom);
@@ -462,6 +490,10 @@ HexagonTargetLowering::initializeHVXLowering() {
462490

463491
unsigned
464492
HexagonTargetLowering::getPreferredHvxVectorAction(MVT VecTy) const {
493+
// Early exit for invalid input types
494+
if (!VecTy.isVector())
495+
return ~0u;
496+
465497
MVT ElemTy = VecTy.getVectorElementType();
466498
unsigned VecLen = VecTy.getVectorNumElements();
467499
unsigned HwLen = Subtarget.getVectorLength();
@@ -1667,14 +1699,15 @@ HexagonTargetLowering::LowerHvxBuildVector(SDValue Op, SelectionDAG &DAG)
16671699
// In case of MVT::f16 BUILD_VECTOR, since MVT::f16 is
16681700
// not a legal type, just bitcast the node to use i16
16691701
// types and bitcast the result back to f16
1670-
if (VecTy.getVectorElementType() == MVT::f16) {
1671-
SmallVector<SDValue,64> NewOps;
1702+
if (VecTy.getVectorElementType() == MVT::f16 ||
1703+
VecTy.getVectorElementType() == MVT::bf16) {
1704+
SmallVector<SDValue, 64> NewOps;
16721705
for (unsigned i = 0; i != Size; i++)
16731706
NewOps.push_back(DAG.getBitcast(MVT::i16, Ops[i]));
16741707

1675-
SDValue T0 = DAG.getNode(ISD::BUILD_VECTOR, dl,
1676-
tyVector(VecTy, MVT::i16), NewOps);
1677-
return DAG.getBitcast(tyVector(VecTy, MVT::f16), T0);
1708+
SDValue T0 =
1709+
DAG.getNode(ISD::BUILD_VECTOR, dl, tyVector(VecTy, MVT::i16), NewOps);
1710+
return DAG.getBitcast(tyVector(VecTy, VecTy.getVectorElementType()), T0);
16781711
}
16791712

16801713
// First, split the BUILD_VECTOR for vector pairs. We could generate
@@ -1698,7 +1731,7 @@ HexagonTargetLowering::LowerHvxSplatVector(SDValue Op, SelectionDAG &DAG)
16981731
MVT VecTy = ty(Op);
16991732
MVT ArgTy = ty(Op.getOperand(0));
17001733

1701-
if (ArgTy == MVT::f16) {
1734+
if (ArgTy == MVT::f16 || ArgTy == MVT::bf16) {
17021735
MVT SplatTy = MVT::getVectorVT(MVT::i16, VecTy.getVectorNumElements());
17031736
SDValue ToInt16 = DAG.getBitcast(MVT::i16, Op.getOperand(0));
17041737
SDValue ToInt32 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, ToInt16);
@@ -1831,12 +1864,12 @@ HexagonTargetLowering::LowerHvxInsertElement(SDValue Op, SelectionDAG &DAG)
18311864
if (ElemTy == MVT::i1)
18321865
return insertHvxElementPred(VecV, IdxV, ValV, dl, DAG);
18331866

1834-
if (ElemTy == MVT::f16) {
1867+
if (ElemTy == MVT::f16 || ElemTy == MVT::bf16) {
18351868
SDValue T0 = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl,
18361869
tyVector(VecTy, MVT::i16),
18371870
DAG.getBitcast(tyVector(VecTy, MVT::i16), VecV),
18381871
DAG.getBitcast(MVT::i16, ValV), IdxV);
1839-
return DAG.getBitcast(tyVector(VecTy, MVT::f16), T0);
1872+
return DAG.getBitcast(tyVector(VecTy, ElemTy), T0);
18401873
}
18411874

18421875
return insertHvxElementReg(VecV, IdxV, ValV, dl, DAG);
@@ -2334,6 +2367,25 @@ SDValue HexagonTargetLowering::LowerHvxFpExtend(SDValue Op,
23342367
MVT VecTy = ty(Op);
23352368
MVT ArgTy = ty(Op.getOperand(0));
23362369
const SDLoc &dl(Op);
2370+
2371+
if (ArgTy == MVT::v64bf16) {
2372+
MVT HalfTy = typeSplit(VecTy).first;
2373+
SDValue BF16Vec = Op.getOperand(0);
2374+
SDValue Zeroes =
2375+
getInstr(Hexagon::V6_vxor, dl, HalfTy, {BF16Vec, BF16Vec}, DAG);
2376+
// Interleave zero vector with the bf16 vector, with zeroes in the lower
2377+
// half of each 32 bit lane, effectively extending the bf16 values to fp32
2378+
// values.
2379+
SDValue ShuffVec =
2380+
getInstr(Hexagon::V6_vshufoeh, dl, VecTy, {BF16Vec, Zeroes}, DAG);
2381+
VectorPair VecPair = opSplit(ShuffVec, dl, DAG);
2382+
SDValue Result = getInstr(Hexagon::V6_vshuffvdd, dl, VecTy,
2383+
{VecPair.second, VecPair.first,
2384+
DAG.getSignedConstant(-4, dl, MVT::i32)},
2385+
DAG);
2386+
return Result;
2387+
}
2388+
23372389
assert(VecTy == MVT::v64f32 && ArgTy == MVT::v64f16);
23382390

23392391
SDValue F16Vec = Op.getOperand(0);

llvm/lib/Target/Hexagon/HexagonPatterns.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ def Fptoui: pf1<fp_to_uint>;
391391
def Sitofp: pf1<sint_to_fp>;
392392
def Uitofp: pf1<uint_to_fp>;
393393

394-
395394
// --(1) Immediate -------------------------------------------------------
396395
//
397396

@@ -474,6 +473,18 @@ def: OpR_R_pat<F2_conv_df2uw_chop, pf1<fp_to_uint>, i32, F64>;
474473
def: OpR_R_pat<F2_conv_sf2ud_chop, pf1<fp_to_uint>, i64, F32>;
475474
def: OpR_R_pat<F2_conv_df2ud_chop, pf1<fp_to_uint>, i64, F64>;
476475

476+
def: Pat<(i32 (fp_to_bf16 F32:$v)),
477+
(C2_mux (F2_sfclass F32:$v, 0x10), (A2_tfrsi(i32 0x7fff)),
478+
(C2_mux
479+
(C2_cmpeq
480+
(A2_and F32:$v, (A2_tfrsi (i32 0x1FFFF))),
481+
(A2_tfrsi (i32 0x08000))),
482+
(A2_and (A2_asrh F32:$v), (A2_tfrsi (i32 65535))),
483+
(A2_and
484+
(A2_asrh
485+
(A2_add F32:$v, (A2_and F32:$v, (A2_tfrsi (i32 0x8000))))),
486+
(A2_tfrsi (i32 65535))))
487+
)>;
477488
// Bitcast is different than [fp|sint|uint]_to_[sint|uint|fp].
478489
def: Pat<(i32 (bitconvert F32:$v)), (I32:$v)>;
479490
def: Pat<(f32 (bitconvert I32:$v)), (F32:$v)>;

0 commit comments

Comments
 (0)