Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DAG: Fix chain mismanagement in SoftenFloatRes_FP_EXTEND #74558

Merged
merged 3 commits into from
Jan 18, 2024

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Dec 6, 2023

This would result in nodes not getting appropriately re-legalized
in the strictfp case.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 6, 2023

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

Changes

This would result in nodes not getting appropriately re-legalized
in the strictfp case.


Patch is 87.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74558.diff

11 Files Affected:

  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+74-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+14-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+37-22)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+8-1)
  • (added) llvm/test/CodeGen/AArch64/strictfp_f16_abi_promote.ll (+344)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll (+3)
  • (modified) llvm/test/CodeGen/AMDGPU/strict_fpext.ll (+254-26)
  • (modified) llvm/test/CodeGen/AMDGPU/strict_fptrunc.ll (+227-21)
  • (added) llvm/test/CodeGen/AMDGPU/strictfp_f16_abi_promote.ll (+558)
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 798e6a1d9525e..de7bf26868b34 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -614,6 +614,12 @@ def strict_sint_to_fp : SDNode<"ISD::STRICT_SINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
 def strict_uint_to_fp : SDNode<"ISD::STRICT_UINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
+
+def strict_f16_to_fp  : SDNode<"ISD::STRICT_FP16_TO_FP",
+                               SDTIntToFPOp, [SDNPHasChain]>;
+def strict_fp_to_f16  : SDNode<"ISD::STRICT_FP_TO_FP16",
+                               SDTFPToIntOp, [SDNPHasChain]>;
+
 def strict_fsetcc  : SDNode<"ISD::STRICT_FSETCC",  SDTSetCC, [SDNPHasChain]>;
 def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
 
@@ -1558,6 +1564,13 @@ def any_fsetccs : PatFrags<(ops node:$lhs, node:$rhs, node:$pred),
                           [(strict_fsetccs node:$lhs, node:$rhs, node:$pred),
                            (setcc node:$lhs, node:$rhs, node:$pred)]>;
 
+def any_f16_to_fp : PatFrags<(ops node:$src),
+                              [(f16_to_fp node:$src),
+                               (strict_f16_to_fp node:$src)]>;
+def any_fp_to_f16 : PatFrags<(ops node:$src),
+                              [(fp_to_f16 node:$src),
+                               (strict_fp_to_f16 node:$src)]>;
+
 multiclass binary_atomic_op_ord {
   def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
       (!cast<SDPatternOperator>(NAME) node:$ptr, node:$val)> {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 630aa4a07d7b9..1985af7bca5c4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -522,8 +522,11 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
     Op = GetPromotedFloat(Op);
     // If the promotion did the FP_EXTEND to the destination type for us,
     // there's nothing left to do here.
-    if (Op.getValueType() == N->getValueType(0))
+    if (Op.getValueType() == N->getValueType(0)) {
+      if (IsStrict)
+        ReplaceValueWith(SDValue(N, 1), Chain);
       return BitConvertToInteger(Op);
+    }
   }
 
   // There's only a libcall for f16 -> f32 and shifting is only valid for bf16
@@ -541,8 +544,10 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
     }
   }
 
-  if (Op.getValueType() == MVT::bf16)
+  if (Op.getValueType() == MVT::bf16) {
+    // FIXME: Need ReplaceValueWith on chain in strict case
     return SoftenFloatRes_BF16_TO_FP(N);
+  }
 
   RTLIB::Libcall LC = RTLIB::getFPEXT(Op.getValueType(), N->getValueType(0));
   assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported FP_EXTEND!");
@@ -2181,6 +2186,24 @@ static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
   report_fatal_error("Attempt at an invalid promotion-related conversion");
 }
 
+static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
+  if (OpVT == MVT::f16)
+    return ISD::STRICT_FP16_TO_FP;
+
+  if (RetVT == MVT::f16)
+    return ISD::STRICT_FP_TO_FP16;
+
+  if (OpVT == MVT::bf16) {
+    // TODO: return ISD::STRICT_BF16_TO_FP;
+  }
+
+  if (RetVT == MVT::bf16) {
+    // TODO: return ISD::STRICT_FP_TO_BF16;
+  }
+
+  report_fatal_error("Attempt at an invalid promotion-related conversion");
+}
+
 bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
   LLVM_DEBUG(dbgs() << "Promote float operand " << OpNo << ": "; N->dump(&DAG));
   SDValue R = SDValue();
@@ -2214,6 +2237,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::FP_TO_UINT_SAT:
                           R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
     case ISD::FP_EXTEND:  R = PromoteFloatOp_FP_EXTEND(N, OpNo); break;
