diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 8457f6178fdc2..9508e63630669 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1801,6 +1801,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32); setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32); } + + if (Subtarget->hasBF16() && + (Subtarget->hasSVE() || Subtarget->hasSME())) { + for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) + setOperationAction(ISD::FMUL, VT, Custom); + } } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); @@ -7529,6 +7535,43 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op, EndOfTrmp); } +SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + auto &Subtarget = DAG.getSubtarget(); + if (VT.getScalarType() != MVT::bf16 || + !(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME()))) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); + + SDLoc DL(Op); + SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32); + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + + auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Ops...); + }; + + SDValue Pg = + getPTrue(DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1, + AArch64SVEPredPattern::all); + // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom + // instructions. These result in two f32 vectors, which can be converted back + // to bf16 with FCVT and FCVNT. + SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32, + Zero, LHS, RHS); + SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT, + DAG.getPOISON(VT), Pg, BottomF32); + if (VT == MVT::nxv8bf16) { + // Note: nxv2bf16 and nxv4bf16 only use even lanes. + SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32, + Zero, LHS, RHS); + return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT, + BottomBF16, Pg, TopF32); + } + return BottomBF16; +} + SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Custom lowering: "); @@ -7603,7 +7646,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::FSUB: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED); case ISD::FMUL: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); + return LowerFMUL(Op, DAG); case ISD::FMA: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); case ISD::FDIV: diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 70bfae717fb76..a926a4822d9cc 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -609,6 +609,7 @@ class AArch64TargetLowering : public TargetLowering { SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const; SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll index 0580f5e0b019a..d8f1ec0241a17 100644 --- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll @@ -1,7 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16 +; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-NONSTREAMING ; RUN: llc -mattr=+sve,+bf16,+sve-b16b16 < %s | FileCheck %s --check-prefixes=CHECK,B16B16 -; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16 +; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-STREAMING ; RUN: llc -mattr=+sme2,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16 target triple = "aarch64-unknown-linux-gnu" @@ -520,64 +520,82 @@ define @fmla_nxv8bf16( %a, @fmul_nxv2bf16( %a, %b) { -; NOB16B16-LABEL: fmul_nxv2bf16: -; NOB16B16: // %bb.0: -; NOB16B16-NEXT: lsl z1.s, z1.s, #16 -; NOB16B16-NEXT: lsl z0.s, z0.s, #16 -; NOB16B16-NEXT: ptrue p0.d -; NOB16B16-NEXT: fmul z0.s, p0/m, z0.s, z1.s -; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s -; NOB16B16-NEXT: ret +; NOB16B16-NONSTREAMING-LABEL: fmul_nxv2bf16: +; NOB16B16-NONSTREAMING: // %bb.0: +; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000 +; NOB16B16-NONSTREAMING-NEXT: ptrue p0.d +; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-NONSTREAMING-NEXT: ret ; ; B16B16-LABEL: fmul_nxv2bf16: ; B16B16: // %bb.0: ; B16B16-NEXT: bfmul z0.h, z0.h, z1.h ; B16B16-NEXT: ret +; +; NOB16B16-STREAMING-LABEL: fmul_nxv2bf16: +; NOB16B16-STREAMING: // %bb.0: +; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0 +; NOB16B16-STREAMING-NEXT: ptrue p0.d +; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-STREAMING-NEXT: ret %res = fmul %a, %b ret %res } define @fmul_nxv4bf16( %a, %b) { -; NOB16B16-LABEL: fmul_nxv4bf16: -; NOB16B16: // %bb.0: -; NOB16B16-NEXT: lsl z1.s, z1.s, #16 -; NOB16B16-NEXT: lsl z0.s, z0.s, #16 -; NOB16B16-NEXT: ptrue p0.s -; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s -; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s -; NOB16B16-NEXT: ret +; NOB16B16-NONSTREAMING-LABEL: fmul_nxv4bf16: +; NOB16B16-NONSTREAMING: // %bb.0: +; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000 +; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s +; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-NONSTREAMING-NEXT: ret ; ; B16B16-LABEL: fmul_nxv4bf16: ; B16B16: // %bb.0: ; B16B16-NEXT: bfmul z0.h, z0.h, z1.h ; B16B16-NEXT: ret +; +; NOB16B16-STREAMING-LABEL: fmul_nxv4bf16: +; NOB16B16-STREAMING: // %bb.0: +; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0 +; NOB16B16-STREAMING-NEXT: ptrue p0.s +; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-STREAMING-NEXT: ret %res = fmul %a, %b ret %res } define @fmul_nxv8bf16( %a, %b) { -; NOB16B16-LABEL: fmul_nxv8bf16: -; NOB16B16: // %bb.0: -; NOB16B16-NEXT: uunpkhi z2.s, z1.h -; NOB16B16-NEXT: uunpkhi z3.s, z0.h -; NOB16B16-NEXT: uunpklo z1.s, z1.h -; NOB16B16-NEXT: uunpklo z0.s, z0.h -; NOB16B16-NEXT: ptrue p0.s -; NOB16B16-NEXT: lsl z2.s, z2.s, #16 -; NOB16B16-NEXT: lsl z3.s, z3.s, #16 -; NOB16B16-NEXT: lsl z1.s, z1.s, #16 -; NOB16B16-NEXT: lsl z0.s, z0.s, #16 -; NOB16B16-NEXT: fmul z2.s, z3.s, z2.s -; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s -; NOB16B16-NEXT: bfcvt z1.h, p0/m, z2.s -; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s -; NOB16B16-NEXT: uzp1 z0.h, z0.h, z1.h -; NOB16B16-NEXT: ret +; NOB16B16-NONSTREAMING-LABEL: fmul_nxv8bf16: +; NOB16B16-NONSTREAMING: // %bb.0: +; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000 +; NOB16B16-NONSTREAMING-NEXT: movi v3.2d, #0000000000000000 +; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s +; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-NONSTREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h +; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-NONSTREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s +; NOB16B16-NONSTREAMING-NEXT: ret ; ; B16B16-LABEL: fmul_nxv8bf16: ; B16B16: // %bb.0: ; B16B16-NEXT: bfmul z0.h, z0.h, z1.h ; B16B16-NEXT: ret +; +; NOB16B16-STREAMING-LABEL: fmul_nxv8bf16: +; NOB16B16-STREAMING: // %bb.0: +; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0 +; NOB16B16-STREAMING-NEXT: mov z3.s, #0 // =0x0 +; NOB16B16-STREAMING-NEXT: ptrue p0.s +; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-STREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h +; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-STREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s +; NOB16B16-STREAMING-NEXT: ret %res = fmul %a, %b ret %res }