Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Use masked pseudo peephole for reduction pseudos #71508

Merged
merged 1 commit into from
Nov 8, 2023

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Nov 7, 2023

After #71483 we now have a way of marking masked pseudos as having an unmasked
equivalent, but their mask shouldn't be folded unless it's all ones since it
would affect the result.

This patch uses it to mark the pseudos for vredsum and friends, which in turn
allows us to remove the unmasked patterns, and catch some other forms of vmerge.

After llvm#71483 we now have a way of marking masked pseudos as having an unmasked
equivalent, but their mask shouldn't be folded unless it's all ones since it
would affect the result.

This patch uses it to mark the pseudos for vredsum and friends, which in turn
allows us to remove the unmasked patterns and remove vmerges entirely if it's
known to have an all ones mask.
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 7, 2023

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

After #71483 we now have a way of marking masked pseudos as having an unmasked
equivalent, but their mask shouldn't be folded unless it's all ones since it
would affect the result.

This patch uses it to mark the pseudos for vredsum and friends, which in turn
allows us to remove the unmasked patterns and remove vmerges entirely if it's
known to have an all ones mask.


Full diff: https://github.com/llvm/llvm-project/pull/71508.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td (+4-2)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (-63)
  • (modified) llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll (+4-10)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 83faa5bbef7931f..01a425298c9da28 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -3213,7 +3213,8 @@ multiclass VPseudoTernaryWithTailPolicy<VReg RetClass,
     defvar mx = MInfo.MX;
     let isCommutable = Commutable in
     def "_" # mx # "_E" # sew : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
-    def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>;
+    def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>,
+                                          RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
   }
 }
 
@@ -3232,7 +3233,8 @@ multiclass VPseudoTernaryWithTailPolicyRoundingMode<VReg RetClass,
                                                      Op2Class, Constraint>;
     def "_" # mx # "_E" # sew # "_MASK"
         : VPseudoTernaryMaskPolicyRoundingMode<RetClass, Op1Class,
-                                               Op2Class, Constraint>;
+                                               Op2Class, Constraint>,
+          RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
   }
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index d92d3975d12f533..a27719455642a71 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1381,16 +1381,6 @@ multiclass VPatReductionVL<SDNode vop, string instruction_name, bit is_float> {
   foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
     defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
     let Predicates = GetVTypePredicates<vti>.Predicates in {
-      def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
-                                   (vti.Vector vti.RegClass:$rs1), VR:$rs2,
-                                   (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-          (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-              (vti_m1.Vector VR:$merge),
-              (vti.Vector vti.RegClass:$rs1),
-              (vti_m1.Vector VR:$rs2),
-              GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
-
       def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
                                    (vti.Vector vti.RegClass:$rs1), VR:$rs2,
                                    (vti.Mask V0), VLOpFrag,
@@ -1408,19 +1398,6 @@ multiclass VPatReductionVL_RM<SDNode vop, string instruction_name, bit is_float>
   foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
     defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
     let Predicates = GetVTypePredicates<vti>.Predicates in {
-      def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
-                                   (vti.Vector vti.RegClass:$rs1), VR:$rs2,
-                                   (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-          (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-              (vti_m1.Vector VR:$merge),
-              (vti.Vector vti.RegClass:$rs1),
-              (vti_m1.Vector VR:$rs2),
-              // Value to indicate no rounding mode change in
-              // RISCVInsertReadWriteCSR
-              FRM_DYN,
-              GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
-
       def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
                                    (vti.Vector vti.RegClass:$rs1), VR:$rs2,
                                    (vti.Mask V0), VLOpFrag,
@@ -1486,14 +1463,6 @@ multiclass VPatWidenReductionVL<SDNode vop, PatFrags extop, string instruction_n
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1513,18 +1482,6 @@ multiclass VPatWidenReductionVL_RM<SDNode vop, PatFrags extop, string instructio
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2),
-                  // Value to indicate no rounding mode change in
-                  // RISCVInsertReadWriteCSR
-                  FRM_DYN,
-                  GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1548,14 +1505,6 @@ multiclass VPatWidenReductionVL_Ext_VL<SDNode vop, PatFrags extop, string instru
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1575,18 +1524,6 @@ multiclass VPatWidenReductionVL_Ext_VL_RM<SDNode vop, PatFrags extop, string ins
     defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
-                                   VR:$rs2, (vti.Mask true_mask), VLOpFrag,
-                                   (XLenVT timm:$policy))),
-               (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
-                  (wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
-                  (wti_m1.Vector VR:$rs2),
-                  // Value to indicate no rounding mode change in
-                  // RISCVInsertReadWriteCSR
-                  FRM_DYN,
-                  GPR:$vl, vti.Log2SEW,
-                  (XLenVT timm:$policy))>;
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
                                    (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
diff --git a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
index 450ab3cbb0dc369..c639f092444fc43 100644
--- a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
@@ -1049,11 +1049,8 @@ define <vscale x 2 x float> @vfredusum(<vscale x 2 x float> %passthru, <vscale x
 define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, i64 %vl) {
 ; CHECK-LABEL: vredsum_allones_mask:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
-; CHECK-NEXT:    vmv1r.v v11, v8
-; CHECK-NEXT:    vredsum.vs v11, v9, v10
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.v.v v8, v11
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, tu, ma
+; CHECK-NEXT:    vredsum.vs v8, v9, v10
 ; CHECK-NEXT:    ret
   %splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0
   %mask = shufflevector <vscale x 2 x i1> %splat, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
@@ -1070,12 +1067,9 @@ define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <v
 define <vscale x 2 x float> @vfredusum_allones_mask(<vscale x 2 x float> %passthru, <vscale x 2 x float> %x, <vscale x 2 x float> %y, i64 %vl) {
 ; CHECK-LABEL: vfredusum_allones_mask:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, tu, ma
 ; CHECK-NEXT:    fsrmi a0, 0
-; CHECK-NEXT:    vmv1r.v v11, v8
-; CHECK-NEXT:    vfredusum.vs v11, v9, v10
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.v.v v8, v11
+; CHECK-NEXT:    vfredusum.vs v8, v9, v10
 ; CHECK-NEXT:    fsrm a0
 ; CHECK-NEXT:    ret
   %splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0

Copy link
Contributor

@yetingk yetingk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@lukel97 lukel97 merged commit 11c1827 into llvm:main Nov 8, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants