-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVM][CodeGen][SVE] Implement nxvbf16 fpextend to nxvf32/nxvf64. #107253
Conversation
NOTE: There are no dedicated SVE instructions but bf16->f32 is just a left shift because they share the same exponent range and from there other convert instructions can be used.
@llvm/pr-subscribers-backend-aarch64 Author: Paul Walker (paulwalker-arm) ChangesNOTE: There are no dedicated SVE instructions but bf16->f32 is just a left shift because they share the same exponent range and from there other convert instructions can be used. Full diff: https://github.com/llvm/llvm-project/pull/107253.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5e3f9364ac3e12..a57878d18b2b7f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1663,6 +1663,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
+ setOperationAction(ISD::FP_EXTEND, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
@@ -4298,8 +4299,28 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) {
SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
- if (VT.isScalableVector())
+ if (VT.isScalableVector()) {
+ SDValue SrcVal = Op.getOperand(0);
+
+ if (SrcVal.getValueType().getScalarType() == MVT::bf16) {
+ // bf16 and f32 share the same exponent range so the conversion requires
+ // them to be aligned with the new mantissa bits zero'd, which is just a
+ // left shift that is best to isel drectly.
+ if (VT == MVT::nxv2f32 || VT == MVT::nxv4f32)
+ return Op;
+
+ if (VT != MVT::nxv2f64)
+ return SDValue();
+
+ // Break other conversions in two with the first part converting to f32
+ // and the second using native f32->VT instructions.
+ SDLoc DL(Op);
+ return DAG.getNode(ISD::FP_EXTEND, DL, VT,
+ DAG.getNode(ISD::FP_EXTEND, DL, MVT::nxv2f32, SrcVal));
+ }
+
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU);
+ }
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
return LowerFixedLengthFPExtendToSVE(Op, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index af8ddb49b0ac66..ef006be9d02354 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2320,7 +2320,12 @@ let Predicates = [HasSVEorSME] in {
def : Pat<(nxv2f16 (AArch64fcvtr_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)),
(FCVT_ZPmZ_StoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
- // Signed integer -> Floating-point
+ def : Pat<(nxv4f32 (fpextend nxv4bf16:$op)),
+ (LSL_ZZI_S $op, (i32 16))>;
+ def : Pat<(nxv2f32 (fpextend nxv2bf16:$op)),
+ (LSL_ZZI_S $op, (i32 16))>;
+
+ // Signed integer -> Floating-point
def : Pat<(nxv2f16 (AArch64scvtf_mt (nxv2i1 (SVEAllActive):$Pg),
(sext_inreg nxv2i64:$Zs, nxv2i16), nxv2f16:$Zd)),
(SCVTF_ZPmZ_HtoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
new file mode 100644
index 00000000000000..d72f92c1dac1ff
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
@@ -0,0 +1,89 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve < %s | FileCheck %s
+; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define <vscale x 2 x float> @fpext_nxv2bf16_to_nxv2f32(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ret
+ %res = fpext <vscale x 2 x bfloat> %a to <vscale x 2 x float>
+ ret <vscale x 2 x float> %res
+}
+
+define <vscale x 4 x float> @fpext_nxv4bf16_to_nxv4f32(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ret
+ %res = fpext <vscale x 4 x bfloat> %a to <vscale x 4 x float>
+ ret <vscale x 4 x float> %res
+}
+
+define <vscale x 8 x float> @fpext_nxv8bf16_to_nxv8f32(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z2.s, z0.h
+; CHECK-NEXT: lsl z0.s, z1.s, #16
+; CHECK-NEXT: lsl z1.s, z2.s, #16
+; CHECK-NEXT: ret
+ %res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x float>
+ ret <vscale x 8 x float> %res
+}
+
+define <vscale x 2 x double> @fpext_nxv2bf16_to_nxv2f64(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fcvt z0.d, p0/m, z0.s
+; CHECK-NEXT: ret
+ %res = fpext <vscale x 2 x bfloat> %a to <vscale x 2 x double>
+ ret <vscale x 2 x double> %res
+}
+
+define <vscale x 4 x double> @fpext_nxv4bf16_to_nxv4f64(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpklo z1.d, z0.s
+; CHECK-NEXT: uunpkhi z0.d, z0.s
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z2.s, z0.s, #16
+; CHECK-NEXT: movprfx z0, z1
+; CHECK-NEXT: fcvt z0.d, p0/m, z1.s
+; CHECK-NEXT: movprfx z1, z2
+; CHECK-NEXT: fcvt z1.d, p0/m, z2.s
+; CHECK-NEXT: ret
+ %res = fpext <vscale x 4 x bfloat> %a to <vscale x 4 x double>
+ ret <vscale x 4 x double> %res
+}
+
+define <vscale x 8 x double> @fpext_nxv8bf16_to_nxv8f64(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: uunpklo z2.d, z1.s
+; CHECK-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEXT: uunpklo z3.d, z0.s
+; CHECK-NEXT: uunpkhi z0.d, z0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z2.s, z2.s, #16
+; CHECK-NEXT: lsl z3.s, z3.s, #16
+; CHECK-NEXT: lsl z4.s, z0.s, #16
+; CHECK-NEXT: fcvt z1.d, p0/m, z1.s
+; CHECK-NEXT: movprfx z0, z2
+; CHECK-NEXT: fcvt z0.d, p0/m, z2.s
+; CHECK-NEXT: movprfx z2, z3
+; CHECK-NEXT: fcvt z2.d, p0/m, z3.s
+; CHECK-NEXT: movprfx z3, z4
+; CHECK-NEXT: fcvt z3.d, p0/m, z4.s
+; CHECK-NEXT: ret
+ %res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x double>
+ ret <vscale x 8 x double> %res
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a nice improvement!
…vm#107253) NOTE: There are no dedicated SVE instructions but bf16->f32 is just a left shift because they share the same exponent range and from there other convert instructions can be used.
NOTE: There are no dedicated SVE instructions but bf16->f32 is just a left shift because they share the same exponent range and from there other convert instructions can be used.