+    case ISD::STRICT_FP_EXTEND:
+      R = PromoteFloatOp_STRICT_FP_EXTEND(N, OpNo);
+      break;
     case ISD::SELECT_CC:  R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
     case ISD::SETCC:      R = PromoteFloatOp_SETCC(N, OpNo); break;
     case ISD::STORE:      R = PromoteFloatOp_STORE(N, OpNo); break;
@@ -2276,6 +2302,26 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo) {
   return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, Op);
 }
 
+SDValue DAGTypeLegalizer::PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N,
+                                                          unsigned OpNo) {
+  assert(OpNo == 1 && "Promoting unpromotable operand");
+
+  SDValue Op = GetPromotedFloat(N->getOperand(1));
+  EVT VT = N->getValueType(0);
+
+  // Desired VT is same as promoted type.  Use promoted float directly.
+  if (VT == Op->getValueType(0)) {
+    ReplaceValueWith(SDValue(N, 1), N->getOperand(0));
+    return Op;
+  }
+
+  // Else, extend the promoted float value to the desired VT.
+  SDValue Res = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N), N->getVTList(),
+                            N->getOperand(0), Op);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 // Promote the float operands used for comparison.  The true- and false-
 // operands have the same type as the result and are promoted, if needed, by
 // PromoteFloatRes_SELECT_CC
@@ -2393,6 +2439,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::FFREXP:     R = PromoteFloatRes_FFREXP(N); break;
 
     case ISD::FP_ROUND:   R = PromoteFloatRes_FP_ROUND(N); break;
+    case ISD::STRICT_FP_ROUND:
+      R = PromoteFloatRes_STRICT_FP_ROUND(N);
+      break;
     case ISD::LOAD:       R = PromoteFloatRes_LOAD(N); break;
     case ISD::SELECT:     R = PromoteFloatRes_SELECT(N); break;
     case ISD::SELECT_CC:  R = PromoteFloatRes_SELECT_CC(N); break;
@@ -2598,6 +2647,29 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_FP_ROUND(SDNode *N) {
   return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round);
 }
 
+// Explicit operation to reduce precision.  Reduce the value to half precision
+// and promote it back to the legal type.
+SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_FP_ROUND(SDNode *N) {
+  SDLoc DL(N);
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Op = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  EVT OpVT = Op->getValueType(0);
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+
+  // Round promoted float to desired precision
+  SDValue Round = DAG.getNode(GetPromotionOpcodeStrict(OpVT, VT), DL,
+                              DAG.getVTList(IVT, MVT::Other), Chain, Op);
+  // Promote it back to the legal output type
+  SDValue Res =
+      DAG.getNode(GetPromotionOpcodeStrict(VT, NVT), DL,
+                  DAG.getVTList(NVT, MVT::Other), Round.getValue(1), Round);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteFloatRes_LOAD(SDNode *N) {
   LoadSDNode *L = cast<LoadSDNode>(N);
   EVT VT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 54698edce7d6f..887670fb6baff 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,7 +165,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
     break;
-
+  case ISD::STRICT_FP_TO_FP16:
+    Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
+    break;
   case ISD::GET_ROUNDING: Res = PromoteIntRes_GET_ROUNDING(N); break;
 
   case ISD::AND:
@@ -787,6 +789,16 @@ SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16_BF16(SDNode *N) {
   return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N) {
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDLoc dl(N);
+
+  SDValue Res = DAG.getNode(N->getOpcode(), dl, DAG.getVTList(NVT, MVT::Other),
+                            N->getOperand(0), N->getOperand(1));
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   SDLoc dl(N);
@@ -1804,6 +1816,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::FP16_TO_FP:
   case ISD::VP_UINT_TO_FP:
   case ISD::UINT_TO_FP:   Res = PromoteIntOp_UINT_TO_FP(N); break;
+  case ISD::STRICT_FP16_TO_FP:
   case ISD::STRICT_UINT_TO_FP:  Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
   case ISD::ZERO_EXTEND:  Res = PromoteIntOp_ZERO_EXTEND(N); break;
   case ISD::VP_ZERO_EXTEND: Res = PromoteIntOp_VP_ZERO_EXTEND(N); break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index e9bd54089d062..9361e7fff2190 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -326,6 +326,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
+  SDValue PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N);
   SDValue PromoteIntRes_XRINT(SDNode *N);
   SDValue PromoteIntRes_FREEZE(SDNode *N);
   SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
@@ -698,6 +699,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatRes_ExpOp(SDNode *N);
   SDValue PromoteFloatRes_FFREXP(SDNode *N);
   SDValue PromoteFloatRes_FP_ROUND(SDNode *N);
+  SDValue PromoteFloatRes_STRICT_FP_ROUND(SDNode *N);
   SDValue PromoteFloatRes_LOAD(SDNode *N);
   SDValue PromoteFloatRes_SELECT(SDNode *N);
   SDValue PromoteFloatRes_SELECT_CC(SDNode *N);
@@ -712,6 +714,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index da7d9ace4114a..29684d3372bdb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -153,6 +153,7 @@ static const unsigned MaxParallelChains = 64;
 static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
                                       const SDValue *Parts, unsigned NumParts,
                                       MVT PartVT, EVT ValueVT, const Value *V,
+                                      SDValue InChain,
                                       std::optional<CallingConv::ID> CC);
 
 /// getCopyFromParts - Create a value that contains the specified legal parts
@@ -163,6 +164,7 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
 static SDValue
 getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
                  unsigned NumParts, MVT PartVT, EVT ValueVT, const Value *V,
+                 SDValue InChain,
                  std::optional<CallingConv::ID> CC = std::nullopt,
                  std::optional<ISD::NodeType> AssertOp = std::nullopt) {
   // Let the target assemble the parts if it wants to
@@ -173,7 +175,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
 
   if (ValueVT.isVector())
     return getCopyFromPartsVector(DAG, DL, Parts, NumParts, PartVT, ValueVT, V,
-                                  CC);
+                                  InChain, CC);
 
   assert(NumParts > 0 && "No parts to assemble!");
   SDValue Val = Parts[0];
@@ -194,10 +196,10 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
       EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), RoundBits/2);
 
       if (RoundParts > 2) {
-        Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2,
-                              PartVT, HalfVT, V);
-        Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2,
-                              RoundParts / 2, PartVT, HalfVT, V);
+        Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2, PartVT, HalfVT, V,
+                              InChain);
+        Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2, RoundParts / 2,
+                              PartVT, HalfVT, V, InChain);
       } else {
         Lo = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[0]);
         Hi = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[1]);
@@ -213,7 +215,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
         unsigned OddParts = NumParts - RoundParts;
         EVT OddVT = EVT::getIntegerVT(*DAG.getContext(), OddParts * PartBits);
         Hi = getCopyFromParts(DAG, DL, Parts + RoundParts, OddParts, PartVT,
-                              OddVT, V, CC);
+                              OddVT, V, InChain, CC);
 
         // Combine the round and odd parts.
         Lo = Val;
@@ -243,7 +245,8 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
       assert(ValueVT.isFloatingPoint() && PartVT.isInteger() &&
              !PartVT.isVector() && "Unexpected split");
       EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), ValueVT.getSizeInBits());
-      Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V, CC);
+      Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V,
+                             InChain, CC);
     }
   }
 
@@ -283,10 +286,20 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
 
   if (PartEVT.isFloatingPoint() && ValueVT.isFloatingPoint()) {
     // FP_ROUND's are always exact here.
-    if (ValueVT.bitsLT(Val.getValueType()))
-      return DAG.getNode(
-          ISD::FP_ROUND, DL, ValueVT, Val,
-          DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout())));
+    if (ValueVT.bitsLT(Val.getValueType())) {
+
+      SDValue NoChange =
+          DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout()));
+
+      if (DAG.getMachineFunction().getFunction().getAttributes().hasFnAttr(
+              llvm::Attribute::StrictFP)) {
+        return DAG.getNode(ISD::STRICT_FP_ROUND, DL,
+                           DAG.getVTList(ValueVT, MVT::Other), InChain, Val,
+                           NoChange);
+      }
+
+      return DAG.getNode(ISD::FP_ROUND, DL, ValueVT, Val, NoChange);
+    }
 
     return DAG.getNode(ISD::FP_EXTEND, DL, ValueVT, Val);
   }
@@ -324,6 +337,7 @@ static void diagnosePossiblyInvalidConstraint(LLVMContext &Ctx, const Value *V,
 static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
                                       const SDValue *Parts, unsigned NumParts,
                                       MVT PartVT, EVT ValueVT, const Value *V,
+                                      SDValue InChain,
                                       std::optional<CallingConv::ID> CallConv) {
   assert(ValueVT.isVector() && "Not a vector value");
   assert(NumParts > 0 && "No parts to assemble!");
@@ -362,8 +376,8 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
       // If the register was not expanded, truncate or copy the value,
       // as appropriate.
       for (unsigned i = 0; i != NumParts; ++i)
-        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1,
-                                  PartVT, IntermediateVT, V, CallConv);
+        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1, PartVT, IntermediateVT,
+                                  V, InChain, CallConv);
     } else if (NumParts > 0) {
       // If the intermediate type was expanded, build the intermediate
       // operands from the parts.
@@ -371,8 +385,8 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
              "Must expand into a divisible number of parts!");
       unsigned Factor = NumParts / NumIntermediates;
       for (unsigned i = 0; i != NumIntermediates; ++i)
-        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor,
-                                  PartVT, IntermediateVT, V, CallConv);
+        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor, PartVT,
+                                  IntermediateVT, V, InChain, CallConv);
     }
 
     // Build a vector with BUILD_VECTOR or CONCAT_VECTORS from the
@@ -926,7 +940,7 @@ SDValue RegsForValue::getCopyFromRegs(SelectionDAG &DAG,
     }
 
     Values[Value] = getCopyFromParts(DAG, dl, Parts.begin(), NumRegs,
-                                     RegisterVT, ValueVT, V, CallConv);
+                                     RegisterVT, ValueVT, V, Chain, CallConv);
     Part += NumRegs;
     Parts.clear();
   }
@@ -10641,9 +10655,9 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
       unsigned NumRegs = getNumRegistersForCallingConv(CLI.RetTy->getContext(),
                                                        CLI.CallConv, VT);
 
-      ReturnValues.push_back(getCopyFromParts(CLI.DAG, CLI.DL, &InVals[CurReg],
-                                              NumRegs, RegisterVT, VT, nullptr,
-                                              CLI.CallConv, AssertOp));
+      ReturnValues.push_back(getCopyFromParts(
+          CLI.DAG, CLI.DL, &InVals[CurReg], NumRegs, RegisterVT, VT, nullptr,
+          CLI.Chain, CLI.CallConv, AssertOp));
       CurReg += NumRegs;
     }
 
@@ -11122,8 +11136,9 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
     MVT VT = ValueVTs[0].getSimpleVT();
     MVT RegVT = TLI->getRegisterType(*CurDAG->getContext(), VT);
     std::optional<ISD::NodeType> AssertOp;
-    SDValue ArgValue = getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT,
-                                        nullptr, F.getCallingConv(), AssertOp);
+    SDValue ArgValue =
+        getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT, nullptr, NewRoot,
+                         F.getCallingConv(), AssertOp);
 
     MachineFunction& MF = SDB->DAG.getMachineFunction();
     MachineRegisterInfo& RegInfo = MF.getRegInfo();
@@ -11195,7 +11210,7 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
           AssertOp = ISD::AssertZext;
 
         ArgValues.push_back(getCopyFromParts(DAG, dl, &InVals[i], NumParts,
-                                             PartVT, VT, nullptr,
+                                             PartVT, VT, nullptr, NewRoot,
                                              F.getCallingConv(), AssertOp));
       }
 
diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td
index 9362fe5d9678b..d23d6b4f93da8 100644
--- a/llvm/lib/Target/AMDGPU/SIInstructions.td
+++ b/llvm/lib/Target/AMDGPU/SIInstructions.td
@@ -1097,7 +1097,7 @@ def : Pat <
 multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16_inst_e64> {
   // f16_to_fp patterns
   def : GCNPat <
-    (f32 (f16_to_fp i32:$src0)),
+    (f32 (any_f16_to_fp i32:$src0)),
     (cvt_f32_f16_inst_e64 SRCMODS.NONE, $src0)
   >;
 
@@ -1151,6 +1151,13 @@ multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16
     (f16 (uint_to_fp i32:$src)),
     (cvt_f16_f32_inst_e64 SRCMODS.NONE, (V_CVT_F32_U32_e32 VSrc_b32:$src))
   >;
+
+  // This is only used on targets without half support
+  // TODO: Introduce strict variant of AMDGPUfp_to_f16 and share custom lowering
+  def : GCNPat <
+    (i32 (strict_fp_to_f16 (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))),
+    (cvt_f16_f32_inst_e64 $src0_modifiers, f32:$src0)
+  >;
 }
 
 let SubtargetPredicate = NotHasTrue16BitInsts in
diff --git a/llvm/test/CodeGen/AArch64/strictfp_f16_abi_promote.ll b/llvm/test/CodeGen/AArch64/strictfp_f16_abi_promote.ll
new file mode 100644
index 0000000000000..37186cf22ccc7
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/strictfp_f16_abi_promote.ll
@@ -...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 6, 2023

@llvm/pr-subscribers-backend-aarch64

Author: Matt Arsenault (arsenm)

Changes

This would result in nodes not getting appropriately re-legalized
in the strictfp case.


Patch is 87.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74558.diff

11 Files Affected:

  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+74-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+14-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+37-22)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+8-1)
  • (added) llvm/test/CodeGen/AArch64/strictfp_f16_abi_promote.ll (+344)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll (+3)
  • (modified) llvm/test/CodeGen/AMDGPU/strict_fpext.ll (+254-26)
  • (modified) llvm/test/CodeGen/AMDGPU/strict_fptrunc.ll (+227-21)
  • (added) llvm/test/CodeGen/AMDGPU/strictfp_f16_abi_promote.ll (+558)
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 798e6a1d9525e..de7bf26868b34 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -614,6 +614,12 @@ def strict_sint_to_fp : SDNode<"ISD::STRICT_SINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
 def strict_uint_to_fp : SDNode<"ISD::STRICT_UINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
+
+def strict_f16_to_fp  : SDNode<"ISD::STRICT_FP16_TO_FP",
+                               SDTIntToFPOp, [SDNPHasChain]>;
+def strict_fp_to_f16  : SDNode<"ISD::STRICT_FP_TO_FP16",
+                               SDTFPToIntOp, [SDNPHasChain]>;
+
 def strict_fsetcc  : SDNode<"ISD::STRICT_FSETCC",  SDTSetCC, [SDNPHasChain]>;
 def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
 
@@ -1558,6 +1564,13 @@ def any_fsetccs : PatFrags<(ops node:$lhs, node:$rhs, node:$pred),
                           [(strict_fsetccs node:$lhs, node:$rhs, node:$pred),
                            (setcc node:$lhs, node:$rhs, node:$pred)]>;
 
+def any_f16_to_fp : PatFrags<(ops node:$src),
+                              [(f16_to_fp node:$src),
+                               (strict_f16_to_fp node:$src)]>;
+def any_fp_to_f16 : PatFrags<(ops node:$src),
+                              [(fp_to_f16 node:$src),
+                               (strict_fp_to_f16 node:$src)]>;
+
 multiclass binary_atomic_op_ord {
   def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
       (!cast<SDPatternOperator>(NAME) node:$ptr, node:$val)> {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 630aa4a07d7b9..1985af7bca5c4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -522,8 +522,11 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
     Op = GetPromotedFloat(Op);
     // If the promotion did the FP_EXTEND to the destination type for us,
     // there's nothing left to do here.
-    if (Op.getValueType() == N->getValueType(0))
+    if (Op.getValueType() == N->getValueType(0)) {
+      if (IsStrict)
+        ReplaceValueWith(SDValue(N, 1), Chain);
       return BitConvertToInteger(Op);
+    }
   }
 
   // There's only a libcall for f16 -> f32 and shifting is only valid for bf16
@@ -541,8 +544,10 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
     }
   }
 
-  if (Op.getValueType() == MVT::bf16)
+  if (Op.getValueType() == MVT::bf16) {
+    // FIXME: Need ReplaceValueWith on chain in strict case
     return SoftenFloatRes_BF16_TO_FP(N);
+  }
 
   RTLIB::Libcall LC = RTLIB::getFPEXT(Op.getValueType(), N->getValueType(0));
   assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported FP_EXTEND!");
@@ -2181,6 +2186,24 @@ static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
   report_fatal_error("Attempt at an invalid promotion-related conversion");
 }
 
+static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
+  if (OpVT == MVT::f16)
+    return ISD::STRICT_FP16_TO_FP;
+
+  if (RetVT == MVT::f16)
+    return ISD::STRICT_FP_TO_FP16;
+
+  if (OpVT == MVT::bf16) {
+    // TODO: return ISD::STRICT_BF16_TO_FP;
+  }
+
+  if (RetVT == MVT::bf16) {
+    // TODO: return ISD::STRICT_FP_TO_BF16;
+  }
+
+  report_fatal_error("Attempt at an invalid promotion-related conversion");
+}
+
 bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
   LLVM_DEBUG(dbgs() << "Promote float operand " << OpNo << ": "; N->dump(&DAG));
   SDValue R = SDValue();
@@ -2214,6 +2237,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::FP_TO_UINT_SAT:
                           R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
     case ISD::FP_EXTEND:  R = PromoteFloatOp_FP_EXTEND(N, OpNo); break;
+    case ISD::STRICT_FP_EXTEND:
+      R = PromoteFloatOp_STRICT_FP_EXTEND(N, OpNo);
+      break;
     case ISD::SELECT_CC:  R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
     case ISD::SETCC:      R = PromoteFloatOp_SETCC(N, OpNo); break;
     case ISD::STORE:      R = PromoteFloatOp_STORE(N, OpNo); break;
@@ -2276,6 +2302,26 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo) {
   return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, Op);
 }
 
+SDValue DAGTypeLegalizer::PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N,
+                                                          unsigned OpNo) {
+  assert(OpNo == 1 && "Promoting unpromotable operand");
+
+  SDValue Op = GetPromotedFloat(N->getOperand(1));
+  EVT VT = N->getValueType(0);
+
+  // Desired VT is same as promoted type.  Use promoted float directly.
+  if (VT == Op->getValueType(0)) {
+    ReplaceValueWith(SDValue(N, 1), N->getOperand(0));
+    return Op;
+  }
+
+  // Else, extend the promoted float value to the desired VT.
+  SDValue Res = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N), N->getVTList(),
+                            N->getOperand(0), Op);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 // Promote the float operands used for comparison.  The true- and false-
 // operands have the same type as the result and are promoted, if needed, by
 // PromoteFloatRes_SELECT_CC
@@ -2393,6 +2439,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::FFREXP:     R = PromoteFloatRes_FFREXP(N); break;
 
     case ISD::FP_ROUND:   R = PromoteFloatRes_FP_ROUND(N); break;
+    case ISD::STRICT_FP_ROUND:
+      R = PromoteFloatRes_STRICT_FP_ROUND(N);
+      break;
     case ISD::LOAD:       R = PromoteFloatRes_LOAD(N); break;
     case ISD::SELECT:     R = PromoteFloatRes_SELECT(N); break;
     case ISD::SELECT_CC:  R = PromoteFloatRes_SELECT_CC(N); break;
@@ -2598,6 +2647,29 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_FP_ROUND(SDNode *N) {
   return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round);
 }
 
+// Explicit operation to reduce precision.  Reduce the value to half precision
+// and promote it back to the legal type.
+SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_FP_ROUND(SDNode *N) {
+  SDLoc DL(N);
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Op = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  EVT OpVT = Op->getValueType(0);
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+
+  // Round promoted float to desired precision
+  SDValue Round = DAG.getNode(GetPromotionOpcodeStrict(OpVT, VT), DL,
+                              DAG.getVTList(IVT, MVT::Other), Chain, Op);
+  // Promote it back to the legal output type
+  SDValue Res =
+      DAG.getNode(GetPromotionOpcodeStrict(VT, NVT), DL,
+                  DAG.getVTList(NVT, MVT::Other), Round.getValue(1), Round);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteFloatRes_LOAD(SDNode *N) {
   LoadSDNode *L = cast<LoadSDNode>(N);
   EVT VT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 54698edce7d6f..887670fb6baff 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,7 +165,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
     break;
-
+  case ISD::STRICT_FP_TO_FP16:
+    Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
+    break;
   case ISD::GET_ROUNDING: Res = PromoteIntRes_GET_ROUNDING(N); break;
 
   case ISD::AND:
@@ -787,6 +789,16 @@ SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16_BF16(SDNode *N) {
   return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N) {
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDLoc dl(N);
+
+  SDValue Res = DAG.getNode(N->getOpcode(), dl, DAG.getVTList(NVT, MVT::Other),
+                            N->getOperand(0), N->getOperand(1));
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   SDLoc dl(N);
@@ -1804,6 +1816,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::FP16_TO_FP:
   case ISD::VP_UINT_TO_FP:
   case ISD::UINT_TO_FP:   Res = PromoteIntOp_UINT_TO_FP(N); break;
+  case ISD::STRICT_FP16_TO_FP:
   case ISD::STRICT_UINT_TO_FP:  Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
   case ISD::ZERO_EXTEND:  Res = PromoteIntOp_ZERO_EXTEND(N); break;
   case ISD::VP_ZERO_EXTEND: Res = PromoteIntOp_VP_ZERO_EXTEND(N); break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index e9bd54089d062..9361e7fff2190 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -326,6 +326,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
+  SDValue PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N);
   SDValue PromoteIntRes_XRINT(SDNode *N);
   SDValue PromoteIntRes_FREEZE(SDNode *N);
   SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
@@ -698,6 +699,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatRes_ExpOp(SDNode *N);
   SDValue PromoteFloatRes_FFREXP(SDNode *N);
   SDValue PromoteFloatRes_FP_ROUND(SDNode *N);
+  SDValue PromoteFloatRes_STRICT_FP_ROUND(SDNode *N);
   SDValue PromoteFloatRes_LOAD(SDNode *N);
   SDValue PromoteFloatRes_SELECT(SDNode *N);
   SDValue PromoteFloatRes_SELECT_CC(SDNode *N);
@@ -712,6 +714,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index da7d9ace4114a..29684d3372bdb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -153,6 +153,7 @@ static const unsigned MaxParallelChains = 64;
 static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
                                       const SDValue *Parts, unsigned NumParts,
                                       MVT PartVT, EVT ValueVT, const Value *V,
+                                      SDValue InChain,
                                       std::optional<CallingConv::ID> CC);
 
 /// getCopyFromParts - Create a value that contains the specified legal parts
@@ -163,6 +164,7 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
 static SDValue
 getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
                  unsigned NumParts, MVT PartVT, EVT ValueVT, const Value *V,
+                 SDValue InChain,
                  std::optional<CallingConv::ID> CC = std::nullopt,
                  std::optional<ISD::NodeType> AssertOp = std::nullopt) {
   // Let the target assemble the parts if it wants to
@@ -173,7 +175,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
 
   if (ValueVT.isVector())
     return getCopyFromPartsVector(DAG, DL, Parts, NumParts, PartVT, ValueVT, V,
-                                  CC);
+                                  InChain, CC);
 
   assert(NumParts > 0 && "No parts to assemble!");
   SDValue Val = Parts[0];
@@ -194,10 +196,10 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
       EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), RoundBits/2);
 
       if (RoundParts > 2) {
-        Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2,
-                              PartVT, HalfVT, V);
-        Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2,
-                              RoundParts / 2, PartVT, HalfVT, V);
+        Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2, PartVT, HalfVT, V,
+                              InChain);
+        Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2, RoundParts / 2,
+                              PartVT, HalfVT, V, InChain);
       } else {
         Lo = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[0]);
         Hi = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[1]);
@@ -213,7 +215,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
         unsigned OddParts = NumParts - RoundParts;
         EVT OddVT = EVT::getIntegerVT(*DAG.getContext(), OddParts * PartBits);
         Hi = getCopyFromParts(DAG, DL, Parts + RoundParts, OddParts, PartVT,
-                              OddVT, V, CC);
+                              OddVT, V, InChain, CC);
 
         // Combine the round and odd parts.
         Lo = Val;
@@ -243,7 +245,8 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
       assert(ValueVT.isFloatingPoint() && PartVT.isInteger() &&
              !PartVT.isVector() && "Unexpected split");
       EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), ValueVT.getSizeInBits());
-      Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V, CC);
+      Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V,
+                             InChain, CC);
     }
   }
 
@@ -283,10 +286,20 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
 
   if (PartEVT.isFloatingPoint() && ValueVT.isFloatingPoint()) {
     // FP_ROUND's are always exact here.
-    if (ValueVT.bitsLT(Val.getValueType()))
-      return DAG.getNode(
-          ISD::FP_ROUND, DL, ValueVT, Val,
-          DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout())));
+    if (ValueVT.bitsLT(Val.getValueType())) {
+
+      SDValue NoChange =
+          DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout()));
+
+      if (DAG.getMachineFunction().getFunction().getAttributes().hasFnAttr(
+              llvm::Attribute::StrictFP)) {
+        return DAG.getNode(ISD::STRICT_FP_ROUND, DL,
+                           DAG.getVTList(ValueVT, MVT::Other), InChain, Val,
+                           NoChange);
+      }
+
+      return DAG.getNode(ISD::FP_ROUND, DL, ValueVT, Val, NoChange);
+    }
 
     return DAG.getNode(ISD::FP_EXTEND, DL, ValueVT, Val);
   }
@@ -324,6 +337,7 @@ static void diagnosePossiblyInvalidConstraint(LLVMContext &Ctx, const Value *V,
 static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
                                       const SDValue *Parts, unsigned NumParts,
                                       MVT PartVT, EVT ValueVT, const Value *V,
+                                      SDValue InChain,
                                       std::optional<CallingConv::ID> CallConv) {
   assert(ValueVT.isVector() && "Not a vector value");
   assert(NumParts > 0 && "No parts to assemble!");
@@ -362,8 +376,8 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
       // If the register was not expanded, truncate or copy the value,
       // as appropriate.
       for (unsigned i = 0; i != NumParts; ++i)
-        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1,
-                                  PartVT, IntermediateVT, V, CallConv);
+        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1, PartVT, IntermediateVT,
+                                  V, InChain, CallConv);
     } else if (NumParts > 0) {
       // If the intermediate type was expanded, build the intermediate
       // operands from the parts.
@@ -371,8 +385,8 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
              "Must expand into a divisible number of parts!");
       unsigned Factor = NumParts / NumIntermediates;
       for (unsigned i = 0; i != NumIntermediates; ++i)
-        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor,
-                                  PartVT, IntermediateVT, V, CallConv);
+        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor, PartVT,
+                                  IntermediateVT, V, InChain, CallConv);
     }
 
     // Build a vector with BUILD_VECTOR or CONCAT_VECTORS from the
@@ -926,7 +940,7 @@ SDValue RegsForValue::getCopyFromRegs(SelectionDAG &DAG,
     }
 
     Values[Value] = getCopyFromParts(DAG, dl, Parts.begin(), NumRegs,
-                                     RegisterVT, ValueVT, V, CallConv);
+                                     RegisterVT, ValueVT, V, Chain, CallConv);
     Part += NumRegs;
     Parts.clear();
   }
@@ -10641,9 +10655,9 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
       unsigned NumRegs = getNumRegistersForCallingConv(CLI.RetTy->getContext(),
                                                        CLI.CallConv, VT);
 
-      ReturnValues.push_back(getCopyFromParts(CLI.DAG, CLI.DL, &InVals[CurReg],
-                                              NumRegs, RegisterVT, VT, nullptr,
-                                              CLI.CallConv, AssertOp));
+      ReturnValues.push_back(getCopyFromParts(
+          CLI.DAG, CLI.DL, &InVals[CurReg], NumRegs, RegisterVT, VT, nullptr,
+          CLI.Chain, CLI.CallConv, AssertOp));
       CurReg += NumRegs;
     }
 
@@ -11122,8 +11136,9 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
     MVT VT = ValueVTs[0].getSimpleVT();
     MVT RegVT = TLI->getRegisterType(*CurDAG->getContext(), VT);
     std::optional<ISD::NodeType> AssertOp;
-    SDValue ArgValue = getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT,
-                                        nullptr, F.getCallingConv(), AssertOp);
+    SDValue ArgValue =
+        getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT, nullptr, NewRoot,
+                         F.getCallingConv(), AssertOp);
 
     MachineFunction& MF = SDB->DAG.getMachineFunction();
     MachineRegisterInfo& RegInfo = MF.getRegInfo();
@@ -11195,7 +11210,7 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
           AssertOp = ISD::AssertZext;
 
         ArgValues.push_back(getCopyFromParts(DAG, dl, &InVals[i], NumParts,
-                                             PartVT, VT, nullptr,
+                                             PartVT, VT, nullptr, NewRoot,
                                              F.getCallingConv(), AssertOp));
       }
 
diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td
index 9362fe5d9678b..d23d6b4f93da8 100644
--- a/llvm/lib/Target/AMDGPU/SIInstructions.td
+++ b/llvm/lib/Target/AMDGPU/SIInstructions.td
@@ -1097,7 +1097,7 @@ def : Pat <
 multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16_inst_e64> {
   // f16_to_fp patterns
   def : GCNPat <
-    (f32 (f16_to_fp i32:$src0)),
+    (f32 (any_f16_to_fp i32:$src0)),
     (cvt_f32_f16_inst_e64 SRCMODS.NONE, $src0)
   >;
 
@@ -1151,6 +1151,13 @@ multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16
     (f16 (uint_to_fp i32:$src)),
     (cvt_f16_f32_inst_e64 SRCMODS.NONE, (V_CVT_F32_U32_e32 VSrc_b32:$src))
   >;
+
+  // This is only used on targets without half support
+  // TODO: Introduce strict variant of AMDGPUfp_to_f16 and share custom lowering
+  def : GCNPat <
+    (i32 (strict_fp_to_f16 (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))),
+    (cvt_f16_f32_inst_e64 $src0_modifiers, f32:$src0)
+  >;
 }
 
 let SubtargetPredicate = NotHasTrue16BitInsts in
diff --git a/llvm/test/CodeGen/AArch64/strictfp_f16_abi_promote.ll b/llvm/test/CodeGen/AArch64/strictfp_f16_abi_promote.ll
new file mode 100644
index 0000000000000..37186cf22ccc7
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/strictfp_f16_abi_promote.ll
@@ -...
[truncated]

Copy link

github-actions bot commented Dec 6, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@arsenm
Copy link
Contributor Author

arsenm commented Jan 4, 2024

ping

This was emitting non-strict casts in ABI contexts for illegal
types.
This would result in nodes not getting appropriately re-legalized
in the strictfp case.
@arsenm arsenm force-pushed the dag-soften-float-res-strict-fpextend-chain-2 branch from 1b598cd to 9757953 Compare January 5, 2024 01:49
@arsenm
Copy link
Contributor Author

arsenm commented Jan 9, 2024

ping

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@arsenm arsenm merged commit c8007f9 into llvm:main Jan 18, 2024
3 of 4 checks passed
@arsenm arsenm deleted the dag-soften-float-res-strict-fpextend-chain-2 branch January 19, 2024 03:15
ampandey-1995 pushed a commit to ampandey-1995/llvm-project that referenced this pull request Jan 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants