Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DAG: Fix ABI lowering with FP promote in strictfp functions #74405

Merged
merged 3 commits into from
Jan 18, 2024

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Dec 5, 2023

This was emitting non-strict casts in ABI contexts for illegal
types.

@arsenm arsenm added floating-point Floating-point math llvm:SelectionDAG SelectionDAGISel as well labels Dec 5, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 5, 2023

@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-llvm-selectiondag

Author: Matt Arsenault (arsenm)

Changes

This was emitting non-strict casts in ABI contexts for illegal
types.


Patch is 72.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74405.diff

10 Files Affected:

  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+63)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+14-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+37-22)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+8-1)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll (+3)
  • (modified) llvm/test/CodeGen/AMDGPU/strict_fpext.ll (+254-26)
  • (modified) llvm/test/CodeGen/AMDGPU/strict_fptrunc.ll (+227-21)
  • (added) llvm/test/CodeGen/AMDGPU/strictfp_f16_abi_promote.ll (+558)
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 798e6a1d9525e..de7bf26868b34 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -614,6 +614,12 @@ def strict_sint_to_fp : SDNode<"ISD::STRICT_SINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
 def strict_uint_to_fp : SDNode<"ISD::STRICT_UINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
+
+def strict_f16_to_fp  : SDNode<"ISD::STRICT_FP16_TO_FP",
+                               SDTIntToFPOp, [SDNPHasChain]>;
+def strict_fp_to_f16  : SDNode<"ISD::STRICT_FP_TO_FP16",
+                               SDTFPToIntOp, [SDNPHasChain]>;
+
 def strict_fsetcc  : SDNode<"ISD::STRICT_FSETCC",  SDTSetCC, [SDNPHasChain]>;
 def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
 
@@ -1558,6 +1564,13 @@ def any_fsetccs : PatFrags<(ops node:$lhs, node:$rhs, node:$pred),
                           [(strict_fsetccs node:$lhs, node:$rhs, node:$pred),
                            (setcc node:$lhs, node:$rhs, node:$pred)]>;
 
+def any_f16_to_fp : PatFrags<(ops node:$src),
+                              [(f16_to_fp node:$src),
+                               (strict_f16_to_fp node:$src)]>;
+def any_fp_to_f16 : PatFrags<(ops node:$src),
+                              [(fp_to_f16 node:$src),
+                               (strict_fp_to_f16 node:$src)]>;
+
 multiclass binary_atomic_op_ord {
   def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
       (!cast<SDPatternOperator>(NAME) node:$ptr, node:$val)> {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 630aa4a07d7b9..d7a688511b726 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2181,6 +2181,20 @@ static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
   report_fatal_error("Attempt at an invalid promotion-related conversion");
 }
 
+static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
+  if (OpVT == MVT::f16) {
+    return ISD::STRICT_FP16_TO_FP;
+  } else if (RetVT == MVT::f16) {
+    return ISD::STRICT_FP_TO_FP16;
+  } else if (OpVT == MVT::bf16) {
+    // return ISD::STRICT_BF16_TO_FP;
+  } else if (RetVT == MVT::bf16) {
+    // return ISD::STRICT_FP_TO_BF16;
+  }
+
+  report_fatal_error("Attempt at an invalid promotion-related conversion");
+}
+
 bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
   LLVM_DEBUG(dbgs() << "Promote float operand " << OpNo << ": "; N->dump(&DAG));
   SDValue R = SDValue();
@@ -2214,6 +2228,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::FP_TO_UINT_SAT:
                           R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
     case ISD::FP_EXTEND:  R = PromoteFloatOp_FP_EXTEND(N, OpNo); break;
+    case ISD::STRICT_FP_EXTEND:
+      R = PromoteFloatOp_STRICT_FP_EXTEND(N, OpNo);
+      break;
     case ISD::SELECT_CC:  R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
     case ISD::SETCC:      R = PromoteFloatOp_SETCC(N, OpNo); break;
     case ISD::STORE:      R = PromoteFloatOp_STORE(N, OpNo); break;
@@ -2276,6 +2293,26 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo) {
   return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, Op);
 }
 
+SDValue DAGTypeLegalizer::PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N,
+                                                          unsigned OpNo) {
+  assert(OpNo == 1);
+
+  SDValue Op = GetPromotedFloat(N->getOperand(1));
+  EVT VT = N->getValueType(0);
+
+  // Desired VT is same as promoted type.  Use promoted float directly.
+  if (VT == Op->getValueType(0)) {
+    ReplaceValueWith(SDValue(N, 1), N->getOperand(0));
+    return Op;
+  }
+
+  // Else, extend the promoted float value to the desired VT.
+  SDValue Res = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N), N->getVTList(),
+                            N->getOperand(0), Op);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 // Promote the float operands used for comparison.  The true- and false-
 // operands have the same type as the result and are promoted, if needed, by
 // PromoteFloatRes_SELECT_CC
@@ -2393,6 +2430,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::FFREXP:     R = PromoteFloatRes_FFREXP(N); break;
 
     case ISD::FP_ROUND:   R = PromoteFloatRes_FP_ROUND(N); break;
+    case ISD::STRICT_FP_ROUND:
+      R = PromoteFloatRes_STRICT_FP_ROUND(N);
+      break;
     case ISD::LOAD:       R = PromoteFloatRes_LOAD(N); break;
     case ISD::SELECT:     R = PromoteFloatRes_SELECT(N); break;
     case ISD::SELECT_CC:  R = PromoteFloatRes_SELECT_CC(N); break;
@@ -2598,6 +2638,29 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_FP_ROUND(SDNode *N) {
   return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round);
 }
 
+// Explicit operation to reduce precision.  Reduce the value to half precision
+// and promote it back to the legal type.
+SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_FP_ROUND(SDNode *N) {
+  SDLoc DL(N);
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Op = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  EVT OpVT = Op->getValueType(0);
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+
+  // Round promoted float to desired precision
+  SDValue Round = DAG.getNode(GetPromotionOpcodeStrict(OpVT, VT), DL,
+                              DAG.getVTList(IVT, MVT::Other), Chain, Op);
+  // Promote it back to the legal output type
+  SDValue Res =
+      DAG.getNode(GetPromotionOpcodeStrict(VT, NVT), DL,
+                  DAG.getVTList(NVT, MVT::Other), Round.getValue(1), Round);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteFloatRes_LOAD(SDNode *N) {
   LoadSDNode *L = cast<LoadSDNode>(N);
   EVT VT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 54698edce7d6f..887670fb6baff 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,7 +165,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
     break;
-
+  case ISD::STRICT_FP_TO_FP16:
+    Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
+    break;
   case ISD::GET_ROUNDING: Res = PromoteIntRes_GET_ROUNDING(N); break;
 
   case ISD::AND:
@@ -787,6 +789,16 @@ SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16_BF16(SDNode *N) {
   return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N) {
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDLoc dl(N);
+
+  SDValue Res = DAG.getNode(N->getOpcode(), dl, DAG.getVTList(NVT, MVT::Other),
+                            N->getOperand(0), N->getOperand(1));
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   SDLoc dl(N);
@@ -1804,6 +1816,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::FP16_TO_FP:
   case ISD::VP_UINT_TO_FP:
   case ISD::UINT_TO_FP:   Res = PromoteIntOp_UINT_TO_FP(N); break;
+  case ISD::STRICT_FP16_TO_FP:
   case ISD::STRICT_UINT_TO_FP:  Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
   case ISD::ZERO_EXTEND:  Res = PromoteIntOp_ZERO_EXTEND(N); break;
   case ISD::VP_ZERO_EXTEND: Res = PromoteIntOp_VP_ZERO_EXTEND(N); break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index e9bd54089d062..9361e7fff2190 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -326,6 +326,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
+  SDValue PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N);
   SDValue PromoteIntRes_XRINT(SDNode *N);
   SDValue PromoteIntRes_FREEZE(SDNode *N);
   SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
@@ -698,6 +699,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatRes_ExpOp(SDNode *N);
   SDValue PromoteFloatRes_FFREXP(SDNode *N);
   SDValue PromoteFloatRes_FP_ROUND(SDNode *N);
+  SDValue PromoteFloatRes_STRICT_FP_ROUND(SDNode *N);
   SDValue PromoteFloatRes_LOAD(SDNode *N);
   SDValue PromoteFloatRes_SELECT(SDNode *N);
   SDValue PromoteFloatRes_SELECT_CC(SDNode *N);
@@ -712,6 +714,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index da7d9ace4114a..29684d3372bdb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -153,6 +153,7 @@ static const unsigned MaxParallelChains = 64;
 static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
                                       const SDValue *Parts, unsigned NumParts,
                                       MVT PartVT, EVT ValueVT, const Value *V,
+                                      SDValue InChain,
                                       std::optional<CallingConv::ID> CC);
 
 /// getCopyFromParts - Create a value that contains the specified legal parts
@@ -163,6 +164,7 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
 static SDValue
 getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
                  unsigned NumParts, MVT PartVT, EVT ValueVT, const Value *V,
+                 SDValue InChain,
                  std::optional<CallingConv::ID> CC = std::nullopt,
                  std::optional<ISD::NodeType> AssertOp = std::nullopt) {
   // Let the target assemble the parts if it wants to
@@ -173,7 +175,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
 
   if (ValueVT.isVector())
     return getCopyFromPartsVector(DAG, DL, Parts, NumParts, PartVT, ValueVT, V,
-                                  CC);
+                                  InChain, CC);
 
   assert(NumParts > 0 && "No parts to assemble!");
   SDValue Val = Parts[0];
@@ -194,10 +196,10 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
       EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), RoundBits/2);
 
       if (RoundParts > 2) {
-        Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2,
-                              PartVT, HalfVT, V);
-        Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2,
-                              RoundParts / 2, PartVT, HalfVT, V);
+        Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2, PartVT, HalfVT, V,
+                              InChain);
+        Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2, RoundParts / 2,
+                              PartVT, HalfVT, V, InChain);
       } else {
         Lo = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[0]);
         Hi = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[1]);
@@ -213,7 +215,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
         unsigned OddParts = NumParts - RoundParts;
         EVT OddVT = EVT::getIntegerVT(*DAG.getContext(), OddParts * PartBits);
         Hi = getCopyFromParts(DAG, DL, Parts + RoundParts, OddParts, PartVT,
-                              OddVT, V, CC);
+                              OddVT, V, InChain, CC);
 
         // Combine the round and odd parts.
         Lo = Val;
@@ -243,7 +245,8 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
       assert(ValueVT.isFloatingPoint() && PartVT.isInteger() &&
              !PartVT.isVector() && "Unexpected split");
       EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), ValueVT.getSizeInBits());
-      Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V, CC);
+      Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V,
+                             InChain, CC);
     }
   }
 
@@ -283,10 +286,20 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
 
   if (PartEVT.isFloatingPoint() && ValueVT.isFloatingPoint()) {
     // FP_ROUND's are always exact here.
-    if (ValueVT.bitsLT(Val.getValueType()))
-      return DAG.getNode(
-          ISD::FP_ROUND, DL, ValueVT, Val,
-          DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout())));
+    if (ValueVT.bitsLT(Val.getValueType())) {
+
+      SDValue NoChange =
+          DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout()));
+
+      if (DAG.getMachineFunction().getFunction().getAttributes().hasFnAttr(
+              llvm::Attribute::StrictFP)) {
+        return DAG.getNode(ISD::STRICT_FP_ROUND, DL,
+                           DAG.getVTList(ValueVT, MVT::Other), InChain, Val,
+                           NoChange);
+      }
+
+      return DAG.getNode(ISD::FP_ROUND, DL, ValueVT, Val, NoChange);
+    }
 
     return DAG.getNode(ISD::FP_EXTEND, DL, ValueVT, Val);
   }
@@ -324,6 +337,7 @@ static void diagnosePossiblyInvalidConstraint(LLVMContext &Ctx, const Value *V,
 static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
                                       const SDValue *Parts, unsigned NumParts,
                                       MVT PartVT, EVT ValueVT, const Value *V,
+                                      SDValue InChain,
                                       std::optional<CallingConv::ID> CallConv) {
   assert(ValueVT.isVector() && "Not a vector value");
   assert(NumParts > 0 && "No parts to assemble!");
@@ -362,8 +376,8 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
       // If the register was not expanded, truncate or copy the value,
       // as appropriate.
       for (unsigned i = 0; i != NumParts; ++i)
-        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1,
-                                  PartVT, IntermediateVT, V, CallConv);
+        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1, PartVT, IntermediateVT,
+                                  V, InChain, CallConv);
     } else if (NumParts > 0) {
       // If the intermediate type was expanded, build the intermediate
       // operands from the parts.
@@ -371,8 +385,8 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
              "Must expand into a divisible number of parts!");
       unsigned Factor = NumParts / NumIntermediates;
       for (unsigned i = 0; i != NumIntermediates; ++i)
-        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor,
-                                  PartVT, IntermediateVT, V, CallConv);
+        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor, PartVT,
+                                  IntermediateVT, V, InChain, CallConv);
     }
 
     // Build a vector with BUILD_VECTOR or CONCAT_VECTORS from the
@@ -926,7 +940,7 @@ SDValue RegsForValue::getCopyFromRegs(SelectionDAG &DAG,
     }
 
     Values[Value] = getCopyFromParts(DAG, dl, Parts.begin(), NumRegs,
-                                     RegisterVT, ValueVT, V, CallConv);
+                                     RegisterVT, ValueVT, V, Chain, CallConv);
     Part += NumRegs;
     Parts.clear();
   }
@@ -10641,9 +10655,9 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
       unsigned NumRegs = getNumRegistersForCallingConv(CLI.RetTy->getContext(),
                                                        CLI.CallConv, VT);
 
-      ReturnValues.push_back(getCopyFromParts(CLI.DAG, CLI.DL, &InVals[CurReg],
-                                              NumRegs, RegisterVT, VT, nullptr,
-                                              CLI.CallConv, AssertOp));
+      ReturnValues.push_back(getCopyFromParts(
+          CLI.DAG, CLI.DL, &InVals[CurReg], NumRegs, RegisterVT, VT, nullptr,
+          CLI.Chain, CLI.CallConv, AssertOp));
       CurReg += NumRegs;
     }
 
@@ -11122,8 +11136,9 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
     MVT VT = ValueVTs[0].getSimpleVT();
     MVT RegVT = TLI->getRegisterType(*CurDAG->getContext(), VT);
     std::optional<ISD::NodeType> AssertOp;
-    SDValue ArgValue = getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT,
-                                        nullptr, F.getCallingConv(), AssertOp);
+    SDValue ArgValue =
+        getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT, nullptr, NewRoot,
+                         F.getCallingConv(), AssertOp);
 
     MachineFunction& MF = SDB->DAG.getMachineFunction();
     MachineRegisterInfo& RegInfo = MF.getRegInfo();
@@ -11195,7 +11210,7 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
           AssertOp = ISD::AssertZext;
 
         ArgValues.push_back(getCopyFromParts(DAG, dl, &InVals[i], NumParts,
-                                             PartVT, VT, nullptr,
+                                             PartVT, VT, nullptr, NewRoot,
                                              F.getCallingConv(), AssertOp));
       }
 
diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td
index 9362fe5d9678b..d23d6b4f93da8 100644
--- a/llvm/lib/Target/AMDGPU/SIInstructions.td
+++ b/llvm/lib/Target/AMDGPU/SIInstructions.td
@@ -1097,7 +1097,7 @@ def : Pat <
 multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16_inst_e64> {
   // f16_to_fp patterns
   def : GCNPat <
-    (f32 (f16_to_fp i32:$src0)),
+    (f32 (any_f16_to_fp i32:$src0)),
     (cvt_f32_f16_inst_e64 SRCMODS.NONE, $src0)
   >;
 
@@ -1151,6 +1151,13 @@ multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16
     (f16 (uint_to_fp i32:$src)),
     (cvt_f16_f32_inst_e64 SRCMODS.NONE, (V_CVT_F32_U32_e32 VSrc_b32:$src))
   >;
+
+  // This is only used on targets without half support
+  // TODO: Introduce strict variant of AMDGPUfp_to_f16 and share custom lowering
+  def : GCNPat <
+    (i32 (strict_fp_to_f16 (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))),
+    (cvt_f16_f32_inst_e64 $src0_modifiers, f32:$src0)
+  >;
 }
 
 let SubtargetPredicate = NotHasTrue16BitInsts in
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll b/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll
index d74948a460c98..20d175ecfd909 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll
@@ -1316,6 +1316,9 @@ define i1 @isnan_f16_strictfp(half %x) strictfp nounwind {
 ; GFX7SELDAG-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX7SELDAG-NEXT:    v_cvt_f16_f32_e32 v0, v0
 ; GFX7SELDAG-NEXT:    s_movk_i32 s4, 0x7c00
+; GFX7SELDAG-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7SELDAG-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7SELDAG-NEXT:    v_cvt_f16_f32_e32 v0, v0
 ; GFX7SELDAG-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
 ; GFX7SELDAG-NEXT:    v_cmp_lt_i32_e32 vcc, s4, v0
 ; GFX7SELDAG-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc
diff --git a/llvm/test/CodeGen/AMDGPU/strict_fpext.ll b/llvm/test/CodeGen/AMDGPU/strict_fpext.ll
index 22bebb7ad26f5..fe59a8491c91a 100644
--- a/llvm/test/CodeGen/AMDGPU/strict_fpext.ll
+++ b/llvm/test/CodeGen/AMDGPU/strict_fpext.ll
@@ -1,22 +1,47 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
 ; FIXME: Missing operand promote for f16
-; XUN: llc -mtriple=amdgcn-mesa-mesa3d -mcpu=tahiti < %s | FileChe...
[truncated]

@arsenm
Copy link
Contributor Author

arsenm commented Dec 22, 2023

ping

1 similar comment
@arsenm
Copy link
Contributor Author

arsenm commented Jan 3, 2024

ping

This was emitting non-strict casts in ABI contexts for illegal
types.
@arsenm
Copy link
Contributor Author

arsenm commented Jan 9, 2024

ping

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@arsenm arsenm merged commit 11bf02e into llvm:main Jan 18, 2024
3 of 4 checks passed
@arsenm arsenm deleted the dag-strictfp-promote-abi branch January 18, 2024 03:57
ampandey-1995 pushed a commit to ampandey-1995/llvm-project that referenced this pull request Jan 19, 2024
This was emitting non-strict casts in ABI contexts for illegal
types.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU floating-point Floating-point math llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants