Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM][CodeGen][SVE] Implement nxvbf16 fpextend to nxvf32/nxvf64. #107253

Merged
merged 2 commits into from
Sep 5, 2024

Conversation

paulwalker-arm
Copy link
Collaborator

NOTE: There are no dedicated SVE instructions but bf16->f32 is just a left shift because they share the same exponent range and from there other convert instructions can be used.

NOTE: There are no dedicated SVE instructions but bf16->f32 is just
a left shift because they share the same exponent range and from
there other convert instructions can be used.
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 4, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

Changes

NOTE: There are no dedicated SVE instructions but bf16->f32 is just a left shift because they share the same exponent range and from there other convert instructions can be used.


Full diff: https://github.com/llvm/llvm-project/pull/107253.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+22-1)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+6-1)
  • (added) llvm/test/CodeGen/AArch64/sve-bf16-converts.ll (+89)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5e3f9364ac3e12..a57878d18b2b7f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1663,6 +1663,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
       setOperationAction(ISD::BITCAST, VT, Custom);
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
+      setOperationAction(ISD::FP_EXTEND, VT, Custom);
       setOperationAction(ISD::MLOAD, VT, Custom);
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
@@ -4298,8 +4299,28 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) {
 SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
                                               SelectionDAG &DAG) const {
   EVT VT = Op.getValueType();
-  if (VT.isScalableVector())
+  if (VT.isScalableVector()) {
+    SDValue SrcVal = Op.getOperand(0);
+
+    if (SrcVal.getValueType().getScalarType() == MVT::bf16) {
+      // bf16 and f32 share the same exponent range so the conversion requires
+      // them to be aligned with the new mantissa bits zero'd, which is just a
+      // left shift that is best to isel drectly.
+      if (VT == MVT::nxv2f32 || VT == MVT::nxv4f32)
+        return Op;
+
+      if (VT != MVT::nxv2f64)
+        return SDValue();
+
+      // Break other conversions in two with the first part converting to f32
+      // and the second using native f32->VT instructions.
+      SDLoc DL(Op);
+      return DAG.getNode(ISD::FP_EXTEND, DL, VT,
+                         DAG.getNode(ISD::FP_EXTEND, DL, MVT::nxv2f32, SrcVal));
+    }
+
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU);
+  }
 
   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
     return LowerFixedLengthFPExtendToSVE(Op, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index af8ddb49b0ac66..ef006be9d02354 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2320,7 +2320,12 @@ let Predicates = [HasSVEorSME] in {
   def : Pat<(nxv2f16 (AArch64fcvtr_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)),
             (FCVT_ZPmZ_StoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
 
-  // Signed integer -> Floating-point 
+  def : Pat<(nxv4f32 (fpextend nxv4bf16:$op)),
+            (LSL_ZZI_S $op, (i32 16))>;
+  def : Pat<(nxv2f32 (fpextend nxv2bf16:$op)),
+            (LSL_ZZI_S $op, (i32 16))>;
+
+  // Signed integer -> Floating-point
   def : Pat<(nxv2f16 (AArch64scvtf_mt (nxv2i1 (SVEAllActive):$Pg),
                       (sext_inreg nxv2i64:$Zs, nxv2i16), nxv2f16:$Zd)),
             (SCVTF_ZPmZ_HtoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
new file mode 100644
index 00000000000000..d72f92c1dac1ff
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll
@@ -0,0 +1,89 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve                  < %s | FileCheck %s
+; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define <vscale x 2 x float> @fpext_nxv2bf16_to_nxv2f32(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 2 x bfloat> %a to <vscale x 2 x float>
+  ret <vscale x 2 x float> %res
+}
+
+define <vscale x 4 x float> @fpext_nxv4bf16_to_nxv4f32(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 4 x bfloat> %a to <vscale x 4 x float>
+  ret <vscale x 4 x float> %res
+}
+
+define <vscale x 8 x float> @fpext_nxv8bf16_to_nxv8f32(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z2.s, z0.h
+; CHECK-NEXT:    lsl z0.s, z1.s, #16
+; CHECK-NEXT:    lsl z1.s, z2.s, #16
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x float>
+  ret <vscale x 8 x float> %res
+}
+
+define <vscale x 2 x double> @fpext_nxv2bf16_to_nxv2f64(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fcvt z0.d, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 2 x bfloat> %a to <vscale x 2 x double>
+  ret <vscale x 2 x double> %res
+}
+
+define <vscale x 4 x double> @fpext_nxv4bf16_to_nxv4f64(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z1.d, z0.s
+; CHECK-NEXT:    uunpkhi z0.d, z0.s
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z2.s, z0.s, #16
+; CHECK-NEXT:    movprfx z0, z1
+; CHECK-NEXT:    fcvt z0.d, p0/m, z1.s
+; CHECK-NEXT:    movprfx z1, z2
+; CHECK-NEXT:    fcvt z1.d, p0/m, z2.s
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 4 x bfloat> %a to <vscale x 4 x double>
+  ret <vscale x 4 x double> %res
+}
+
+define <vscale x 8 x double> @fpext_nxv8bf16_to_nxv8f64(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpklo z2.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    uunpklo z3.d, z0.s
+; CHECK-NEXT:    uunpkhi z0.d, z0.s
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z2.s, z2.s, #16
+; CHECK-NEXT:    lsl z3.s, z3.s, #16
+; CHECK-NEXT:    lsl z4.s, z0.s, #16
+; CHECK-NEXT:    fcvt z1.d, p0/m, z1.s
+; CHECK-NEXT:    movprfx z0, z2
+; CHECK-NEXT:    fcvt z0.d, p0/m, z2.s
+; CHECK-NEXT:    movprfx z2, z3
+; CHECK-NEXT:    fcvt z2.d, p0/m, z3.s
+; CHECK-NEXT:    movprfx z3, z4
+; CHECK-NEXT:    fcvt z3.d, p0/m, z4.s
+; CHECK-NEXT:    ret
+  %res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x double>
+  ret <vscale x 8 x double> %res
+}

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice improvement!

@paulwalker-arm paulwalker-arm merged commit be1958f into llvm:main Sep 5, 2024
5 of 8 checks passed
@paulwalker-arm paulwalker-arm deleted the sve-bfloat-lowering branch September 5, 2024 16:10
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Sep 12, 2024
…vm#107253)

NOTE: There are no dedicated SVE instructions but bf16->f32 is just a
left shift because they share the same exponent range and from there
other convert instructions can be used.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants