From a55a3f172d8cf6d01aabe8ae459d8f9192605882 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Tue, 11 Nov 2025 13:14:00 +0000 Subject: [PATCH 1/8] [AArch64][SVE] Add custom lowering for bfloat FMUL (with +bf16) This lowers an SVE FMUL of bf16 using the BFMLAL top/bottom instructions rather than extending to an f32 mul. This does require zeroing the accumulator, but requires fewer extends/unpacking. --- .../Target/AArch64/AArch64ISelLowering.cpp | 45 +++++++++- llvm/lib/Target/AArch64/AArch64ISelLowering.h | 1 + llvm/test/CodeGen/AArch64/sve-bf16-arith.ll | 88 +++++++++++-------- 3 files changed, 98 insertions(+), 36 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 8f41f230b5521..26ea2c91a08b2 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1815,6 +1815,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32); setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32); } + + if (Subtarget->hasBF16() && + (Subtarget->hasSVE() || Subtarget->hasSME())) { + for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) + setOperationAction(ISD::FMUL, VT, Custom); + } } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); @@ -7641,6 +7647,43 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op, EndOfTrmp); } +SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + auto &Subtarget = DAG.getSubtarget(); + if (VT.getScalarType() != MVT::bf16 || + !(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME()))) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); + + SDLoc DL(Op); + SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32); + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + + auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Ops...); + }; + + SDValue Pg = + getPTrue(DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1, + AArch64SVEPredPattern::all); + // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom + // instructions. These result in two f32 vectors, which can be converted back + // to bf16 with FCVT and FCVNT. + SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32, + Zero, LHS, RHS); + SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT, + DAG.getPOISON(VT), Pg, BottomF32); + if (VT == MVT::nxv8bf16) { + // Note: nxv2bf16 and nxv4bf16 only use even lanes. + SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32, + Zero, LHS, RHS); + return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT, + BottomBF16, Pg, TopF32); + } + return BottomBF16; +} + SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Custom lowering: "); @@ -7715,7 +7758,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::FSUB: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED); case ISD::FMUL: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); + return LowerFMUL(Op, DAG); case ISD::FMA: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); case ISD::FDIV: diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index be198e54cbcbf..ca08eb40c956a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -614,6 +614,7 @@ class AArch64TargetLowering : public TargetLowering { SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const; SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll index 582e8456c05b3..95db2666dbbba 100644 --- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll @@ -1,7 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16 +; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-NONSTREAMING ; RUN: llc -mattr=+sve,+bf16,+sve-b16b16 < %s | FileCheck %s --check-prefixes=CHECK,B16B16 -; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16 +; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-STREAMING ; RUN: llc -mattr=+sme2,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16 target triple = "aarch64-unknown-linux-gnu" @@ -514,64 +514,82 @@ define @fmla_nxv8bf16( %a, @fmul_nxv2bf16( %a, %b) { -; NOB16B16-LABEL: fmul_nxv2bf16: -; NOB16B16: // %bb.0: -; NOB16B16-NEXT: lsl z1.s, z1.s, #16 -; NOB16B16-NEXT: lsl z0.s, z0.s, #16 -; NOB16B16-NEXT: ptrue p0.d -; NOB16B16-NEXT: fmul z0.s, p0/m, z0.s, z1.s -; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s -; NOB16B16-NEXT: ret +; NOB16B16-NONSTREAMING-LABEL: fmul_nxv2bf16: +; NOB16B16-NONSTREAMING: // %bb.0: +; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000 +; NOB16B16-NONSTREAMING-NEXT: ptrue p0.d +; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-NONSTREAMING-NEXT: ret ; ; B16B16-LABEL: fmul_nxv2bf16: ; B16B16: // %bb.0: ; B16B16-NEXT: bfmul z0.h, z0.h, z1.h ; B16B16-NEXT: ret +; +; NOB16B16-STREAMING-LABEL: fmul_nxv2bf16: +; NOB16B16-STREAMING: // %bb.0: +; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0 +; NOB16B16-STREAMING-NEXT: ptrue p0.d +; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-STREAMING-NEXT: ret %res = fmul %a, %b ret %res } define @fmul_nxv4bf16( %a, %b) { -; NOB16B16-LABEL: fmul_nxv4bf16: -; NOB16B16: // %bb.0: -; NOB16B16-NEXT: lsl z1.s, z1.s, #16 -; NOB16B16-NEXT: lsl z0.s, z0.s, #16 -; NOB16B16-NEXT: ptrue p0.s -; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s -; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s -; NOB16B16-NEXT: ret +; NOB16B16-NONSTREAMING-LABEL: fmul_nxv4bf16: +; NOB16B16-NONSTREAMING: // %bb.0: +; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000 +; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s +; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-NONSTREAMING-NEXT: ret ; ; B16B16-LABEL: fmul_nxv4bf16: ; B16B16: // %bb.0: ; B16B16-NEXT: bfmul z0.h, z0.h, z1.h ; B16B16-NEXT: ret +; +; NOB16B16-STREAMING-LABEL: fmul_nxv4bf16: +; NOB16B16-STREAMING: // %bb.0: +; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0 +; NOB16B16-STREAMING-NEXT: ptrue p0.s +; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-STREAMING-NEXT: ret %res = fmul %a, %b ret %res } define @fmul_nxv8bf16( %a, %b) { -; NOB16B16-LABEL: fmul_nxv8bf16: -; NOB16B16: // %bb.0: -; NOB16B16-NEXT: uunpkhi z2.s, z1.h -; NOB16B16-NEXT: uunpkhi z3.s, z0.h -; NOB16B16-NEXT: uunpklo z1.s, z1.h -; NOB16B16-NEXT: uunpklo z0.s, z0.h -; NOB16B16-NEXT: ptrue p0.s -; NOB16B16-NEXT: lsl z2.s, z2.s, #16 -; NOB16B16-NEXT: lsl z3.s, z3.s, #16 -; NOB16B16-NEXT: lsl z1.s, z1.s, #16 -; NOB16B16-NEXT: lsl z0.s, z0.s, #16 -; NOB16B16-NEXT: fmul z2.s, z3.s, z2.s -; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s -; NOB16B16-NEXT: bfcvt z1.h, p0/m, z2.s -; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s -; NOB16B16-NEXT: uzp1 z0.h, z0.h, z1.h -; NOB16B16-NEXT: ret +; NOB16B16-NONSTREAMING-LABEL: fmul_nxv8bf16: +; NOB16B16-NONSTREAMING: // %bb.0: +; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000 +; NOB16B16-NONSTREAMING-NEXT: movi v3.2d, #0000000000000000 +; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s +; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-NONSTREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h +; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-NONSTREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s +; NOB16B16-NONSTREAMING-NEXT: ret ; ; B16B16-LABEL: fmul_nxv8bf16: ; B16B16: // %bb.0: ; B16B16-NEXT: bfmul z0.h, z0.h, z1.h ; B16B16-NEXT: ret +; +; NOB16B16-STREAMING-LABEL: fmul_nxv8bf16: +; NOB16B16-STREAMING: // %bb.0: +; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0 +; NOB16B16-STREAMING-NEXT: mov z3.s, #0 // =0x0 +; NOB16B16-STREAMING-NEXT: ptrue p0.s +; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-STREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h +; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-STREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s +; NOB16B16-STREAMING-NEXT: ret %res = fmul %a, %b ret %res } From a59d6a06ac22bd25b299855e585af332a79d1c32 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 14 Nov 2025 15:19:40 +0000 Subject: [PATCH 2/8] Fixups --- .../Target/AArch64/AArch64ISelLowering.cpp | 71 +++++++++++-------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 26ea2c91a08b2..40b14db16658a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1809,17 +1809,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, if (!Subtarget->hasSVEB16B16() || !Subtarget->isNonStreamingSVEorSME2Available()) { - for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM, - ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) { - setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32); - setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32); - setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32); - } - - if (Subtarget->hasBF16() && - (Subtarget->hasSVE() || Subtarget->hasSME())) { - for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) + for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) { + MVT PromotedVT = VT.changeVectorElementType(MVT::f32); + setOperationPromotedToType(ISD::FADD, VT, PromotedVT); + setOperationPromotedToType(ISD::FMA, VT, PromotedVT); + setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FSUB, VT, PromotedVT); + + if (Subtarget->hasBF16()) setOperationAction(ISD::FMUL, VT, Custom); + else + setOperationPromotedToType(ISD::FMUL, VT, PromotedVT); } } @@ -7648,40 +7651,46 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op, } SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); EVT VT = Op.getValueType(); auto &Subtarget = DAG.getSubtarget(); if (VT.getScalarType() != MVT::bf16 || - !(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME()))) + (Subtarget.hasSVEB16B16() && + Subtarget.isNonStreamingSVEorSME2Available())) return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); - SDLoc DL(Op); + assert(Subtarget.hasBF16() && "Expected +bf16 for custom FMUL lowering"); + + auto MakeGetIntrinsic = [&](Intrinsic::ID IID) { + return [&, IID](EVT VT, auto... Ops) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Ops...); + }; + }; + + // Create helpers for building intrinsic calls. + auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb); + auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt); + auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2); + auto FCVNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2); + SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32); SDValue LHS = Op.getOperand(0); SDValue RHS = Op.getOperand(1); - auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) { - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, - DAG.getConstant(IID, DL, MVT::i32), Ops...); - }; - SDValue Pg = - getPTrue(DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1, - AArch64SVEPredPattern::all); + DAG.getConstant(1, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1); + // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom // instructions. These result in two f32 vectors, which can be converted back // to bf16 with FCVT and FCVNT. - SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32, - Zero, LHS, RHS); - SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT, - DAG.getPOISON(VT), Pg, BottomF32); - if (VT == MVT::nxv8bf16) { - // Note: nxv2bf16 and nxv4bf16 only use even lanes. - SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32, - Zero, LHS, RHS); - return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT, - BottomBF16, Pg, TopF32); - } - return BottomBF16; + SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS); + SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32); + // Note: nxv2bf16 and nxv4bf16 only use even lanes. + if (VT != MVT::nxv8bf16) + return BottomBF16; + SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS); + return FCVNT(VT, BottomBF16, Pg, TopF32); } SDValue AArch64TargetLowering::LowerOperation(SDValue Op, From 91f71a60e10f5d6385b411e96c24a420bd905f71 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 14 Nov 2025 16:30:09 +0000 Subject: [PATCH 3/8] Use DAG.getNeutralElement() --- .../Target/AArch64/AArch64ISelLowering.cpp | 3 ++- llvm/test/CodeGen/AArch64/sve-bf16-arith.ll | 25 +++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 40b14db16658a..cd70ae99e4ae2 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7674,10 +7674,11 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2); auto FCVNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2); - SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32); SDValue LHS = Op.getOperand(0); SDValue RHS = Op.getOperand(1); + SDValue Zero = + DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags()); SDValue Pg = DAG.getConstant(1, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1); diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll index 95db2666dbbba..7d5aed2898bfd 100644 --- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll @@ -534,7 +534,7 @@ define @fmul_nxv2bf16( %a, %a, %b + %res = fmul nsz %a, %b ret %res } @@ -559,7 +559,7 @@ define @fmul_nxv4bf16( %a, %a, %b + %res = fmul nsz %a, %b ret %res } @@ -590,6 +590,27 @@ define @fmul_nxv8bf16( %a, %a, %b + ret %res +} + +define @fmul_nxv8bf16_no_nsz( %a, %b) { +; NOB16B16-LABEL: fmul_nxv8bf16_no_nsz: +; NOB16B16: // %bb.0: +; NOB16B16-NEXT: mov w8, #-2147483648 // =0x80000000 +; NOB16B16-NEXT: ptrue p0.s +; NOB16B16-NEXT: mov z2.s, w8 +; NOB16B16-NEXT: mov z3.d, z2.d +; NOB16B16-NEXT: bfmlalb z2.s, z0.h, z1.h +; NOB16B16-NEXT: bfmlalt z3.s, z0.h, z1.h +; NOB16B16-NEXT: bfcvt z0.h, p0/m, z2.s +; NOB16B16-NEXT: bfcvtnt z0.h, p0/m, z3.s +; NOB16B16-NEXT: ret +; +; B16B16-LABEL: fmul_nxv8bf16_no_nsz: +; B16B16: // %bb.0: +; B16B16-NEXT: bfmul z0.h, z0.h, z1.h +; B16B16-NEXT: ret %res = fmul %a, %b ret %res } From af0c523feace91deca8cf4a510931a9b038982bd Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 14 Nov 2025 17:17:24 +0000 Subject: [PATCH 4/8] Update bf16-combines checks --- .../test/CodeGen/AArch64/sve-bf16-combines.ll | 198 +++++++----------- 1 file changed, 80 insertions(+), 118 deletions(-) diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll b/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll index fc3e018f2ec7a..9f4716c733709 100644 --- a/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll +++ b/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll @@ -414,28 +414,21 @@ define @fsub_sel_negzero_nxv8bf16( %a define @fadd_sel_fmul_nxv8bf16( %a, %b, %c, %mask) { ; SVE-LABEL: fadd_sel_fmul_nxv8bf16: ; SVE: // %bb.0: -; SVE-NEXT: uunpkhi z3.s, z2.h -; SVE-NEXT: uunpkhi z4.s, z1.h -; SVE-NEXT: uunpklo z2.s, z2.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: mov z3.s, #0x80000000 +; SVE-NEXT: mov z4.s, #0x80000000 ; SVE-NEXT: ptrue p1.s -; SVE-NEXT: lsl z3.s, z3.s, #16 -; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z1.s, z1.s, #16 -; SVE-NEXT: fmul z3.s, z4.s, z3.s -; SVE-NEXT: fmul z1.s, z1.s, z2.s -; SVE-NEXT: bfcvt z2.h, p1/m, z3.s -; SVE-NEXT: movi v3.2d, #0000000000000000 -; SVE-NEXT: bfcvt z1.h, p1/m, z1.s -; SVE-NEXT: uzp1 z1.h, z1.h, z2.h -; SVE-NEXT: sel z1.h, p0, z1.h, z3.h +; SVE-NEXT: bfmlalb z3.s, z1.h, z2.h +; SVE-NEXT: bfmlalt z4.s, z1.h, z2.h +; SVE-NEXT: movi v2.2d, #0000000000000000 +; SVE-NEXT: bfcvt z1.h, p1/m, z3.s ; SVE-NEXT: uunpkhi z3.s, z0.h ; SVE-NEXT: uunpklo z0.s, z0.h -; SVE-NEXT: uunpkhi z2.s, z1.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: bfcvtnt z1.h, p1/m, z4.s ; SVE-NEXT: lsl z3.s, z3.s, #16 ; SVE-NEXT: lsl z0.s, z0.s, #16 +; SVE-NEXT: sel z1.h, p0, z1.h, z2.h +; SVE-NEXT: uunpkhi z2.s, z1.h +; SVE-NEXT: uunpklo z1.s, z1.h ; SVE-NEXT: lsl z2.s, z2.s, #16 ; SVE-NEXT: lsl z1.s, z1.s, #16 ; SVE-NEXT: fadd z2.s, z3.s, z2.s @@ -461,24 +454,20 @@ define @fadd_sel_fmul_nxv8bf16( %a, < define @fsub_sel_fmul_nxv8bf16( %a, %b, %c, %mask) { ; SVE-LABEL: fsub_sel_fmul_nxv8bf16: ; SVE: // %bb.0: -; SVE-NEXT: uunpkhi z3.s, z2.h -; SVE-NEXT: uunpkhi z4.s, z1.h -; SVE-NEXT: uunpklo z2.s, z2.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: mov z3.s, #0x80000000 +; SVE-NEXT: mov z4.s, #0x80000000 ; SVE-NEXT: ptrue p1.s -; SVE-NEXT: lsl z3.s, z3.s, #16 -; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z1.s, z1.s, #16 -; SVE-NEXT: fmul z3.s, z4.s, z3.s -; SVE-NEXT: uunpklo z4.s, z0.h -; SVE-NEXT: fmul z1.s, z1.s, z2.s -; SVE-NEXT: bfcvt z2.h, p1/m, z3.s +; SVE-NEXT: bfmlalb z3.s, z1.h, z2.h +; SVE-NEXT: bfmlalt z4.s, z1.h, z2.h +; SVE-NEXT: bfcvt z1.h, p1/m, z3.s ; SVE-NEXT: uunpkhi z3.s, z0.h +; SVE-NEXT: bfcvtnt z1.h, p1/m, z4.s +; SVE-NEXT: uunpklo z4.s, z0.h +; SVE-NEXT: lsl z3.s, z3.s, #16 +; SVE-NEXT: uunpkhi z2.s, z1.h +; SVE-NEXT: uunpklo z1.s, z1.h ; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: bfcvt z1.h, p1/m, z1.s ; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z3.s, z3.s, #16 ; SVE-NEXT: lsl z1.s, z1.s, #16 ; SVE-NEXT: fsub z2.s, z3.s, z2.s ; SVE-NEXT: fsub z1.s, z4.s, z1.s @@ -503,24 +492,20 @@ define @fsub_sel_fmul_nxv8bf16( %a, < define @fadd_sel_fmul_nsz_nxv8bf16( %a, %b, %c, %mask) { ; SVE-LABEL: fadd_sel_fmul_nsz_nxv8bf16: ; SVE: // %bb.0: -; SVE-NEXT: uunpkhi z3.s, z2.h -; SVE-NEXT: uunpkhi z4.s, z1.h -; SVE-NEXT: uunpklo z2.s, z2.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: mov z3.s, #0x80000000 +; SVE-NEXT: mov z4.s, #0x80000000 ; SVE-NEXT: ptrue p1.s -; SVE-NEXT: lsl z3.s, z3.s, #16 -; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z1.s, z1.s, #16 -; SVE-NEXT: fmul z3.s, z4.s, z3.s -; SVE-NEXT: uunpklo z4.s, z0.h -; SVE-NEXT: fmul z1.s, z1.s, z2.s -; SVE-NEXT: bfcvt z2.h, p1/m, z3.s +; SVE-NEXT: bfmlalb z3.s, z1.h, z2.h +; SVE-NEXT: bfmlalt z4.s, z1.h, z2.h +; SVE-NEXT: bfcvt z1.h, p1/m, z3.s ; SVE-NEXT: uunpkhi z3.s, z0.h +; SVE-NEXT: bfcvtnt z1.h, p1/m, z4.s +; SVE-NEXT: uunpklo z4.s, z0.h +; SVE-NEXT: lsl z3.s, z3.s, #16 +; SVE-NEXT: uunpkhi z2.s, z1.h +; SVE-NEXT: uunpklo z1.s, z1.h ; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: bfcvt z1.h, p1/m, z1.s ; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z3.s, z3.s, #16 ; SVE-NEXT: lsl z1.s, z1.s, #16 ; SVE-NEXT: fadd z2.s, z3.s, z2.s ; SVE-NEXT: fadd z1.s, z4.s, z1.s @@ -545,24 +530,20 @@ define @fadd_sel_fmul_nsz_nxv8bf16( % define @fsub_sel_fmul_nsz_nxv8bf16( %a, %b, %c, %mask) { ; SVE-LABEL: fsub_sel_fmul_nsz_nxv8bf16: ; SVE: // %bb.0: -; SVE-NEXT: uunpkhi z3.s, z2.h -; SVE-NEXT: uunpkhi z4.s, z1.h -; SVE-NEXT: uunpklo z2.s, z2.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: mov z3.s, #0x80000000 +; SVE-NEXT: mov z4.s, #0x80000000 ; SVE-NEXT: ptrue p1.s -; SVE-NEXT: lsl z3.s, z3.s, #16 -; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z1.s, z1.s, #16 -; SVE-NEXT: fmul z3.s, z4.s, z3.s -; SVE-NEXT: uunpklo z4.s, z0.h -; SVE-NEXT: fmul z1.s, z1.s, z2.s -; SVE-NEXT: bfcvt z2.h, p1/m, z3.s +; SVE-NEXT: bfmlalb z3.s, z1.h, z2.h +; SVE-NEXT: bfmlalt z4.s, z1.h, z2.h +; SVE-NEXT: bfcvt z1.h, p1/m, z3.s ; SVE-NEXT: uunpkhi z3.s, z0.h +; SVE-NEXT: bfcvtnt z1.h, p1/m, z4.s +; SVE-NEXT: uunpklo z4.s, z0.h +; SVE-NEXT: lsl z3.s, z3.s, #16 +; SVE-NEXT: uunpkhi z2.s, z1.h +; SVE-NEXT: uunpklo z1.s, z1.h ; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: bfcvt z1.h, p1/m, z1.s ; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z3.s, z3.s, #16 ; SVE-NEXT: lsl z1.s, z1.s, #16 ; SVE-NEXT: fsub z2.s, z3.s, z2.s ; SVE-NEXT: fsub z1.s, z4.s, z1.s @@ -587,24 +568,20 @@ define @fsub_sel_fmul_nsz_nxv8bf16( % define @fadd_sel_fmul_negzero_nxv8bf16( %a, %b, %c, %mask) { ; SVE-LABEL: fadd_sel_fmul_negzero_nxv8bf16: ; SVE: // %bb.0: -; SVE-NEXT: uunpkhi z3.s, z2.h -; SVE-NEXT: uunpkhi z4.s, z1.h -; SVE-NEXT: uunpklo z2.s, z2.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: mov z3.s, #0x80000000 +; SVE-NEXT: mov z4.s, #0x80000000 ; SVE-NEXT: ptrue p1.s -; SVE-NEXT: lsl z3.s, z3.s, #16 -; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z1.s, z1.s, #16 -; SVE-NEXT: fmul z3.s, z4.s, z3.s -; SVE-NEXT: uunpklo z4.s, z0.h -; SVE-NEXT: fmul z1.s, z1.s, z2.s -; SVE-NEXT: bfcvt z2.h, p1/m, z3.s +; SVE-NEXT: bfmlalb z3.s, z1.h, z2.h +; SVE-NEXT: bfmlalt z4.s, z1.h, z2.h +; SVE-NEXT: bfcvt z1.h, p1/m, z3.s ; SVE-NEXT: uunpkhi z3.s, z0.h +; SVE-NEXT: bfcvtnt z1.h, p1/m, z4.s +; SVE-NEXT: uunpklo z4.s, z0.h +; SVE-NEXT: lsl z3.s, z3.s, #16 +; SVE-NEXT: uunpkhi z2.s, z1.h +; SVE-NEXT: uunpklo z1.s, z1.h ; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: bfcvt z1.h, p1/m, z1.s ; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z3.s, z3.s, #16 ; SVE-NEXT: lsl z1.s, z1.s, #16 ; SVE-NEXT: fadd z2.s, z3.s, z2.s ; SVE-NEXT: fadd z1.s, z4.s, z1.s @@ -630,28 +607,21 @@ define @fadd_sel_fmul_negzero_nxv8bf16( @fsub_sel_fmul_negzero_nxv8bf16( %a, %b, %c, %mask) { ; SVE-LABEL: fsub_sel_fmul_negzero_nxv8bf16: ; SVE: // %bb.0: -; SVE-NEXT: uunpkhi z3.s, z2.h -; SVE-NEXT: uunpkhi z4.s, z1.h -; SVE-NEXT: uunpklo z2.s, z2.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: mov z3.s, #0x80000000 +; SVE-NEXT: mov z4.s, #0x80000000 ; SVE-NEXT: ptrue p1.s -; SVE-NEXT: lsl z3.s, z3.s, #16 -; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z1.s, z1.s, #16 -; SVE-NEXT: fmul z3.s, z4.s, z3.s -; SVE-NEXT: fmul z1.s, z1.s, z2.s -; SVE-NEXT: bfcvt z2.h, p1/m, z3.s -; SVE-NEXT: dupm z3.h, #0x8000 -; SVE-NEXT: bfcvt z1.h, p1/m, z1.s -; SVE-NEXT: uzp1 z1.h, z1.h, z2.h -; SVE-NEXT: sel z1.h, p0, z1.h, z3.h +; SVE-NEXT: bfmlalb z3.s, z1.h, z2.h +; SVE-NEXT: bfmlalt z4.s, z1.h, z2.h +; SVE-NEXT: dupm z2.h, #0x8000 +; SVE-NEXT: bfcvt z1.h, p1/m, z3.s ; SVE-NEXT: uunpkhi z3.s, z0.h ; SVE-NEXT: uunpklo z0.s, z0.h -; SVE-NEXT: uunpkhi z2.s, z1.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: bfcvtnt z1.h, p1/m, z4.s ; SVE-NEXT: lsl z3.s, z3.s, #16 ; SVE-NEXT: lsl z0.s, z0.s, #16 +; SVE-NEXT: sel z1.h, p0, z1.h, z2.h +; SVE-NEXT: uunpkhi z2.s, z1.h +; SVE-NEXT: uunpklo z1.s, z1.h ; SVE-NEXT: lsl z2.s, z2.s, #16 ; SVE-NEXT: lsl z1.s, z1.s, #16 ; SVE-NEXT: fsub z2.s, z3.s, z2.s @@ -678,24 +648,20 @@ define @fsub_sel_fmul_negzero_nxv8bf16( @fadd_sel_fmul_negzero_nsz_nxv8bf16( %a, %b, %c, %mask) { ; SVE-LABEL: fadd_sel_fmul_negzero_nsz_nxv8bf16: ; SVE: // %bb.0: -; SVE-NEXT: uunpkhi z3.s, z2.h -; SVE-NEXT: uunpkhi z4.s, z1.h -; SVE-NEXT: uunpklo z2.s, z2.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: mov z3.s, #0x80000000 +; SVE-NEXT: mov z4.s, #0x80000000 ; SVE-NEXT: ptrue p1.s -; SVE-NEXT: lsl z3.s, z3.s, #16 -; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z1.s, z1.s, #16 -; SVE-NEXT: fmul z3.s, z4.s, z3.s -; SVE-NEXT: uunpklo z4.s, z0.h -; SVE-NEXT: fmul z1.s, z1.s, z2.s -; SVE-NEXT: bfcvt z2.h, p1/m, z3.s +; SVE-NEXT: bfmlalb z3.s, z1.h, z2.h +; SVE-NEXT: bfmlalt z4.s, z1.h, z2.h +; SVE-NEXT: bfcvt z1.h, p1/m, z3.s ; SVE-NEXT: uunpkhi z3.s, z0.h +; SVE-NEXT: bfcvtnt z1.h, p1/m, z4.s +; SVE-NEXT: uunpklo z4.s, z0.h +; SVE-NEXT: lsl z3.s, z3.s, #16 +; SVE-NEXT: uunpkhi z2.s, z1.h +; SVE-NEXT: uunpklo z1.s, z1.h ; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: bfcvt z1.h, p1/m, z1.s ; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z3.s, z3.s, #16 ; SVE-NEXT: lsl z1.s, z1.s, #16 ; SVE-NEXT: fadd z2.s, z3.s, z2.s ; SVE-NEXT: fadd z1.s, z4.s, z1.s @@ -721,24 +687,20 @@ define @fadd_sel_fmul_negzero_nsz_nxv8bf16( @fsub_sel_fmul_negzero_nsz_nxv8bf16( %a, %b, %c, %mask) { ; SVE-LABEL: fsub_sel_fmul_negzero_nsz_nxv8bf16: ; SVE: // %bb.0: -; SVE-NEXT: uunpkhi z3.s, z2.h -; SVE-NEXT: uunpkhi z4.s, z1.h -; SVE-NEXT: uunpklo z2.s, z2.h -; SVE-NEXT: uunpklo z1.s, z1.h +; SVE-NEXT: mov z3.s, #0x80000000 +; SVE-NEXT: mov z4.s, #0x80000000 ; SVE-NEXT: ptrue p1.s -; SVE-NEXT: lsl z3.s, z3.s, #16 -; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z1.s, z1.s, #16 -; SVE-NEXT: fmul z3.s, z4.s, z3.s -; SVE-NEXT: uunpklo z4.s, z0.h -; SVE-NEXT: fmul z1.s, z1.s, z2.s -; SVE-NEXT: bfcvt z2.h, p1/m, z3.s +; SVE-NEXT: bfmlalb z3.s, z1.h, z2.h +; SVE-NEXT: bfmlalt z4.s, z1.h, z2.h +; SVE-NEXT: bfcvt z1.h, p1/m, z3.s ; SVE-NEXT: uunpkhi z3.s, z0.h +; SVE-NEXT: bfcvtnt z1.h, p1/m, z4.s +; SVE-NEXT: uunpklo z4.s, z0.h +; SVE-NEXT: lsl z3.s, z3.s, #16 +; SVE-NEXT: uunpkhi z2.s, z1.h +; SVE-NEXT: uunpklo z1.s, z1.h ; SVE-NEXT: lsl z4.s, z4.s, #16 -; SVE-NEXT: bfcvt z1.h, p1/m, z1.s ; SVE-NEXT: lsl z2.s, z2.s, #16 -; SVE-NEXT: lsl z3.s, z3.s, #16 ; SVE-NEXT: lsl z1.s, z1.s, #16 ; SVE-NEXT: fsub z2.s, z3.s, z2.s ; SVE-NEXT: fsub z1.s, z4.s, z1.s From ce51297edb439aa9e88fcdc43a1e46468e7e3534 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 14 Nov 2025 21:00:59 +0000 Subject: [PATCH 5/8] Fix typo --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index cd70ae99e4ae2..c5e0311d6fdc8 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7672,7 +7672,7 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb); auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt); auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2); - auto FCVNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2); + auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2); SDValue LHS = Op.getOperand(0); SDValue RHS = Op.getOperand(1); @@ -7684,14 +7684,14 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom // instructions. These result in two f32 vectors, which can be converted back - // to bf16 with FCVT and FCVNT. + // to bf16 with FCVT and FCVTNT. SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS); SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32); // Note: nxv2bf16 and nxv4bf16 only use even lanes. if (VT != MVT::nxv8bf16) return BottomBF16; SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS); - return FCVNT(VT, BottomBF16, Pg, TopF32); + return FCVTNT(VT, BottomBF16, Pg, TopF32); } SDValue AArch64TargetLowering::LowerOperation(SDValue Op, From ecd2d0f0ca8f712e23c4e38541f9084cb7c0f3f3 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 19 Nov 2025 13:20:43 +0000 Subject: [PATCH 6/8] Fixups --- .../Target/AArch64/AArch64ISelLowering.cpp | 32 +++++++++++-------- llvm/test/CodeGen/AArch64/sve-bf16-arith.ll | 23 +++++-------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index c5e0311d6fdc8..37ab333c100f4 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1819,7 +1819,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT); setOperationPromotedToType(ISD::FSUB, VT, PromotedVT); - if (Subtarget->hasBF16()) + if (VT != MVT::nxv2bf16 && Subtarget->hasBF16()) setOperationAction(ISD::FMUL, VT, Custom); else setOperationPromotedToType(ISD::FMUL, VT, PromotedVT); @@ -7653,13 +7653,12 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op, SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); EVT VT = Op.getValueType(); - auto &Subtarget = DAG.getSubtarget(); if (VT.getScalarType() != MVT::bf16 || - (Subtarget.hasSVEB16B16() && - Subtarget.isNonStreamingSVEorSME2Available())) + (Subtarget->hasSVEB16B16() && + Subtarget->isNonStreamingSVEorSME2Available())) return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); - assert(Subtarget.hasBF16() && "Expected +bf16 for custom FMUL lowering"); + assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering"); auto MakeGetIntrinsic = [&](Intrinsic::ID IID) { return [&, IID](EVT VT, auto... Ops) { @@ -7668,28 +7667,35 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { }; }; + auto ReinterpretCast = [&](SDValue Value, EVT VT) { + if (VT == Value.getValueType()) + return Value; + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value); + }; + // Create helpers for building intrinsic calls. auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb); auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt); auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2); auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2); - SDValue LHS = Op.getOperand(0); - SDValue RHS = Op.getOperand(1); + // All intrinsics expect to operate on full bf16 vector types. + SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16); + SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16); SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags()); - SDValue Pg = - DAG.getConstant(1, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1); + SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1); // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom // instructions. These result in two f32 vectors, which can be converted back // to bf16 with FCVT and FCVTNT. SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS); - SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32); - // Note: nxv2bf16 and nxv4bf16 only use even lanes. - if (VT != MVT::nxv8bf16) - return BottomBF16; + SDValue BottomBF16 = + FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32); + // Note: nxv4bf16 only uses even lanes. + if (VT == MVT::nxv4bf16) + return ReinterpretCast(BottomBF16, VT); SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS); return FCVTNT(VT, BottomBF16, Pg, TopF32); } diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll index 7d5aed2898bfd..f9441e0151f86 100644 --- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll @@ -514,26 +514,19 @@ define @fmla_nxv8bf16( %a, @fmul_nxv2bf16( %a, %b) { -; NOB16B16-NONSTREAMING-LABEL: fmul_nxv2bf16: -; NOB16B16-NONSTREAMING: // %bb.0: -; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000 -; NOB16B16-NONSTREAMING-NEXT: ptrue p0.d -; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h -; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s -; NOB16B16-NONSTREAMING-NEXT: ret +; NOB16B16-LABEL: fmul_nxv2bf16: +; NOB16B16: // %bb.0: +; NOB16B16-NEXT: lsl z1.s, z1.s, #16 +; NOB16B16-NEXT: lsl z0.s, z0.s, #16 +; NOB16B16-NEXT: ptrue p0.d +; NOB16B16-NEXT: fmul z0.s, p0/m, z0.s, z1.s +; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s +; NOB16B16-NEXT: ret ; ; B16B16-LABEL: fmul_nxv2bf16: ; B16B16: // %bb.0: ; B16B16-NEXT: bfmul z0.h, z0.h, z1.h ; B16B16-NEXT: ret -; -; NOB16B16-STREAMING-LABEL: fmul_nxv2bf16: -; NOB16B16-STREAMING: // %bb.0: -; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0 -; NOB16B16-STREAMING-NEXT: ptrue p0.d -; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h -; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s -; NOB16B16-STREAMING-NEXT: ret %res = fmul nsz %a, %b ret %res } From 45c9d88d255899160979626f9035f625d88e106c Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 19 Nov 2025 13:29:13 +0000 Subject: [PATCH 7/8] Add extra assert --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 37ab333c100f4..4eee9bf1b2405 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7659,6 +7659,7 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering"); + assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT"); auto MakeGetIntrinsic = [&](Intrinsic::ID IID) { return [&, IID](EVT VT, auto... Ops) { From cb123012e69fd53b3420f9e159c5cc9fea100f1d Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 19 Nov 2025 17:04:35 +0000 Subject: [PATCH 8/8] Update sve-bf16-arith.ll --- llvm/test/CodeGen/AArch64/sve-bf16-arith.ll | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll index f9441e0151f86..eaf03af462279 100644 --- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll @@ -590,10 +590,9 @@ define @fmul_nxv8bf16( %a, @fmul_nxv8bf16_no_nsz( %a, %b) { ; NOB16B16-LABEL: fmul_nxv8bf16_no_nsz: ; NOB16B16: // %bb.0: -; NOB16B16-NEXT: mov w8, #-2147483648 // =0x80000000 +; NOB16B16-NEXT: mov z2.s, #0x80000000 +; NOB16B16-NEXT: mov z3.s, #0x80000000 ; NOB16B16-NEXT: ptrue p0.s -; NOB16B16-NEXT: mov z2.s, w8 -; NOB16B16-NEXT: mov z3.d, z2.d ; NOB16B16-NEXT: bfmlalb z2.s, z0.h, z1.h ; NOB16B16-NEXT: bfmlalt z3.s, z0.h, z1.h ; NOB16B16-NEXT: bfcvt z0.h, p0/m, z2.s