@@ -88,6 +88,10 @@ HexagonTargetLowering::initializeHVXLowering() {
8888 addRegisterClass (MVT::v64f32, &Hexagon::HvxWRRegClass);
8989 addRegisterClass (MVT::v128f16, &Hexagon::HvxWRRegClass);
9090 }
91+ if (Subtarget.useHVXV81Ops ()) {
92+ addRegisterClass (MVT::v64bf16, &Hexagon::HvxVRRegClass);
93+ addRegisterClass (MVT::v128bf16, &Hexagon::HvxWRRegClass);
94+ }
9195 }
9296
9397 // Set up operation actions.
@@ -162,6 +166,30 @@ HexagonTargetLowering::initializeHVXLowering() {
162166 setPromoteTo (ISD::VECTOR_SHUFFLE, MVT::v64f32, ByteW);
163167 setPromoteTo (ISD::VECTOR_SHUFFLE, MVT::v32f32, ByteV);
164168
169+ if (Subtarget.useHVXV81Ops ()) {
170+ setPromoteTo (ISD::VECTOR_SHUFFLE, MVT::v128bf16, ByteW);
171+ setPromoteTo (ISD::VECTOR_SHUFFLE, MVT::v64bf16, ByteV);
172+ setPromoteTo (ISD::SETCC, MVT::v64bf16, MVT::v64f32);
173+ setPromoteTo (ISD::FADD, MVT::v64bf16, MVT::v64f32);
174+ setPromoteTo (ISD::FSUB, MVT::v64bf16, MVT::v64f32);
175+ setPromoteTo (ISD::FMUL, MVT::v64bf16, MVT::v64f32);
176+ setPromoteTo (ISD::FMINNUM, MVT::v64bf16, MVT::v64f32);
177+ setPromoteTo (ISD::FMAXNUM, MVT::v64bf16, MVT::v64f32);
178+
179+ setOperationAction (ISD::SPLAT_VECTOR, MVT::v64bf16, Legal);
180+ setOperationAction (ISD::INSERT_SUBVECTOR, MVT::v64bf16, Custom);
181+ setOperationAction (ISD::EXTRACT_SUBVECTOR, MVT::v64bf16, Custom);
182+
183+ setOperationAction (ISD::MLOAD, MVT::v64bf16, Custom);
184+ setOperationAction (ISD::MSTORE, MVT::v64bf16, Custom);
185+ setOperationAction (ISD::BUILD_VECTOR, MVT::v64bf16, Custom);
186+ setOperationAction (ISD::CONCAT_VECTORS, MVT::v64bf16, Custom);
187+
188+ setOperationAction (ISD::SPLAT_VECTOR, MVT::bf16 , Custom);
189+ setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::bf16 , Custom);
190+ setOperationAction (ISD::BUILD_VECTOR, MVT::bf16 , Custom);
191+ }
192+
165193 for (MVT P : FloatW) {
166194 setOperationAction (ISD::LOAD, P, Custom);
167195 setOperationAction (ISD::STORE, P, Custom);
@@ -462,6 +490,10 @@ HexagonTargetLowering::initializeHVXLowering() {
462490
463491unsigned
464492HexagonTargetLowering::getPreferredHvxVectorAction (MVT VecTy) const {
493+ // Early exit for invalid input types
494+ if (!VecTy.isVector ())
495+ return ~0u ;
496+
465497 MVT ElemTy = VecTy.getVectorElementType ();
466498 unsigned VecLen = VecTy.getVectorNumElements ();
467499 unsigned HwLen = Subtarget.getVectorLength ();
@@ -1667,14 +1699,15 @@ HexagonTargetLowering::LowerHvxBuildVector(SDValue Op, SelectionDAG &DAG)
16671699 // In case of MVT::f16 BUILD_VECTOR, since MVT::f16 is
16681700 // not a legal type, just bitcast the node to use i16
16691701 // types and bitcast the result back to f16
1670- if (VecTy.getVectorElementType () == MVT::f16 ) {
1671- SmallVector<SDValue,64 > NewOps;
1702+ if (VecTy.getVectorElementType () == MVT::f16 ||
1703+ VecTy.getVectorElementType () == MVT::bf16 ) {
1704+ SmallVector<SDValue, 64 > NewOps;
16721705 for (unsigned i = 0 ; i != Size; i++)
16731706 NewOps.push_back (DAG.getBitcast (MVT::i16 , Ops[i]));
16741707
1675- SDValue T0 = DAG. getNode (ISD::BUILD_VECTOR, dl,
1676- tyVector (VecTy, MVT::i16 ), NewOps);
1677- return DAG.getBitcast (tyVector (VecTy, MVT:: f16 ), T0);
1708+ SDValue T0 =
1709+ DAG. getNode (ISD::BUILD_VECTOR, dl, tyVector (VecTy, MVT::i16 ), NewOps);
1710+ return DAG.getBitcast (tyVector (VecTy, VecTy. getVectorElementType () ), T0);
16781711 }
16791712
16801713 // First, split the BUILD_VECTOR for vector pairs. We could generate
@@ -1698,7 +1731,7 @@ HexagonTargetLowering::LowerHvxSplatVector(SDValue Op, SelectionDAG &DAG)
16981731 MVT VecTy = ty (Op);
16991732 MVT ArgTy = ty (Op.getOperand (0 ));
17001733
1701- if (ArgTy == MVT::f16 ) {
1734+ if (ArgTy == MVT::f16 || ArgTy == MVT:: bf16 ) {
17021735 MVT SplatTy = MVT::getVectorVT (MVT::i16 , VecTy.getVectorNumElements ());
17031736 SDValue ToInt16 = DAG.getBitcast (MVT::i16 , Op.getOperand (0 ));
17041737 SDValue ToInt32 = DAG.getNode (ISD::ANY_EXTEND, dl, MVT::i32 , ToInt16);
@@ -1831,12 +1864,12 @@ HexagonTargetLowering::LowerHvxInsertElement(SDValue Op, SelectionDAG &DAG)
18311864 if (ElemTy == MVT::i1)
18321865 return insertHvxElementPred (VecV, IdxV, ValV, dl, DAG);
18331866
1834- if (ElemTy == MVT::f16 ) {
1867+ if (ElemTy == MVT::f16 || ElemTy == MVT:: bf16 ) {
18351868 SDValue T0 = DAG.getNode (ISD::INSERT_VECTOR_ELT, dl,
18361869 tyVector (VecTy, MVT::i16 ),
18371870 DAG.getBitcast (tyVector (VecTy, MVT::i16 ), VecV),
18381871 DAG.getBitcast (MVT::i16 , ValV), IdxV);
1839- return DAG.getBitcast (tyVector (VecTy, MVT:: f16 ), T0);
1872+ return DAG.getBitcast (tyVector (VecTy, ElemTy ), T0);
18401873 }
18411874
18421875 return insertHvxElementReg (VecV, IdxV, ValV, dl, DAG);
@@ -2334,6 +2367,25 @@ SDValue HexagonTargetLowering::LowerHvxFpExtend(SDValue Op,
23342367 MVT VecTy = ty (Op);
23352368 MVT ArgTy = ty (Op.getOperand (0 ));
23362369 const SDLoc &dl (Op);
2370+
2371+ if (ArgTy == MVT::v64bf16) {
2372+ MVT HalfTy = typeSplit (VecTy).first ;
2373+ SDValue BF16Vec = Op.getOperand (0 );
2374+ SDValue Zeroes =
2375+ getInstr (Hexagon::V6_vxor, dl, HalfTy, {BF16Vec, BF16Vec}, DAG);
2376+ // Interleave zero vector with the bf16 vector, with zeroes in the lower
2377+ // half of each 32 bit lane, effectively extending the bf16 values to fp32
2378+ // values.
2379+ SDValue ShuffVec =
2380+ getInstr (Hexagon::V6_vshufoeh, dl, VecTy, {BF16Vec, Zeroes}, DAG);
2381+ VectorPair VecPair = opSplit (ShuffVec, dl, DAG);
2382+ SDValue Result = getInstr (Hexagon::V6_vshuffvdd, dl, VecTy,
2383+ {VecPair.second , VecPair.first ,
2384+ DAG.getSignedConstant (-4 , dl, MVT::i32 )},
2385+ DAG);
2386+ return Result;
2387+ }
2388+
23372389 assert (VecTy == MVT::v64f32 && ArgTy == MVT::v64f16);
23382390
23392391 SDValue F16Vec = Op.getOperand (0 );
0 commit comments