Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SelectionDAG] Add support for the 3-way comparison intrinsics [US]CMP #91871

Merged
merged 4 commits into from
Jun 17, 2024

Conversation

Poseydon42
Copy link
Contributor

@Poseydon42 Poseydon42 commented May 11, 2024

This PR adds initial support for the scmp/ucmp 3-way comparison intrinsics in the SelectionDAG.

What works as of now:

  • An invokation of the intrinsic in the IR gets properly lowered into SelectionDAG
  • A node with opcodes UCMP/SCMP gets properly expanded into two comparisons and two selects
  • Narrow scalar arguments and return types are properly handled (i.e. i3 or i51)
  • Wide scalar arguments and return types are properly handled (i.e. i87 or i139)
  • Vector arguments and return types are properly widened/split where necessary

Copy link

github-actions bot commented May 11, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff d98a78590f4f9e43fdfb69fde7d154a985e4560f 8ed9aaf3c608482570bfe57b3cbaa4a496323ffd -- llvm/include/llvm/CodeGen/ISDOpcodes.h llvm/include/llvm/CodeGen/TargetLowering.h llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp llvm/lib/CodeGen/TargetLoweringBase.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index e82ab9db4c..c963ffa998 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -1887,7 +1887,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::ROTR: Res = PromoteIntOp_Shift(N); break;
 
   case ISD::SCMP:
-  case ISD::UCMP: Res = PromoteIntOp_CMP(N); break;
+  case ISD::UCMP:
+    Res = PromoteIntOp_CMP(N);
+    break;
 
   case ISD::FSHL:
   case ISD::FSHR: Res = PromoteIntOp_FunnelShift(N); break;
@@ -2768,7 +2770,9 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::SMIN: ExpandIntRes_MINMAX(N, Lo, Hi); break;
 
   case ISD::SCMP:
-  case ISD::UCMP: ExpandIntRes_CMP(N, Lo, Hi); break;
+  case ISD::UCMP:
+    ExpandIntRes_CMP(N, Lo, Hi);
+    break;
 
   case ISD::ADD:
   case ISD::SUB: ExpandIntRes_ADDSUB(N, Lo, Hi); break;
@@ -5172,7 +5176,9 @@ bool DAGTypeLegalizer::ExpandIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::FRAMEADDR:         Res = ExpandIntOp_RETURNADDR(N); break;
 
   case ISD::SCMP:
-  case ISD::UCMP:              Res = ExpandIntOp_CMP(N); break;
+  case ISD::UCMP:
+    Res = ExpandIntOp_CMP(N);
+    break;
 
   case ISD::ATOMIC_STORE:      Res = ExpandIntOp_ATOMIC_STORE(N); break;
   case ISD::STACKMAP:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index db2cd5e40a..2aaf16d2a1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -459,7 +459,7 @@ private:
 
   void ExpandIntRes_MINMAX            (SDNode *N, SDValue &Lo, SDValue &Hi);
 
-  void ExpandIntRes_CMP               (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_CMP(SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index d52004553e..62cf740194 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1221,7 +1221,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     SplitVecRes_TernaryOp(N, Lo, Hi);
     break;
 
-  case ISD::SCMP: case ISD::UCMP:
+  case ISD::SCMP:
+  case ISD::UCMP:
     SplitVecRes_CMP(N, Lo, Hi);
     break;
 

llvm/include/llvm/CodeGen/ISDOpcodes.h Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp Outdated Show resolved Hide resolved
llvm/test/CodeGen/X86/uscmp.ll Outdated Show resolved Hide resolved
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build/bin/llc -mtriple=aarch64-unknown-linux-gnu < llvm/test/CodeGen/X86/uscmp.ll gives a selection error.

llvm/test/CodeGen/X86/uscmp.ll Outdated Show resolved Hide resolved
llvm/test/CodeGen/X86/uscmp.ll Outdated Show resolved Hide resolved
@Poseydon42 Poseydon42 changed the title [WIP][SelectionDAG] Add support for the 3-way comparison intrinsics [US]CMP [SelectionDAG] Add support for the 3-way comparison intrinsics [US]CMP May 27, 2024
@Poseydon42 Poseydon42 marked this pull request as ready for review May 27, 2024 16:18
@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well labels May 27, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 27, 2024

@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-x86

Author: None (Poseydon42)

Changes

This PR adds initial support for the scmp/ucmp 3-way comparison intrinsics in the SelectionDAG.

What works as of now:

  • An invokation of the intrinsic in the IR gets properly lowered into SelectionDAG
  • A node with opcodes UCMP/SCMP gets properly expanded into two comparisons and two selects
  • Narrow scalar arguments and return types are properly handled (i.e. i3 or i51)
  • Wide scalar arguments and return types are properly handled (i.e. i87 or i139)
  • Vector arguments and return types are properly widened/split where necessary

Patch is 95.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91871.diff

14 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+6)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+4)
  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+6)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+41)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+10)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+142)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+16)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+21)
  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+3)
  • (added) llvm/test/CodeGen/X86/scmp.ll (+776)
  • (added) llvm/test/CodeGen/X86/ucmp.ll (+1010)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 6429947958ee9..7d36f582244b0 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -677,6 +677,12 @@ enum NodeType {
   UMIN,
   UMAX,
 
+  /// [US]CMP - 3-way comparison of signed or unsigned integers. Returns -1, 0,
+  /// or 1 depending on whether Op0 <, ==, or > Op1. The operands can have type
+  /// different to the result.
+  SCMP,
+  UCMP,
+
   /// Bitwise operators - logical and, logical or, logical xor.
   AND,
   OR,
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 7ed08cfa8a202..82c450afcbcf6 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5402,6 +5402,10 @@ class TargetLowering : public TargetLoweringBase {
   /// method accepts integers as its arguments.
   SDValue expandAddSubSat(SDNode *Node, SelectionDAG &DAG) const;
 
+  /// Method for building the DAG expansion of ISD::[US]CMP. This
+  /// method accepts integers as its arguments
+  SDValue expandCMP(SDNode *Node, SelectionDAG &DAG) const;
+
   /// Method for building the DAG expansion of ISD::[US]SHLSAT. This
   /// method accepts integers as its arguments.
   SDValue expandShlSat(SDNode *Node, SelectionDAG &DAG) const;
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 1684b424e3b44..6d771521aa739 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -434,6 +434,11 @@ def umin       : SDNode<"ISD::UMIN"      , SDTIntBinOp,
 def umax       : SDNode<"ISD::UMAX"      , SDTIntBinOp,
                                   [SDNPCommutative, SDNPAssociative]>;
 
+def scmp       : SDNode<"ISD::SCMP"      , SDTIntBinOp,
+                                  []>;
+def ucmp       : SDNode<"ISD::UCMP"      , SDTIntBinOp,
+                                  []>;
+
 def saddsat    : SDNode<"ISD::SADDSAT"   , SDTIntBinOp, [SDNPCommutative]>;
 def uaddsat    : SDNode<"ISD::UADDSAT"   , SDTIntBinOp, [SDNPCommutative]>;
 def ssubsat    : SDNode<"ISD::SSUBSAT"   , SDTIntBinOp>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index bfc3e08c1632d..f7da195e03bd1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1148,6 +1148,8 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
   case ISD::USUBSAT:
   case ISD::SSHLSAT:
   case ISD::USHLSAT:
+  case ISD::SCMP:
+  case ISD::UCMP:
   case ISD::FP_TO_SINT_SAT:
   case ISD::FP_TO_UINT_SAT:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
@@ -3864,6 +3866,10 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
   case ISD::USUBSAT:
     Results.push_back(TLI.expandAddSubSat(Node, DAG));
     break;
+  case ISD::SCMP:
+  case ISD::UCMP:
+    Results.push_back(TLI.expandCMP(Node, DAG));
+    break;
   case ISD::SSHLSAT:
   case ISD::USHLSAT:
     Results.push_back(TLI.expandShlSat(Node, DAG));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 0aa36deda79dc..e82ab9db4c090 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -232,6 +232,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
     Res = PromoteIntRes_ADDSUBSHLSAT<VPMatchContext>(N);
     break;
 
+  case ISD::SCMP:
+  case ISD::UCMP:
+    Res = PromoteIntRes_CMP(N);
+    break;
+
   case ISD::SMULFIX:
   case ISD::SMULFIXSAT:
   case ISD::UMULFIX:
@@ -1246,6 +1251,13 @@ SDValue DAGTypeLegalizer::PromoteIntRes_SADDSUBO(SDNode *N, unsigned ResNo) {
   return Res;
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_CMP(SDNode *N) {
+  EVT PromotedResultTy =
+      TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  return DAG.getNode(N->getOpcode(), SDLoc(N), PromotedResultTy,
+                     N->getOperand(0), N->getOperand(1));
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_Select(SDNode *N) {
   SDValue Mask = N->getOperand(0);
 
@@ -1874,6 +1886,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::ROTL:
   case ISD::ROTR: Res = PromoteIntOp_Shift(N); break;
 
+  case ISD::SCMP:
+  case ISD::UCMP: Res = PromoteIntOp_CMP(N); break;
+
   case ISD::FSHL:
   case ISD::FSHR: Res = PromoteIntOp_FunnelShift(N); break;
 
@@ -2184,6 +2199,17 @@ SDValue DAGTypeLegalizer::PromoteIntOp_Shift(SDNode *N) {
                                 ZExtPromotedInteger(N->getOperand(1))), 0);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntOp_CMP(SDNode *N) {
+  SDValue LHS = N->getOpcode() == ISD::UCMP
+                    ? ZExtPromotedInteger(N->getOperand(0))
+                    : SExtPromotedInteger(N->getOperand(0));
+  SDValue RHS = N->getOpcode() == ISD::UCMP
+                    ? ZExtPromotedInteger(N->getOperand(1))
+                    : SExtPromotedInteger(N->getOperand(1));
+
+  return SDValue(DAG.UpdateNodeOperands(N, LHS, RHS), 0);
+}
+
 SDValue DAGTypeLegalizer::PromoteIntOp_FunnelShift(SDNode *N) {
   return SDValue(DAG.UpdateNodeOperands(N, N->getOperand(0), N->getOperand(1),
                                 ZExtPromotedInteger(N->getOperand(2))), 0);
@@ -2741,6 +2767,9 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::UMIN:
   case ISD::SMIN: ExpandIntRes_MINMAX(N, Lo, Hi); break;
 
+  case ISD::SCMP:
+  case ISD::UCMP: ExpandIntRes_CMP(N, Lo, Hi); break;
+
   case ISD::ADD:
   case ISD::SUB: ExpandIntRes_ADDSUB(N, Lo, Hi); break;
 
@@ -3233,6 +3262,11 @@ void DAGTypeLegalizer::ExpandIntRes_MINMAX(SDNode *N,
   SplitInteger(Result, Lo, Hi);
 }
 
+void DAGTypeLegalizer::ExpandIntRes_CMP(SDNode *N, SDValue &Lo, SDValue &Hi) {
+  SDValue ExpandedCMP = TLI.expandCMP(N, DAG);
+  SplitInteger(ExpandedCMP, Lo, Hi);
+}
+
 void DAGTypeLegalizer::ExpandIntRes_ADDSUB(SDNode *N,
                                            SDValue &Lo, SDValue &Hi) {
   SDLoc dl(N);
@@ -5137,6 +5171,9 @@ bool DAGTypeLegalizer::ExpandIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::RETURNADDR:
   case ISD::FRAMEADDR:         Res = ExpandIntOp_RETURNADDR(N); break;
 
+  case ISD::SCMP:
+  case ISD::UCMP:              Res = ExpandIntOp_CMP(N); break;
+
   case ISD::ATOMIC_STORE:      Res = ExpandIntOp_ATOMIC_STORE(N); break;
   case ISD::STACKMAP:
     Res = ExpandIntOp_STACKMAP(N, OpNo);
@@ -5398,6 +5435,10 @@ SDValue DAGTypeLegalizer::ExpandIntOp_Shift(SDNode *N) {
   return SDValue(DAG.UpdateNodeOperands(N, N->getOperand(0), Lo), 0);
 }
 
+SDValue DAGTypeLegalizer::ExpandIntOp_CMP(SDNode *N) {
+  return TLI.expandCMP(N, DAG);
+}
+
 SDValue DAGTypeLegalizer::ExpandIntOp_RETURNADDR(SDNode *N) {
   // The argument of RETURNADDR / FRAMEADDR builtin is 32 bit contant.  This
   // surely makes pretty nice problems on 8/16 bit targets. Just truncate this
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 4b06e19656ce6..74ab2f44149fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -324,6 +324,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_Overflow(SDNode *N);
   SDValue PromoteIntRes_FFREXP(SDNode *N);
   SDValue PromoteIntRes_SADDSUBO(SDNode *N, unsigned ResNo);
+  SDValue PromoteIntRes_CMP(SDNode *N);
   SDValue PromoteIntRes_Select(SDNode *N);
   SDValue PromoteIntRes_SELECT_CC(SDNode *N);
   SDValue PromoteIntRes_SETCC(SDNode *N);
@@ -375,6 +376,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntOp_SELECT_CC(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_SETCC(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_Shift(SDNode *N);
+  SDValue PromoteIntOp_CMP(SDNode *N);
   SDValue PromoteIntOp_FunnelShift(SDNode *N);
   SDValue PromoteIntOp_SIGN_EXTEND(SDNode *N);
   SDValue PromoteIntOp_VP_SIGN_EXTEND(SDNode *N);
@@ -457,6 +459,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
 
   void ExpandIntRes_MINMAX            (SDNode *N, SDValue &Lo, SDValue &Hi);
 
+  void ExpandIntRes_CMP               (SDNode *N, SDValue &Lo, SDValue &Hi);
+
   void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_XMULO             (SDNode *N, SDValue &Lo, SDValue &Hi);
@@ -485,6 +489,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue ExpandIntOp_SETCC(SDNode *N);
   SDValue ExpandIntOp_SETCCCARRY(SDNode *N);
   SDValue ExpandIntOp_Shift(SDNode *N);
+  SDValue ExpandIntOp_CMP(SDNode *N);
   SDValue ExpandIntOp_STORE(StoreSDNode *N, unsigned OpNo);
   SDValue ExpandIntOp_TRUNCATE(SDNode *N);
   SDValue ExpandIntOp_XINT_TO_FP(SDNode *N);
@@ -779,6 +784,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void ScalarizeVectorResult(SDNode *N, unsigned ResNo);
   SDValue ScalarizeVecRes_MERGE_VALUES(SDNode *N, unsigned ResNo);
   SDValue ScalarizeVecRes_BinOp(SDNode *N);
+  SDValue ScalarizeVecRes_CMP(SDNode *N);
   SDValue ScalarizeVecRes_TernaryOp(SDNode *N);
   SDValue ScalarizeVecRes_UnaryOp(SDNode *N);
   SDValue ScalarizeVecRes_StrictFPOp(SDNode *N);
@@ -852,6 +858,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVectorResult(SDNode *N, unsigned ResNo);
   void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_CMP(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_FFREXP(SDNode *N, unsigned ResNo, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_ExtendOp(SDNode *N, SDValue &Lo, SDValue &Hi);
@@ -915,6 +922,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SplitVecOp_VSETCC(SDNode *N);
   SDValue SplitVecOp_FP_ROUND(SDNode *N);
   SDValue SplitVecOp_FPOpDifferentTypes(SDNode *N);
+  SDValue SplitVecOp_CMP(SDNode *N);
   SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
   SDValue SplitVecOp_VP_CttzElements(SDNode *N);
 
@@ -982,6 +990,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
 
   SDValue WidenVecRes_Ternary(SDNode *N);
   SDValue WidenVecRes_Binary(SDNode *N);
+  SDValue WidenVecRes_CMP(SDNode *N);
   SDValue WidenVecRes_BinaryCanTrap(SDNode *N);
   SDValue WidenVecRes_BinaryWithExtraScalarOp(SDNode *N);
   SDValue WidenVecRes_StrictFP(SDNode *N);
@@ -1001,6 +1010,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecOp_BITCAST(SDNode *N);
   SDValue WidenVecOp_CONCAT_VECTORS(SDNode *N);
   SDValue WidenVecOp_EXTEND(SDNode *N);
+  SDValue WidenVecOp_CMP(SDNode *N);
   SDValue WidenVecOp_EXTRACT_VECTOR_ELT(SDNode *N);
   SDValue WidenVecOp_INSERT_SUBVECTOR(SDNode *N);
   SDValue WidenVecOp_EXTRACT_SUBVECTOR(SDNode *N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 423df9ae6b2a5..3a1c11da24075 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -442,6 +442,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::FP_TO_SINT_SAT:
   case ISD::FP_TO_UINT_SAT:
   case ISD::MGATHER:
+  case ISD::SCMP:
+  case ISD::UCMP:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   case ISD::SMULFIX:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index cab4dc5f3c156..0fe6e9e0252a4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -164,6 +164,12 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::ROTR:
     R = ScalarizeVecRes_BinOp(N);
     break;
+
+  case ISD::SCMP:
+  case ISD::UCMP:
+    R = ScalarizeVecRes_CMP(N);
+    break;
+
   case ISD::FMA:
   case ISD::FSHL:
   case ISD::FSHR:
@@ -213,6 +219,27 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_BinOp(SDNode *N) {
                      LHS.getValueType(), LHS, RHS, N->getFlags());
 }
 
+SDValue DAGTypeLegalizer::ScalarizeVecRes_CMP(SDNode *N) {
+  SDLoc DL(N);
+
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  if (getTypeAction(LHS.getValueType()) ==
+      TargetLowering::TypeScalarizeVector) {
+    LHS = GetScalarizedVector(LHS);
+    RHS = GetScalarizedVector(RHS);
+  } else {
+    EVT VT = LHS.getValueType().getVectorElementType();
+    LHS = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, LHS,
+                      DAG.getVectorIdxConstant(0, DL));
+    RHS = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, RHS,
+                      DAG.getVectorIdxConstant(0, DL));
+  }
+
+  return DAG.getNode(N->getOpcode(), SDLoc(N),
+                     N->getValueType(0).getVectorElementType(), LHS, RHS);
+}
+
 SDValue DAGTypeLegalizer::ScalarizeVecRes_TernaryOp(SDNode *N) {
   SDValue Op0 = GetScalarizedVector(N->getOperand(0));
   SDValue Op1 = GetScalarizedVector(N->getOperand(1));
@@ -1184,6 +1211,10 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     SplitVecRes_TernaryOp(N, Lo, Hi);
     break;
 
+  case ISD::SCMP: case ISD::UCMP:
+    SplitVecRes_CMP(N, Lo, Hi);
+    break;
+
 #define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)               \
   case ISD::STRICT_##DAGN:
 #include "llvm/IR/ConstrainedOps.def"
@@ -1327,6 +1358,21 @@ void DAGTypeLegalizer::SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo,
                    {Op0Hi, Op1Hi, Op2Hi, MaskHi, EVLHi}, Flags);
 }
 
+void DAGTypeLegalizer::SplitVecRes_CMP(SDNode *N, SDValue &Lo, SDValue &Hi) {
+  LLVMContext &Ctxt = *DAG.getContext();
+  SDLoc dl(N);
+
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  EVT SplitResVT = N->getValueType(0).getHalfNumVectorElementsVT(Ctxt);
+
+  auto [LHSLo, LHSHi] = DAG.SplitVector(LHS, dl);
+  auto [RHSLo, RHSHi] = DAG.SplitVector(RHS, dl);
+
+  Lo = DAG.getNode(N->getOpcode(), dl, SplitResVT, LHSLo, RHSLo);
+  Hi = DAG.getNode(N->getOpcode(), dl, SplitResVT, LHSHi, RHSHi);
+}
+
 void DAGTypeLegalizer::SplitVecRes_FIX(SDNode *N, SDValue &Lo, SDValue &Hi) {
   SDValue LHSLo, LHSHi;
   GetSplitVector(N->getOperand(0), LHSLo, LHSHi);
@@ -3054,6 +3100,11 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
     Res = SplitVecOp_FPOpDifferentTypes(N);
     break;
 
+  case ISD::SCMP:
+  case ISD::UCMP:
+    Res = SplitVecOp_CMP(N);
+    break;
+
   case ISD::ANY_EXTEND_VECTOR_INREG:
   case ISD::SIGN_EXTEND_VECTOR_INREG:
   case ISD::ZERO_EXTEND_VECTOR_INREG:
@@ -4043,6 +4094,25 @@ SDValue DAGTypeLegalizer::SplitVecOp_FPOpDifferentTypes(SDNode *N) {
   return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Lo, Hi);
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_CMP(SDNode *N) {
+  LLVMContext &Ctxt = *DAG.getContext();
+  SDLoc dl(N);
+
+  SDValue LHSLo, LHSHi, RHSLo, RHSHi;
+  GetSplitVector(N->getOperand(0), LHSLo, LHSHi);
+  GetSplitVector(N->getOperand(1), RHSLo, RHSHi);
+
+  EVT ResVT = N->getValueType(0);
+  ElementCount SplitOpEC = LHSLo.getValueType().getVectorElementCount();
+  EVT NewResVT =
+      EVT::getVectorVT(Ctxt, ResVT.getVectorElementType(), SplitOpEC);
+
+  SDValue Lo = DAG.getNode(N->getOpcode(), dl, NewResVT, LHSLo, RHSLo);
+  SDValue Hi = DAG.getNode(N->getOpcode(), dl, NewResVT, LHSHi, RHSHi);
+
+  return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
+}
+
 SDValue DAGTypeLegalizer::SplitVecOp_FP_TO_XINT_SAT(SDNode *N) {
   EVT ResVT = N->getValueType(0);
   SDValue Lo, Hi;
@@ -4220,6 +4290,11 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
     Res = WidenVecRes_Binary(N);
     break;
 
+  case ISD::SCMP:
+  case ISD::UCMP:
+    Res = WidenVecRes_CMP(N);
+    break;
+
   case ISD::FPOW:
   case ISD::FREM:
     if (unrollExpandedOp())
@@ -4426,6 +4501,42 @@ SDValue DAGTypeLegalizer::WidenVecRes_Binary(SDNode *N) {
                      {InOp1, InOp2, Mask, N->getOperand(3)}, N->getFlags());
 }
 
+SDValue DAGTypeLegalizer::WidenVecRes_CMP(SDNode *N) {
+  LLVMContext &Ctxt = *DAG.getContext();
+  SDLoc dl(N);
+
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  EVT OpVT = LHS.getValueType();
+  EVT OpElementVT = OpVT.getVectorElementType();
+  EVT TransformedOpVT = TLI.getTypeToTransformTo(Ctxt, OpVT);
+  if (TransformedOpVT.getVectorNumElements() > OpVT.getVectorNumElements()) {
+    LHS = GetWidenedVector(LHS);
+    RHS = GetWidenedVector(RHS);
+  }
+
+  EVT WidenResVT = TLI.getTypeToTransformTo(Ctxt, N->getValueType(0));
+  ElementCount WidenResEC = WidenResVT.getVectorElementCount();
+  EVT WidenResElementVT = WidenResVT.getVectorElementType();
+
+  SDValue CMP = DAG.getNode(N->getOpcode(), dl, LHS.getValueType(), LHS, RHS);
+
+  EVT WideUndefVectorVT = EVT::getVectorVT(Ctxt, OpElementVT, WidenResEC);
+  SDValue WideUndefValue = DAG.getUNDEF(WideUndefVectorVT);
+
+  SDValue WideResult =
+      DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideUndefVectorVT, WideUndefValue,
+                  CMP, DAG.getVectorIdxConstant(0, dl));
+
+  ISD::NodeType ExtendCode;
+  if (OpElementVT.getSizeInBits() > WidenResElementVT.getSizeInBits()) {
+    ExtendCode = ISD::TRUNCATE;
+  } else {
+    ExtendCode = (N->getOpcode() == ISD::SCMP ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND);
+  }
+  return DAG.getNode(ExtendCode, dl, WidenResVT, WideResult);
+}
+
 SDValue DAGTypeLegalizer::WidenVecRes_BinaryWithExtraScalarOp(SDNode *N) {
   // Binary op widening, but with an extra operand that shouldn't be widened.
   SDLoc dl(N);
@@ -6129,6 +6240,11 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
     Res = WidenVecOp_EXTEND(N);
     break;
 
+  case ISD::SCMP:
+  case ISD::UCMP:
+    Res = WidenVecOp_CMP(N);
+    break;
+
   case ISD::FP_EXTEND:
   case ISD::STRICT_FP_EXTEND:
   case ISD::FP_ROUND:
@@ -6273,6 +6389,32 @@ SDValue DAGTypeLegalizer::WidenVecOp_EXTEND(SDNode *N) {
   }
 }
 
+SDValue DAGTypeLegalizer::WidenVecOp_CMP(SDNode *N) {
+  SDLoc dl(N);
+
+  EVT OpVT = N->getOperand(0).getValueType();
+  EVT ResVT = N->getValueType(0);
+  SDValue LHS = GetWidenedVector(N->getOperand(0));
+  SDValue RHS = GetWidenedVector(N->getOperand(1));
+
+  // 1. EXTRACT_SUBVECTOR
+  // 2. SIGN_EXTEND/ZERO_EXTEND
+  // 3. CMP
+  LHS = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, LHS,
+                    DAG.getVectorIdxConstant(0, dl));
+  RHS = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, RHS,
+                    DAG.getVectorIdxConstant(0, dl));
+
+  // At this point the result type is guaranteed to be valid, so we can use it
+  // as the operand type by extending it appropriately
+  ISD::NodeType ExtendOpcode =
+      N->getOpcode() == ISD::SCMP ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+  LHS = DAG.getNode(ExtendOpcode, dl, ResVT, LHS);
+  RHS = DAG.getNode(ExtendOpcode, dl, ResVT, RHS);
+
+  return DAG.getNode(N->getOpcode(), dl, ResVT, LHS, RHS);
+}
+
 SDValue DAGTypeLegalizer::WidenVecOp_UnrollVectorOp(SDNode *N) {
   // The result (and first input) is legal, but the second input is illegal.
   // We can't do much to fix that, so just unroll and let the extracts off of
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index cfd82a342433f..16d8a9816b013 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7143,6 +7143,22 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     setValue(&I, DAG.getNode(ISD::ABS, sdl, Op1.getValueType(), Op1));
     return;
   }
+  case Intrinsic::scmp: {
+    SDValue Op1 = getValue(I.getArgOperand(0));
+    SDValue Op2 = getValue(I.getArgOperand(1));
+    EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(),
+                                                          I.getType());
+    setValue(&I, DAG.getNode(ISD::SCMP, sdl, DestVT, Op1, Op2));
+    break;
+  }
+  case Intrinsic::ucmp: {
+    SDValue Op1 = getValue(I.getArgOperand(0));
+    SDValue Op2 = getValue(I.getArgOperand(1));
+    EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG...
[truncated]

; CHECK-NEXT: punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
; CHECK-NEXT: movdqa %xmm2, %xmm0
; CHECK-NEXT: retq
%1 = call <4 x i32> @llvm.scmp(<4 x i32> %x, <4 x i32> %y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is surprising that we can omit mangling suffix here :)

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Outdated Show resolved Hide resolved
@nikic
Copy link
Contributor

nikic commented May 31, 2024

define <2 x i16> @test_ucmp.2.8.16(<2 x i8> %x, <2 x i8> %y) {
  %1 = call <2 x i16> @llvm.ucmp(<2 x i8> %x, <2 x i8> %y)
  ret <2 x i16> %1
}

Asserts:

llc: /home/npopov/repos/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp:7434: SDValue llvm::SelectionDAG::getNode(unsigned int, const SDLoc &, EVT, SDValue, SDValue, SDValue, const SDNodeFlags): Assertion `(VT.isScalableVector() != N2VT.isScalableVector() || VT.getVectorMinNumElements() >= N2VT.getVectorMinNumElements()) && "Insert subvector must be from smaller vector to larger vector!"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: build/bin/llc test.ll
1.	Running pass 'Function Pass Manager' on module 'test.ll'.
2.	Running pass 'X86 DAG->DAG Instruction Selection' on function '@test_ucmp.2.8.16'
 #0 0x0000000006474828 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (build/bin/llc+0x6474828)
 #1 0x00000000064723ee llvm::sys::RunSignalHandlers() (build/bin/llc+0x64723ee)
 #2 0x0000000006474ed8 SignalHandler(int) Signals.cpp:0:0
 #3 0x00007f2303c53710 __restore_rt (/lib64/libc.so.6+0x40710)
 #4 0x00007f2303cab144 __pthread_kill_implementation (/lib64/libc.so.6+0x98144)
 #5 0x00007f2303c5365e gsignal (/lib64/libc.so.6+0x4065e)
 #6 0x00007f2303c3b902 abort (/lib64/libc.so.6+0x28902)
 #7 0x00007f2303c3b81e _nl_load_domain.cold (/lib64/libc.so.6+0x2881e)
 #8 0x00007f2303c4b977 (/lib64/libc.so.6+0x38977)
 #9 0x0000000006249fd7 llvm::SelectionDAG::getNode(unsigned int, llvm::SDLoc const&, llvm::EVT, llvm::SDValue, llvm::SDValue, llvm::SDValue, llvm::SDNodeFlags) (build/bin/llc+0x6249fd7)
#10 0x00000000062f4a41 llvm::DAGTypeLegalizer::WidenVecRes_CMP(llvm::SDNode*) LegalizeVectorTypes.cpp:0:0
#11 0x00000000062ed2bd llvm::DAGTypeLegalizer::WidenVectorResult(llvm::SDNode*, unsigned int) LegalizeVectorTypes.cpp:0:0
#12 0x0000000006294fdb llvm::DAGTypeLegalizer::run() LegalizeTypes.cpp:0:0

I'm using this script to generate test cases:

<?php

function makeType(int $numElems, int $bw): string {
  $ty = 'i' . $bw;
  if ($numElems != 0) {
    return '<' . $numElems . ' x ' . $ty . '>';
  }
  return $ty;
}

$bws = [1, 2, 3, 4, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255, 256, 257];

foreach (['ucmp', 'scmp'] as $intrin) {
  for ($numElems = 0; $numElems < 100; ++$numElems) {
    foreach ($bws as $argBW) {
      foreach ($bws as $resBW) {
        if ($resBW == 1) {
          continue;
        }
        $argType = makeType($numElems, $argBW);
        $resType = makeType($numElems, $resBW);
        echo "define $resType @test_$intrin.$numElems.$argBW.$resBW($argType %x, $argType %y) {\n";
        echo "  %1 = call $resType @llvm.$intrin($argType %x, $argType %y)\n";
        echo "  ret $resType %1\n";
        echo "}\n";
        echo "\n";
      }
    }
  }
}

@Poseydon42
Copy link
Contributor Author

Poseydon42 commented May 31, 2024

Thank you for providing the script, I've tested the current revision of the code against the generated tests and it seems to work on both x86_64 and AArch64. I've reduced the maximum size of vector types from 100 to 40 as it was taking ages otherwise, but I don't think it should change anything.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks mostly good to me -- main uncertainty is around the vector widening implementations, which are pretty unusual. It would be good if an SDAG expert could look over that part at least (@RKSimon maybe?)

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp Outdated Show resolved Hide resolved
llvm/test/CodeGen/X86/ucmp.ll Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp Outdated Show resolved Hide resolved
@nikic
Copy link
Contributor

nikic commented Jun 1, 2024

It would also be good to add some test coverage for other targets -- for those cases, I don't think we need the full legalization test coverage, but mainly some scalar tests with common types, so we can see how the codegen looks like and have a baseline when improving it later. (The currently used expansion is probably not the best default expansion, but it will be easier to judge that with baseline test coverage.)

@Poseydon42
Copy link
Contributor Author

Added the tests that were suggested (for AArch64 and x86 only for now, but if any other architectures are suggested I'd be happy to add those too). Also fixed the slightly misleading comment in WidenVecRes_CMP and addressed the comment about checking whether operands should also be split in SplitVecRes_CMP.

@@ -677,6 +677,12 @@ enum NodeType {
UMIN,
UMAX,

/// [US]CMP - 3-way comparison of signed or unsigned integers. Returns -1, 0,
/// or 1 depending on whether Op0 <, ==, or > Op1. The operands can have type
/// different to the result.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we enforce that for vector types the result type should match the operand type? Treating these as close to regular (non-commutable) binops as possible makes sense to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like allowing them to be different isn't that much of an issue, while potentially allowing us to emit more optimal code (maybe?). After all, if you're comparing two vectors of i128s you don't need the result to be a vector of i128 as well, and I'm not sure whether forcing the return type to be the same as operands in this scenario would be handled optimally by other optimizations in SDAG pipeline.

llvm/test/CodeGen/X86/scmp.ll Outdated Show resolved Hide resolved
@nikic
Copy link
Contributor

nikic commented Jun 12, 2024

This one asserts with -mattr=+avx512f:

define <33 x i2> @test_ucmp.33.4.2(<33 x i4> %x, <33 x i4> %y) {
  %1 = call <33 x i2> @llvm.ucmp(<33 x i4> %x, <33 x i4> %y)
  ret <33 x i2> %1
}

@Poseydon42
Copy link
Contributor Author

Fixed the crash with AVX-512. WidenVecRes_CMP now unrolls the operation into a bunch of scalar operations and then rebuilds a vector whenever the widened operand and result have a different number of elements. There is probably a way to avoid doing this in every situation, but I think it's worth adding that later and in a separate PR.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nikic nikic merged commit 995835f into llvm:main Jun 17, 2024
6 of 7 checks passed
tschuett added a commit to tschuett/llvm-project that referenced this pull request Jul 21, 2024
sgundapa pushed a commit to sgundapa/upstream_effort that referenced this pull request Jul 23, 2024
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
#91871
#98774

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251536
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants