Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SelectionDAG][RISCV] Move VP_REDUCE* legalization to LegalizeDAG.cpp. #90522

Merged
merged 2 commits into from
Apr 30, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Apr 29, 2024

LegalizeVectorType is responsible for legalizing nodes that perform an operation on each element may need to scalarize.

This is not true for nodes like VP_REDUCE.*, BUILD_VECTOR, SHUFFLE_VECTOR, EXTRACT_SUBVECTOR, etc.

This patch drops any nodes with a scalar result from LegalizeVectorOps and handles them in LegalizeDAG instead.

This required moving the reduction promotion to LegalizeDAG. I have removed the support integer promotion as it was incorrect for integer min/max reductions. Since it was untested, it was best to assert on it until it was really needed.

There are a couple regressions that can be fixed with a small DAG combine which I will do as a follow up.

LegalizeVectorType is responsible for legalizing nodes that perform an
operation on each element may need to scalarize.

This is not true for nodes like VP_REDUCE.*, BUILD_VECTOR,
SHUFFLE_VECTOR, EXTRACT_SUBVECTOR, etc.

This patch drops any nodes with a scalar result from LegalizeVectorOps
and handles them in LegalizeDAG instead.

This required moving the reduction promotion to LegalizeDAG. I have
removed the support integer promotion as it was incorrect for integer
min/max reductions. Since it was untested, it was best to assert on it
until it was really needed.

There are a couple regressions that can be fixed with a small DAG combine
which I will do as a follow up.
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Apr 29, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 29, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Craig Topper (topperc)

Changes

LegalizeVectorType is responsible for legalizing nodes that perform an operation on each element may need to scalarize.

This is not true for nodes like VP_REDUCE.*, BUILD_VECTOR, SHUFFLE_VECTOR, EXTRACT_SUBVECTOR, etc.

This patch drops any nodes with a scalar result from LegalizeVectorOps and handles them in LegalizeDAG instead.

This required moving the reduction promotion to LegalizeDAG. I have removed the support integer promotion as it was incorrect for integer min/max reductions. Since it was untested, it was best to assert on it until it was really needed.

There are a couple regressions that can be fixed with a small DAG combine which I will do as a follow up.


Full diff: https://github.com/llvm/llvm-project/pull/90522.diff

4 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+63-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+5-68)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll (+14-12)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll (+5-2)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 46e54b5366d66a..9dd40531abb005 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -180,6 +180,13 @@ class SelectionDAGLegalize {
                              SmallVectorImpl<SDValue> &Results);
   SDValue PromoteLegalFP_TO_INT_SAT(SDNode *Node, const SDLoc &dl);
 
+  /// Implements vector reduce operation promotion.
+  ///
+  /// All vector operands are promoted to a vector type with larger element
+  /// type, and the start value is promoted to a larger scalar type. Then the
+  /// result is truncated back to the original scalar type.
+  void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
+
   SDValue ExpandPARITY(SDValue Op, const SDLoc &dl);
 
   SDValue ExpandExtractFromVectorThroughStack(SDValue Op);
@@ -2979,6 +2986,49 @@ SDValue SelectionDAGLegalize::ExpandPARITY(SDValue Op, const SDLoc &dl) {
   return DAG.getNode(ISD::AND, dl, VT, Result, DAG.getConstant(1, dl, VT));
 }
 
+void SelectionDAGLegalize::PromoteReduction(SDNode *Node,
+                                            SmallVectorImpl<SDValue> &Results) {
+  MVT VecVT = Node->getOperand(1).getSimpleValueType();
+  MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
+  MVT ScalarVT = Node->getSimpleValueType(0);
+  MVT NewScalarVT = NewVecVT.getVectorElementType();
+
+  SDLoc DL(Node);
+  SmallVector<SDValue, 4> Operands(Node->getNumOperands());
+
+  // promote the initial value.
+  // FIXME: Support integer.
+  assert(Node->getOperand(0).getValueType().isFloatingPoint() &&
+         "Only FP promotion is supported");
+  Operands[0] =
+      DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
+
+  for (unsigned j = 1; j != Node->getNumOperands(); ++j)
+    if (Node->getOperand(j).getValueType().isVector() &&
+        !(ISD::isVPOpcode(Node->getOpcode()) &&
+          ISD::getVPMaskIdx(Node->getOpcode()) == j)) { // Skip mask operand.
+      // promote the vector operand.
+      // FIXME: Support integer.
+      assert(Node->getOperand(j).getValueType().isFloatingPoint() &&
+             "Only FP promotion is supported");
+      Operands[j] =
+          DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
+    } else {
+      Operands[j] = Node->getOperand(j); // Skip VL operand.
+    }
+
+  SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
+                            Node->getFlags());
+
+  if (ScalarVT.isFloatingPoint())
+    Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
+                      DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+  else
+    Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
+
+  Results.push_back(Res);
+}
+
 bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
   LLVM_DEBUG(dbgs() << "Trying to expand node\n");
   SmallVector<SDValue, 8> Results;
@@ -4955,7 +5005,12 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
   if (Node->getOpcode() == ISD::STRICT_UINT_TO_FP ||
       Node->getOpcode() == ISD::STRICT_SINT_TO_FP ||
       Node->getOpcode() == ISD::STRICT_FSETCC ||
-      Node->getOpcode() == ISD::STRICT_FSETCCS)
+      Node->getOpcode() == ISD::STRICT_FSETCCS ||
+      Node->getOpcode() == ISD::VP_REDUCE_FADD ||
+      Node->getOpcode() == ISD::VP_REDUCE_FMUL ||
+      Node->getOpcode() == ISD::VP_REDUCE_FMAX ||
+      Node->getOpcode() == ISD::VP_REDUCE_FMIN ||
+      Node->getOpcode() == ISD::VP_REDUCE_SEQ_FADD)
     OVT = Node->getOperand(1).getSimpleValueType();
   if (Node->getOpcode() == ISD::BR_CC ||
       Node->getOpcode() == ISD::SELECT_CC)
@@ -5613,6 +5668,13 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
                     DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)));
     break;
   }
+  case ISD::VP_REDUCE_FADD:
+  case ISD::VP_REDUCE_FMUL:
+  case ISD::VP_REDUCE_FMAX:
+  case ISD::VP_REDUCE_FMIN:
+  case ISD::VP_REDUCE_SEQ_FADD:
+    PromoteReduction(Node, Results);
+    break;
   }
 
   // Replace the original node with the legalized result.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8f87ee8e09393a..423df9ae6b2a55 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -176,13 +176,6 @@ class VectorLegalizer {
   /// truncated back to the original type.
   void PromoteFP_TO_INT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
-  /// Implements vector reduce operation promotion.
-  ///
-  /// All vector operands are promoted to a vector type with larger element
-  /// type, and the start value is promoted to a larger scalar type. Then the
-  /// result is truncated back to the original scalar type.
-  void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
-
   /// Implements vector setcc operation promotion.
   ///
   /// All vector operands are promoted to a vector type with larger element
@@ -510,6 +503,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       if (Action != TargetLowering::Legal)                                     \
         break;                                                                 \
     }                                                                          \
+    /* Defer non-vector results to LegalizeDAG. */                             \
+    if (!Node->getValueType(0).isVector()) {                                   \
+      Action = TargetLowering::Legal;                                          \
+      break;                                                                   \
+    }                                                                          \
     Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT);            \
   } break;
 #include "llvm/IR/VPIntrinsics.def"
@@ -580,50 +578,6 @@ bool VectorLegalizer::LowerOperationWrapper(SDNode *Node,
   return true;
 }
 
-void VectorLegalizer::PromoteReduction(SDNode *Node,
-                                       SmallVectorImpl<SDValue> &Results) {
-  MVT VecVT = Node->getOperand(1).getSimpleValueType();
-  MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
-  MVT ScalarVT = Node->getSimpleValueType(0);
-  MVT NewScalarVT = NewVecVT.getVectorElementType();
-
-  SDLoc DL(Node);
-  SmallVector<SDValue, 4> Operands(Node->getNumOperands());
-
-  // promote the initial value.
-  if (Node->getOperand(0).getValueType().isFloatingPoint())
-    Operands[0] =
-        DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
-  else
-    Operands[0] =
-        DAG.getNode(ISD::ANY_EXTEND, DL, NewScalarVT, Node->getOperand(0));
-
-  for (unsigned j = 1; j != Node->getNumOperands(); ++j)
-    if (Node->getOperand(j).getValueType().isVector() &&
-        !(ISD::isVPOpcode(Node->getOpcode()) &&
-          ISD::getVPMaskIdx(Node->getOpcode()) == j)) // Skip mask operand.
-      // promote the vector operand.
-      if (Node->getOperand(j).getValueType().isFloatingPoint())
-        Operands[j] =
-            DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
-      else
-        Operands[j] =
-            DAG.getNode(ISD::ANY_EXTEND, DL, NewVecVT, Node->getOperand(j));
-    else
-      Operands[j] = Node->getOperand(j); // Skip VL operand.
-
-  SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
-                            Node->getFlags());
-
-  if (ScalarVT.isFloatingPoint())
-    Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
-                      DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
-  else
-    Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
-
-  Results.push_back(Res);
-}
-
 void VectorLegalizer::PromoteSETCC(SDNode *Node,
                                    SmallVectorImpl<SDValue> &Results) {
   MVT VecVT = Node->getOperand(0).getSimpleValueType();
@@ -708,23 +662,6 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
     // Promote the operation by extending the operand.
     PromoteFP_TO_INT(Node, Results);
     return;
-  case ISD::VP_REDUCE_ADD:
-  case ISD::VP_REDUCE_MUL:
-  case ISD::VP_REDUCE_AND:
-  case ISD::VP_REDUCE_OR:
-  case ISD::VP_REDUCE_XOR:
-  case ISD::VP_REDUCE_SMAX:
-  case ISD::VP_REDUCE_SMIN:
-  case ISD::VP_REDUCE_UMAX:
-  case ISD::VP_REDUCE_UMIN:
-  case ISD::VP_REDUCE_FADD:
-  case ISD::VP_REDUCE_FMUL:
-  case ISD::VP_REDUCE_FMAX:
-  case ISD::VP_REDUCE_FMIN:
-  case ISD::VP_REDUCE_SEQ_FADD:
-    // Promote the operation by extending the operand.
-    PromoteReduction(Node, Results);
-    return;
   case ISD::VP_SETCC:
   case ISD::SETCC:
     // Promote the operation by extending the operand.
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
index 02a989a9699606..b874a4477f5d17 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
@@ -802,25 +802,27 @@ define signext i32 @vpreduce_xor_v64i32(i32 signext %s, <64 x i32> %v, <64 x i1>
 ; CHECK-LABEL: vpreduce_xor_v64i32:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, ta, ma
-; CHECK-NEXT:    li a3, 32
 ; CHECK-NEXT:    vslidedown.vi v24, v0, 4
-; CHECK-NEXT:    mv a2, a1
-; CHECK-NEXT:    bltu a1, a3, .LBB49_2
+; CHECK-NEXT:    addi a2, a1, -32
+; CHECK-NEXT:    sltu a3, a1, a2
+; CHECK-NEXT:    addi a3, a3, -1
+; CHECK-NEXT:    li a4, 32
+; CHECK-NEXT:    and a2, a3, a2
+; CHECK-NEXT:    bltu a1, a4, .LBB49_2
 ; CHECK-NEXT:  # %bb.1:
-; CHECK-NEXT:    li a2, 32
+; CHECK-NEXT:    li a1, 32
 ; CHECK-NEXT:  .LBB49_2:
 ; CHECK-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
 ; CHECK-NEXT:    vmv.s.x v25, a0
-; CHECK-NEXT:    vsetvli zero, a2, e32, m8, ta, ma
+; CHECK-NEXT:    vsetvli zero, a1, e32, m8, ta, ma
 ; CHECK-NEXT:    vredxor.vs v25, v8, v25, v0.t
-; CHECK-NEXT:    addi a0, a1, -32
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    addi a1, a1, -1
-; CHECK-NEXT:    and a0, a1, a0
-; CHECK-NEXT:    vsetvli zero, a0, e32, m8, ta, ma
-; CHECK-NEXT:    vmv1r.v v0, v24
-; CHECK-NEXT:    vredxor.vs v25, v16, v25, v0.t
 ; CHECK-NEXT:    vmv.x.s a0, v25
+; CHECK-NEXT:    vsetivli zero, 1, e32, m8, ta, ma
+; CHECK-NEXT:    vmv.s.x v8, a0
+; CHECK-NEXT:    vsetvli zero, a2, e32, m8, ta, ma
+; CHECK-NEXT:    vmv1r.v v0, v24
+; CHECK-NEXT:    vredxor.vs v8, v16, v8, v0.t
+; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    ret
   %r = call i32 @llvm.vp.reduce.xor.v64i32(i32 %s, <64 x i32> %v, <64 x i1> %m, i32 %evl)
   ret i32 %r
diff --git a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
index 7bcf37b1af3c8f..95b64cb662a614 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
@@ -1115,10 +1115,13 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, <vscale x 32 x i32> %
 ; CHECK-NEXT:    vmv.s.x v25, a0
 ; CHECK-NEXT:    vsetvli zero, a1, e32, m8, ta, ma
 ; CHECK-NEXT:    vredmaxu.vs v25, v8, v25, v0.t
+; CHECK-NEXT:    vmv.x.s a0, v25
+; CHECK-NEXT:    vsetivli zero, 1, e32, m8, ta, ma
+; CHECK-NEXT:    vmv.s.x v8, a0
 ; CHECK-NEXT:    vsetvli zero, a2, e32, m8, ta, ma
 ; CHECK-NEXT:    vmv1r.v v0, v24
-; CHECK-NEXT:    vredmaxu.vs v25, v16, v25, v0.t
-; CHECK-NEXT:    vmv.x.s a0, v25
+; CHECK-NEXT:    vredmaxu.vs v8, v16, v8, v0.t
+; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    ret
   %r = call i32 @llvm.vp.reduce.umax.nxv32i32(i32 %s, <vscale x 32 x i32> %v, <vscale x 32 x i1> %m, i32 %evl)
   ret i32 %r

@topperc topperc requested a review from mshockwave April 29, 2024 21:07
@topperc
Copy link
Collaborator Author

topperc commented Apr 29, 2024

The regressions will be fixed by #90524

Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
else
Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a dead code now?

Copy link
Contributor

@wangpc-pp wangpc-pp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@topperc
Copy link
Collaborator Author

topperc commented Apr 30, 2024

I notice we might be missing support for non-VP reductions for f16 vectors with Zfhmin.

@topperc topperc merged commit 705636a into llvm:main Apr 30, 2024
3 of 4 checks passed
@topperc topperc deleted the pr/reduction-legalization branch April 30, 2024 05:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants