Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly round FP -> BF16 when SDAG expands such nodes #82399

Merged
merged 1 commit into from
Feb 21, 2024

Conversation

majnemer
Copy link
Contributor

We did something pretty naive:

  • round FP64 -> BF16 by first rounding to FP32
  • skip FP32 -> BF16 rounding entirely
  • taking the top 16 bits of a FP32 which will turn some NaNs into infinities

Let's do this in a more principled way by rounding types with more precision than FP32 to FP32 using round-inexact-to-odd which will negate double rounding issues.

@majnemer majnemer requested a review from d0k February 20, 2024 18:34
@llvmbot llvmbot added backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well labels Feb 20, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 20, 2024

@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-llvm-selectiondag

Author: David Majnemer (majnemer)

Changes

We did something pretty naive:

  • round FP64 -> BF16 by first rounding to FP32
  • skip FP32 -> BF16 rounding entirely
  • taking the top 16 bits of a FP32 which will turn some NaNs into infinities

Let's do this in a more principled way by rounding types with more precision than FP32 to FP32 using round-inexact-to-odd which will negate double rounding issues.


Patch is 1.09 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82399.diff

11 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+92-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+53)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+5-5)
  • (modified) llvm/test/CodeGen/AMDGPU/bf16.ll (+12511-2859)
  • (modified) llvm/test/CodeGen/AMDGPU/fmed3-cast-combine.ll (+14-2)
  • (modified) llvm/test/CodeGen/AMDGPU/global-atomics-fp.ll (+174-112)
  • (modified) llvm/test/CodeGen/AMDGPU/isel-amdgpu-cs-chain-preserve-cc.ll (+987-475)
  • (modified) llvm/test/CodeGen/AMDGPU/local-atomics-fp.ll (+66-38)
  • (modified) llvm/test/CodeGen/AMDGPU/vector_shuffle.packed.ll (+213-80)
  • (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+1-1)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 252b6e9997a710..3426956a41b3d2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3219,8 +3219,98 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
   case ISD::FP_ROUND: {
     EVT VT = Node->getValueType(0);
     if (VT.getScalarType() == MVT::bf16) {
-      Results.push_back(
-          DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+      if (Node->getConstantOperandVal(1) == 1) {
+        Results.push_back(
+            DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+        break;
+      }
+      SDValue Op = Node->getOperand(0);
+      SDValue IsNaN = DAG.getSetCC(dl, getSetCCResultType(Op.getValueType()),
+                                   Op, Op, ISD::SETUO);
+      if (Op.getValueType() != MVT::f32) {
+        // We are rounding binary64/binary128 -> binary32 -> bfloat16. This
+        // can induce double-rounding which may alter the results. We can
+        // correct for this using a trick explained in: Boldo, Sylvie, and
+        // Guillaume Melquiond. "When double rounding is odd." 17th IMACS
+        // World Congress. 2005.
+        FloatSignAsInt ValueAsInt;
+        getSignAsIntValue(ValueAsInt, dl, Op);
+        EVT WideIntVT = ValueAsInt.IntValue.getValueType();
+        SDValue SignMask = DAG.getConstant(ValueAsInt.SignMask, dl, WideIntVT);
+        SDValue SignBit =
+            DAG.getNode(ISD::AND, dl, WideIntVT, ValueAsInt.IntValue, SignMask);
+        SDValue AbsWide;
+        if (TLI.isOperationLegalOrCustom(ISD::FABS, ValueAsInt.FloatVT)) {
+          AbsWide = DAG.getNode(ISD::FABS, dl, ValueAsInt.FloatVT, Op);
+        } else {
+          SDValue ClearSignMask =
+              DAG.getConstant(~ValueAsInt.SignMask, dl, WideIntVT);
+          SDValue ClearedSign = DAG.getNode(ISD::AND, dl, WideIntVT,
+                                            ValueAsInt.IntValue, ClearSignMask);
+          AbsWide = modifySignAsInt(ValueAsInt, dl, ClearedSign);
+        }
+        SDValue AbsNarrow =
+            DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, AbsWide,
+                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
+        SDValue AbsNarrowAsWide =
+            DAG.getNode(ISD::FP_EXTEND, dl, ValueAsInt.FloatVT, AbsNarrow);
+
+        // We can keep the narrow value as-is if narrowing was exact (no
+        // rounding error), the wide value was NaN (the narrow value is also
+        // NaN and should be preserved) or if we rounded to the odd value.
+        SDValue NarrowBits = DAG.getNode(ISD::BITCAST, dl, MVT::i32, AbsNarrow);
+        SDValue One = DAG.getConstant(1, dl, MVT::i32);
+        SDValue NegativeOne = DAG.getConstant(-1, dl, MVT::i32);
+        SDValue And = DAG.getNode(ISD::AND, dl, MVT::i32, NarrowBits, One);
+        EVT I32CCVT = getSetCCResultType(And.getValueType());
+        SDValue Zero = DAG.getConstant(0, dl, MVT::i32);
+        SDValue AlreadyOdd = DAG.getSetCC(dl, I32CCVT, And, Zero, ISD::SETNE);
+
+        EVT WideSetCCVT = getSetCCResultType(AbsWide.getValueType());
+        SDValue KeepNarrow = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+                                          AbsNarrowAsWide, ISD::SETUEQ);
+        KeepNarrow =
+            DAG.getNode(ISD::OR, dl, WideSetCCVT, KeepNarrow, AlreadyOdd);
+        // We morally performed a round-down if `abs_narrow` is smaller than
+        // `abs_wide`.
+        SDValue NarrowIsRd = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+                                          AbsNarrowAsWide, ISD::SETOGT);
+        // If the narrow value is odd or exact, pick it.
+        // Otherwise, narrow is even and corresponds to either the rounded-up
+        // or rounded-down value. If narrow is the rounded-down value, we want
+        // the rounded-up value as it will be odd.
+        SDValue Adjust =
+            DAG.getSelect(dl, MVT::i32, NarrowIsRd, One, NegativeOne);
+        Adjust = DAG.getSelect(dl, MVT::i32, KeepNarrow, Zero, Adjust);
+        int ShiftAmount = ValueAsInt.SignBit - 31;
+        SDValue ShiftCnst = DAG.getConstant(
+            ShiftAmount, dl,
+            TLI.getShiftAmountTy(WideIntVT, DAG.getDataLayout()));
+        SignBit = DAG.getNode(ISD::SRL, dl, WideIntVT, SignBit, ShiftCnst);
+        SignBit = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, SignBit);
+        Op = DAG.getNode(ISD::OR, dl, MVT::i32, Adjust, SignBit);
+      } else {
+        Op = DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op);
+      }
+
+      SDValue One = DAG.getConstant(1, dl, MVT::i32);
+      SDValue Lsb = DAG.getNode(
+          ISD::SRL, dl, MVT::i32, Op,
+          DAG.getConstant(16, dl,
+                          TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
+      Lsb = DAG.getNode(ISD::AND, dl, MVT::i32, Lsb, One);
+      SDValue RoundingBias = DAG.getNode(
+          ISD::ADD, dl, MVT::i32, DAG.getConstant(0x7fff, dl, MVT::i32), Lsb);
+      SDValue Add = DAG.getNode(ISD::ADD, dl, MVT::i32, Op, RoundingBias);
+      Op = DAG.getNode(
+          ISD::SRL, dl, MVT::i32, Add,
+          DAG.getConstant(16, dl,
+                          TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
+      Op = DAG.getSelect(dl, MVT::i32, IsNaN,
+                         DAG.getConstant(0x00007fc0, dl, MVT::i32), Op);
+
+      Op = DAG.getNode(ISD::TRUNCATE, dl, MVT::i16, Op);
+      Results.push_back(DAG.getNode(ISD::BITCAST, dl, MVT::bf16, Op));
       break;
     }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7f58b312e7a201..e75799ca13b0bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -776,6 +776,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
       AddPromotedToType(Op, MVT::bf16, MVT::f32);
   }
 
+  if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
+    setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
+  }
+
   // sm_80 only has conversions between f32 and bf16. Custom lower all other
   // bf16 conversions.
   if (STI.hasBF16Math() &&
@@ -2465,6 +2470,50 @@ SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
   return Op;
 }
 
+SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
+                                           SelectionDAG &DAG) const {
+  if (Op.getValueType() == MVT::bf16) {
+    if (Op.getOperand(0).getValueType() == MVT::f32 &&
+        (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70)) {
+      SDLoc Loc(Op);
+      return DAG.getNode(ISD::FP_TO_BF16, Loc, MVT::bf16, Op.getOperand(0));
+    }
+    if (Op.getOperand(0).getValueType() == MVT::f64 &&
+        (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+      SDLoc Loc(Op);
+      return DAG.getNode(ISD::FP_TO_BF16, Loc, MVT::bf16, Op.getOperand(0));
+    }
+  }
+
+  // Everything else is considered legal.
+  return Op;
+}
+
+SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  if (Op.getOperand(0).getValueType() == MVT::bf16) {
+    if (Op.getValueType() == MVT::f32 &&
+        (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
+      SDLoc Loc(Op);
+      return DAG.getNode(ISD::BF16_TO_FP, Loc, Op.getValueType(),
+                         Op.getOperand(0));
+    }
+    if (Op.getValueType() == MVT::f64 &&
+        (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+      SDLoc Loc(Op);
+      if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
+        Op = DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0));
+        return DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f64, Op);
+      }
+      return DAG.getNode(ISD::BF16_TO_FP, Loc, Op.getValueType(),
+                         Op.getOperand(0));
+    }
+  }
+
+  // Everything else is considered legal.
+  return Op;
+}
+
 static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
   SDLoc DL(Op);
   if (Op.getValueType() != MVT::v2i16)
@@ -2527,6 +2576,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::FP_TO_SINT:
   case ISD::FP_TO_UINT:
     return LowerFP_TO_INT(Op, DAG);
+  case ISD::FP_ROUND:
+    return LowerFP_ROUND(Op, DAG);
+  case ISD::FP_EXTEND:
+    return LowerFP_EXTEND(Op, DAG);
   case ISD::VAARG:
     return LowerVAARG(Op, DAG);
   case ISD::VASTART:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 5d3fd992812ef9..cf1d4580766918 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -618,6 +618,9 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 631136ad621464..40d82ebecbed35 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -662,7 +662,7 @@ let hasSideEffects = false in {
                    // bf16->f32 was introduced early.
                    [hasPTX<71>, hasSM<80>],
                    // bf16->everything else needs sm90/ptx78
-                   [hasPTX<78>, hasSM<90>])>; 
+                   [hasPTX<78>, hasSM<90>])>;
     def _f32 :
       NVPTXInst<(outs RC:$dst),
                 (ins Float32Regs:$src, CvtMode:$mode),
@@ -3647,7 +3647,7 @@ def : Pat<(f16 (fpround Float32Regs:$a)),
 
 // fpround f32 -> bf16
 def : Pat<(bf16 (fpround Float32Regs:$a)),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>, Requires<[hasPTX<70>, hasSM<80>]>;
 
 // fpround f64 -> f16
 def : Pat<(f16 (fpround Float64Regs:$a)),
@@ -3655,7 +3655,7 @@ def : Pat<(f16 (fpround Float64Regs:$a)),
 
 // fpround f64 -> bf16
 def : Pat<(bf16 (fpround Float64Regs:$a)),
-          (CVT_bf16_f64 Float64Regs:$a, CvtRN)>;
+          (CVT_bf16_f64 Float64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 // fpround f64 -> f32
 def : Pat<(f32 (fpround Float64Regs:$a)),
           (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
@@ -3671,7 +3671,7 @@ def : Pat<(f32 (fpextend (f16 Int16Regs:$a))),
 def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
           (CVT_f32_bf16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
 def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
-          (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>;
+          (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<71>, hasSM<80>]>;
 
 // fpextend f16 -> f64
 def : Pat<(f64 (fpextend (f16 Int16Regs:$a))),
@@ -3679,7 +3679,7 @@ def : Pat<(f64 (fpextend (f16 Int16Regs:$a))),
 
 // fpextend bf16 -> f64
 def : Pat<(f64 (fpextend (bf16 Int16Regs:$a))),
-          (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>;
+          (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<78>, hasSM<90>]>;
 
 // fpextend f32 -> f64
 def : Pat<(f64 (fpextend Float32Regs:$a)),
diff --git a/llvm/test/CodeGen/AMDGPU/bf16.ll b/llvm/test/CodeGen/AMDGPU/bf16.ll
index 387c4a16a008ae..39cb0a768701c0 100644
--- a/llvm/test/CodeGen/AMDGPU/bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/bf16.ll
@@ -1918,8 +1918,14 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX8:       ; %bb.0:
 ; GFX8-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX8-NEXT:    flat_load_dword v0, v[0:1]
+; GFX8-NEXT:    v_mov_b32_e32 v1, 0x7fc0
 ; GFX8-NEXT:    s_waitcnt vmcnt(0)
-; GFX8-NEXT:    v_lshrrev_b32_e32 v0, 16, v0
+; GFX8-NEXT:    v_bfe_u32 v4, v0, 16, 1
+; GFX8-NEXT:    v_add_u32_e32 v4, vcc, v4, v0
+; GFX8-NEXT:    v_add_u32_e32 v4, vcc, 0x7fff, v4
+; GFX8-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX8-NEXT:    v_cmp_o_f32_e32 vcc, v0, v0
+; GFX8-NEXT:    v_cndmask_b32_e32 v0, v1, v4, vcc
 ; GFX8-NEXT:    flat_store_short v[2:3], v0
 ; GFX8-NEXT:    s_waitcnt vmcnt(0)
 ; GFX8-NEXT:    s_setpc_b64 s[30:31]
@@ -1928,8 +1934,15 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX9:       ; %bb.0:
 ; GFX9-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX9-NEXT:    global_load_dword v0, v[0:1], off
+; GFX9-NEXT:    s_movk_i32 s4, 0x7fff
+; GFX9-NEXT:    v_mov_b32_e32 v1, 0x7fc0
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
-; GFX9-NEXT:    global_store_short_d16_hi v[2:3], v0, off
+; GFX9-NEXT:    v_bfe_u32 v4, v0, 16, 1
+; GFX9-NEXT:    v_add3_u32 v4, v4, v0, s4
+; GFX9-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX9-NEXT:    v_cmp_o_f32_e32 vcc, v0, v0
+; GFX9-NEXT:    v_cndmask_b32_e32 v0, v1, v4, vcc
+; GFX9-NEXT:    global_store_short v[2:3], v0, off
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
 ; GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
@@ -1938,7 +1951,12 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX10-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX10-NEXT:    global_load_dword v0, v[0:1], off
 ; GFX10-NEXT:    s_waitcnt vmcnt(0)
-; GFX10-NEXT:    global_store_short_d16_hi v[2:3], v0, off
+; GFX10-NEXT:    v_bfe_u32 v1, v0, 16, 1
+; GFX10-NEXT:    v_cmp_o_f32_e32 vcc_lo, v0, v0
+; GFX10-NEXT:    v_add3_u32 v1, v1, v0, 0x7fff
+; GFX10-NEXT:    v_lshrrev_b32_e32 v1, 16, v1
+; GFX10-NEXT:    v_cndmask_b32_e32 v0, 0x7fc0, v1, vcc_lo
+; GFX10-NEXT:    global_store_short v[2:3], v0, off
 ; GFX10-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GFX11-LABEL: test_load_store_f32_to_bf16:
@@ -1946,7 +1964,14 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX11-NEXT:    global_load_b32 v0, v[0:1], off
 ; GFX11-NEXT:    s_waitcnt vmcnt(0)
-; GFX11-NEXT:    global_store_d16_hi_b16 v[2:3], v0, off
+; GFX11-NEXT:    v_bfe_u32 v1, v0, 16, 1
+; GFX11-NEXT:    v_cmp_o_f32_e32 vcc_lo, v0, v0
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_add3_u32 v1, v1, v0, 0x7fff
+; GFX11-NEXT:    v_lshrrev_b32_e32 v1, 16, v1
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1)
+; GFX11-NEXT:    v_cndmask_b32_e32 v0, 0x7fc0, v1, vcc_lo
+; GFX11-NEXT:    global_store_b16 v[2:3], v0, off
 ; GFX11-NEXT:    s_setpc_b64 s[30:31]
   %val = load float, ptr addrspace(1) %in
   %val.bf16 = fptrunc float %val to bfloat
@@ -1989,9 +2014,25 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX8:       ; %bb.0:
 ; GFX8-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX8-NEXT:    flat_load_dwordx2 v[0:1], v[0:1]
+; GFX8-NEXT:    v_mov_b32_e32 v7, 0x7fc0
 ; GFX8-NEXT:    s_waitcnt vmcnt(0)
-; GFX8-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX8-NEXT:    v_lshrrev_b32_e32 v0, 16, v0
+; GFX8-NEXT:    v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX8-NEXT:    v_and_b32_e32 v8, 0x80000000, v1
+; GFX8-NEXT:    v_cvt_f64_f32_e32 v[4:5], v6
+; GFX8-NEXT:    v_and_b32_e32 v6, 1, v6
+; GFX8-NEXT:    v_cmp_eq_u32_e32 vcc, 1, v6
+; GFX8-NEXT:    v_cmp_nlg_f64_e64 s[4:5], |v[0:1]|, v[4:5]
+; GFX8-NEXT:    v_cmp_gt_f64_e64 s[6:7], |v[0:1]|, v[4:5]
+; GFX8-NEXT:    s_or_b64 s[4:5], s[4:5], vcc
+; GFX8-NEXT:    v_cndmask_b32_e64 v4, -1, 1, s[6:7]
+; GFX8-NEXT:    v_cndmask_b32_e64 v4, v4, 0, s[4:5]
+; GFX8-NEXT:    v_or_b32_e32 v5, v4, v8
+; GFX8-NEXT:    v_bfe_u32 v4, v4, 16, 1
+; GFX8-NEXT:    v_add_u32_e32 v4, vcc, v4, v5
+; GFX8-NEXT:    v_add_u32_e32 v4, vcc, 0x7fff, v4
+; GFX8-NEXT:    v_cmp_o_f64_e32 vcc, v[0:1], v[0:1]
+; GFX8-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX8-NEXT:    v_cndmask_b32_e32 v0, v7, v4, vcc
 ; GFX8-NEXT:    flat_store_short v[2:3], v0
 ; GFX8-NEXT:    s_waitcnt vmcnt(0)
 ; GFX8-NEXT:    s_setpc_b64 s[30:31]
@@ -2000,9 +2041,26 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX9:       ; %bb.0:
 ; GFX9-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX9-NEXT:    global_load_dwordx2 v[0:1], v[0:1], off
+; GFX9-NEXT:    s_brev_b32 s8, 1
+; GFX9-NEXT:    s_movk_i32 s9, 0x7fff
+; GFX9-NEXT:    v_mov_b32_e32 v7, 0x7fc0
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
-; GFX9-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX9-NEXT:    global_store_short_d16_hi v[2:3], v0, off
+; GFX9-NEXT:    v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX9-NEXT:    v_cvt_f64_f32_e32 v[4:5], v6
+; GFX9-NEXT:    v_and_b32_e32 v6, 1, v6
+; GFX9-NEXT:    v_cmp_eq_u32_e32 vcc, 1, v6
+; GFX9-NEXT:    v_cmp_nlg_f64_e64 s[4:5], |v[0:1]|, v[4:5]
+; GFX9-NEXT:    v_cmp_gt_f64_e64 s[6:7], |v[0:1]|, v[4:5]
+; GFX9-NEXT:    s_or_b64 s[4:5], s[4:5], vcc
+; GFX9-NEXT:    v_cmp_o_f64_e32 vcc, v[0:1], v[0:1]
+; GFX9-NEXT:    v_cndmask_b32_e64 v4, -1, 1, s[6:7]
+; GFX9-NEXT:    v_cndmask_b32_e64 v4, v4, 0, s[4:5]
+; GFX9-NEXT:    v_and_or_b32 v5, v1, s8, v4
+; GFX9-NEXT:    v_bfe_u32 v4, v4, 16, 1
+; GFX9-NEXT:    v_add3_u32 v4, v4, v5, s9
+; GFX9-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX9-NEXT:    v_cndmask_b32_e32 v0, v7, v4, vcc
+; GFX9-NEXT:    global_store_short v[2:3], v0, off
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
 ; GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
@@ -2011,8 +2069,22 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX10-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX10-NEXT:    global_load_dwordx2 v[0:1], v[0:1], off
 ; GFX10-NEXT:    s_waitcnt vmcnt(0)
-; GFX10-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX10-NEXT:    global_store_short_d16_hi v[2:3], v0, off
+; GFX10-NEXT:    v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX10-NEXT:    v_cvt_f64_f32_e32 v[4:5], v6
+; GFX10-NEXT:    v_and_b32_e32 v6, 1, v6
+; GFX10-NEXT:    v_cmp_eq_u32_e32 vcc_lo, 1, v6
+; GFX10-NEXT:    v_cmp_gt_f64_e64 s5, |v[0:1]|, v[4:5]
+; GFX10-NEXT:    v_cmp_nlg_f64_e64 s4, |v[0:1]|, v[4:5]
+; GFX10-NEXT:    v_cndmask_b32_e64 v4, -1, 1, s5
+; GFX10-NEXT:    s_or_b32 s4, s4, vcc_lo
+; GFX10-NEXT:    v_cmp_o_f64_e32 vcc_lo, v[0:1], v[0:1]
+; GFX10-NEXT:    v_cndmask_b32_e64 v4, v4, 0, s4
+; GFX10-NEXT:    v_and_or_b32 v5, 0x80000000, v1, v4
+; GFX10-NEXT:    v_bfe_u32 v4, v4, 16, 1
+; GFX10-NEXT:    v_add3_u32 v4, v4, v5, 0x7fff
+; GFX10-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX10-NEXT:    v_cndmask_b32_e32 v0, 0x7fc0, v4, vcc_lo
+; GFX10-NEXT:    global_store_short v[2:3], v0, off
 ; GFX10-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GFX11-LABEL: test_load_store_f64_to_bf16:
@@ -2020,8 +2092,27 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX11-NEXT:    global_load_b64 v[0:1], v[0:1], off
 ; GFX11-NEXT:    s_waitcnt vmcnt(0)
-; GFX11-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX11-NEXT:    global_store_d16_hi_b16 v[2:3], v0, off
+; GFX11-NEXT:    v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_cvt_f64_f32_e32 v[4:5], v6
+; GFX11-NEXT:    v_and_b32_e32 v6, 1, v6
+; GFX11-NEXT:    v_cmp_eq_u32_e32 vcc_lo, 1, v6
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_3) | instskip(SKIP_1) | instid1(VALU_DEP_2)
+; GFX11-NEXT:    v_cmp_nlg_f64_e64 s0, |v[0:1]|, v[4:5]
+; GFX11-NEXT:    v_cmp_gt_f64_e64 s1, |v[0:1]|, v[4:5]
+; GFX11-NEXT:    s_or_b32 s0, s0, vcc_lo
+; GFX11-NEXT:    v_cmp_o_f64_e32 vcc_lo, v[0:1], v[0:1]
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_cndmask_b32_e64 v4, -1, 1, s1
+; GFX11-NEXT:    v_cndmask_b32_e64 v4, v4, 0, s0
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_and_or_b32 v5, 0x80000000, v1, v4
+; GFX11-NEXT:    v_bfe_u32 v4, v4, 16, 1
+; GFX11-NEXT:    v_add3_u32 v4, v4, v5, 0x7fff
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX11-NEXT:    v_cndmask_b32_e32 v0, 0x7fc0, v4, vcc_lo
+; GFX11-NEXT:    global_store_b16 v[2:3], v0, off
 ; GFX11-NEXT:    s_setpc_b64 s[30:31]
   %val = load double, ptr addrspace(1) %in
   %val.bf16 = fptrunc double %val to bfloat
@@ -8487,7 +8578,13 @@ define bfloat ...
[truncated]

@majnemer majnemer force-pushed the do-cvt branch 2 times, most recently from 0573396 to f82d958 Compare February 21, 2024 16:17
Copy link

github-actions bot commented Feb 21, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

We did something pretty naive:
- round FP64 -> BF16 by first rounding to FP32
- skip FP32 -> BF16 rounding entirely
- taking the top 16 bits of a FP32 which will turn some NaNs into infinities

Let's do this in a more principled way by rounding types with more
precision than FP32 to FP32 using round-inexact-to-odd which will
negate double rounding issues.
@majnemer majnemer merged commit cc13f3b into llvm:main Feb 21, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants