Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Support select/merge like ops for bf16 vectors when have Zvfbfmin #91936

Merged
merged 1 commit into from
Jun 6, 2024

Conversation

jacquesguan
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented May 13, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Jianjian Guan (jacquesguan)

Changes

Patch is 65.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91936.diff

11 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+28-4)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td (+14-1)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td (+5-3)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+3-2)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-select-fp.ll (+124-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll (+140-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll (+52-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/select-fp.ll (+184-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vpmerge-sdnode.ll (+304-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vselect-fp.ll (+220-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll (+76-4)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 60985edd9420e..451d870cca0b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1101,6 +1101,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                             ISD::EXTRACT_SUBVECTOR},
                            VT, Custom);
         setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
+        if (Subtarget.hasStdExtZfbfmin()) {
+          if (Subtarget.hasVInstructionsF16())
+            setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
+          else if (Subtarget.hasVInstructionsF16Minimal())
+            setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+        }
+        setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT,
+                           Custom);
+        setOperationAction(ISD::SELECT_CC, VT, Expand);
         // TODO: Promote to fp32.
       }
     }
@@ -1329,6 +1338,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                               ISD::EXTRACT_SUBVECTOR},
                              VT, Custom);
           setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
+          if (Subtarget.hasStdExtZfbfmin()) {
+            if (Subtarget.hasVInstructionsF16())
+              setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
+            else if (Subtarget.hasVInstructionsF16Minimal())
+              setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+          }
+          setOperationAction(
+              {ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT,
+              Custom);
           // TODO: Promote to fp32.
           continue;
         }
@@ -6700,10 +6718,16 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::BUILD_VECTOR:
     return lowerBUILD_VECTOR(Op, DAG, Subtarget);
   case ISD::SPLAT_VECTOR:
-    if (Op.getValueType().getScalarType() == MVT::f16 &&
-        (Subtarget.hasVInstructionsF16Minimal() &&
-         !Subtarget.hasVInstructionsF16())) {
-      if (Op.getValueType() == MVT::nxv32f16)
+    if ((Op.getValueType().getScalarType() == MVT::f16 &&
+         (Subtarget.hasVInstructionsF16Minimal() &&
+          Subtarget.hasStdExtZfhminOrZhinxmin() &&
+          !Subtarget.hasVInstructionsF16())) ||
+        (Op.getValueType().getScalarType() == MVT::bf16 &&
+         (Subtarget.hasVInstructionsBF16() && Subtarget.hasStdExtZfbfmin() &&
+          Subtarget.hasVInstructionsF16Minimal() &&
+          !Subtarget.hasVInstructionsF16()))) {
+      if (Op.getValueType() == MVT::nxv32f16 ||
+          Op.getValueType() == MVT::nxv32bf16)
         return SplitVectorOp(Op, DAG);
       SDLoc DL(Op);
       SDValue NewScalar =
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 4adc26f628914..e86eceed95710 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -382,7 +382,20 @@ class GetIntVTypeInfo<VTypeInfo vti> {
   // Equivalent integer vector type. Eg.
   //   VI8M1 → VI8M1 (identity)
   //   VF64M4 → VI64M4
-  VTypeInfo Vti = !cast<VTypeInfo>(!subst("VF", "VI", !cast<string>(vti)));
+  VTypeInfo Vti = !cast<VTypeInfo>(!subst("VBF", "VI",
+                                          !subst("VF", "VI",
+                                                 !cast<string>(vti))));
+}
+
+// This functor is used to obtain the fp vector type that has the same SEW and
+// multiplier as the input parameter type.
+class GetFpVTypeInfo<VTypeInfo vti> {
+  // Equivalent integer vector type. Eg.
+  //   VF16M1 → VF16M1 (identity)
+  //   VBF16M1 → VF16M1
+  VTypeInfo Vti = !cast<VTypeInfo>(!subst("VBF", "VF",
+                                          !subst("VI", "VF",
+                                                 !cast<string>(vti))));
 }
 
 class MTypeInfo<ValueType Mas, LMULInfo M, string Bx> {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 714f8cff7b637..db4ca958989ec 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -1433,7 +1433,7 @@ defm : VPatFPSetCCSDNode_VV_VF_FV<SETOLE, "PseudoVMFLE", "PseudoVMFGE">;
 // Floating-point vselects:
 // 11.15. Vector Integer Merge Instructions
 // 13.15. Vector Floating-Point Merge Instruction
-foreach fvti = AllFloatVectors in {
+foreach fvti = !listconcat(AllFloatVectors, AllBFloatVectors) in {
   defvar ivti = GetIntVTypeInfo<fvti>.Vti;
   let Predicates = GetVTypePredicates<ivti>.Predicates in {
     def : Pat<(fvti.Vector (vselect (fvti.Mask V0), fvti.RegClass:$rs1,
@@ -1451,7 +1451,9 @@ foreach fvti = AllFloatVectors in {
                    fvti.RegClass:$rs2, 0, (fvti.Mask V0), fvti.AVL, fvti.Log2SEW)>;
 
   }
-  let Predicates = GetVTypePredicates<fvti>.Predicates in 
+
+  let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
+                               GetVTypeScalarPredicates<fvti>.Predicates) in 
     def : Pat<(fvti.Vector (vselect (fvti.Mask V0),
                                     (SplatFPOp fvti.ScalarRegClass:$rs1),
                                     fvti.RegClass:$rs2)),
@@ -1514,7 +1516,7 @@ foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in {
 //===----------------------------------------------------------------------===//
 
 foreach fvti = !listconcat(AllFloatVectors, AllBFloatVectors) in {
-  let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
+  let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
                                GetVTypeScalarPredicates<fvti>.Predicates) in
     def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl undef, fvti.ScalarRegClass:$rs1, srcvalue)),
               (!cast<Instruction>("PseudoVFMV_V_"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index e10b8bf2767b8..4f6a84ab15143 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2556,7 +2556,7 @@ foreach vti = AllFloatVectors in {
   }
 }
 
-foreach fvti = AllFloatVectors in {
+foreach fvti = !listconcat(AllFloatVectors, AllBFloatVectors) in {
   // Floating-point vselects:
   // 11.15. Vector Integer Merge Instructions
   // 13.15. Vector Floating-Point Merge Instruction
@@ -2591,7 +2591,8 @@ foreach fvti = AllFloatVectors in {
                  GPR:$vl, fvti.Log2SEW)>;
   }
 
-  let Predicates = GetVTypePredicates<fvti>.Predicates in {
+  let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
+                               GetVTypeScalarPredicates<fvti>.Predicates) in {
     def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
                                             (SplatFPOp fvti.ScalarRegClass:$rs1),
                                             fvti.RegClass:$rs2,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-select-fp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-select-fp.ll
index d945cf5616981..7a96aad31f084 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-select-fp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-select-fp.ll
@@ -1,11 +1,11 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -target-abi=ilp32d \
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=ilp32d \
 ; RUN:   -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -target-abi=lp64d \
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=lp64d \
 ; RUN:   -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v,+m -target-abi=ilp32d -riscv-v-vector-bits-min=128 \
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=ilp32d -riscv-v-vector-bits-min=128 \
 ; RUN:   -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v,+m -target-abi=lp64d -riscv-v-vector-bits-min=128 \
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=lp64d -riscv-v-vector-bits-min=128 \
 ; RUN:   -verify-machineinstrs < %s | FileCheck %s
 
 define <2 x half> @select_v2f16(i1 zeroext %c, <2 x half> %a, <2 x half> %b) {
@@ -343,3 +343,123 @@ define <16 x double> @selectcc_v16f64(double %a, double %b, <16 x double> %c, <1
   %v = select i1 %cmp, <16 x double> %c, <16 x double> %d
   ret <16 x double> %v
 }
+
+define <2 x bfloat> @select_v2bf16(i1 zeroext %c, <2 x bfloat> %a, <2 x bfloat> %b) {
+; CHECK-LABEL: select_v2bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vmv.v.x v10, a0
+; CHECK-NEXT:    vmsne.vi v0, v10, 0
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf4, ta, ma
+; CHECK-NEXT:    vmerge.vvm v8, v9, v8, v0
+; CHECK-NEXT:    ret
+  %v = select i1 %c, <2 x bfloat> %a, <2 x bfloat> %b
+  ret <2 x bfloat> %v
+}
+
+define <2 x bfloat> @selectcc_v2bf16(bfloat %a, bfloat %b, <2 x bfloat> %c, <2 x bfloat> %d) {
+; CHECK-LABEL: selectcc_v2bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa1
+; CHECK-NEXT:    fcvt.s.bf16 fa4, fa0
+; CHECK-NEXT:    feq.s a0, fa4, fa5
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vmv.v.x v10, a0
+; CHECK-NEXT:    vmsne.vi v0, v10, 0
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf4, ta, ma
+; CHECK-NEXT:    vmerge.vvm v8, v9, v8, v0
+; CHECK-NEXT:    ret
+  %cmp = fcmp oeq bfloat %a, %b
+  %v = select i1 %cmp, <2 x bfloat> %c, <2 x bfloat> %d
+  ret <2 x bfloat> %v
+}
+
+define <4 x bfloat> @select_v4bf16(i1 zeroext %c, <4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: select_v4bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-NEXT:    vmv.v.x v10, a0
+; CHECK-NEXT:    vmsne.vi v0, v10, 0
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vmerge.vvm v8, v9, v8, v0
+; CHECK-NEXT:    ret
+  %v = select i1 %c, <4 x bfloat> %a, <4 x bfloat> %b
+  ret <4 x bfloat> %v
+}
+
+define <4 x bfloat> @selectcc_v4bf16(bfloat %a, bfloat %b, <4 x bfloat> %c, <4 x bfloat> %d) {
+; CHECK-LABEL: selectcc_v4bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa1
+; CHECK-NEXT:    fcvt.s.bf16 fa4, fa0
+; CHECK-NEXT:    feq.s a0, fa4, fa5
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-NEXT:    vmv.v.x v10, a0
+; CHECK-NEXT:    vmsne.vi v0, v10, 0
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vmerge.vvm v8, v9, v8, v0
+; CHECK-NEXT:    ret
+  %cmp = fcmp oeq bfloat %a, %b
+  %v = select i1 %cmp, <4 x bfloat> %c, <4 x bfloat> %d
+  ret <4 x bfloat> %v
+}
+
+define <8 x bfloat> @select_v8bf16(i1 zeroext %c, <8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: select_v8bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v10, a0
+; CHECK-NEXT:    vmsne.vi v0, v10, 0
+; CHECK-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; CHECK-NEXT:    vmerge.vvm v8, v9, v8, v0
+; CHECK-NEXT:    ret
+  %v = select i1 %c, <8 x bfloat> %a, <8 x bfloat> %b
+  ret <8 x bfloat> %v
+}
+
+define <8 x bfloat> @selectcc_v8bf16(bfloat %a, bfloat %b, <8 x bfloat> %c, <8 x bfloat> %d) {
+; CHECK-LABEL: selectcc_v8bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa1
+; CHECK-NEXT:    fcvt.s.bf16 fa4, fa0
+; CHECK-NEXT:    feq.s a0, fa4, fa5
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.x v10, a0
+; CHECK-NEXT:    vmsne.vi v0, v10, 0
+; CHECK-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; CHECK-NEXT:    vmerge.vvm v8, v9, v8, v0
+; CHECK-NEXT:    ret
+  %cmp = fcmp oeq bfloat %a, %b
+  %v = select i1 %cmp, <8 x bfloat> %c, <8 x bfloat> %d
+  ret <8 x bfloat> %v
+}
+
+define <16 x bfloat> @select_v16bf16(i1 zeroext %c, <16 x bfloat> %a, <16 x bfloat> %b) {
+; CHECK-LABEL: select_v16bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; CHECK-NEXT:    vmv.v.x v12, a0
+; CHECK-NEXT:    vmsne.vi v0, v12, 0
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vmerge.vvm v8, v10, v8, v0
+; CHECK-NEXT:    ret
+  %v = select i1 %c, <16 x bfloat> %a, <16 x bfloat> %b
+  ret <16 x bfloat> %v
+}
+
+define <16 x bfloat> @selectcc_v16bf16(bfloat %a, bfloat %b, <16 x bfloat> %c, <16 x bfloat> %d) {
+; CHECK-LABEL: selectcc_v16bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa1
+; CHECK-NEXT:    fcvt.s.bf16 fa4, fa0
+; CHECK-NEXT:    feq.s a0, fa4, fa5
+; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; CHECK-NEXT:    vmv.v.x v12, a0
+; CHECK-NEXT:    vmsne.vi v0, v12, 0
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vmerge.vvm v8, v10, v8, v0
+; CHECK-NEXT:    ret
+  %cmp = fcmp oeq bfloat %a, %b
+  %v = select i1 %cmp, <16 x bfloat> %c, <16 x bfloat> %d
+  ret <16 x bfloat> %v
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll
index 466448a7a05a2..69a266e060bbf 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll
@@ -1,11 +1,11 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v,+m -target-abi=ilp32d \
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=ilp32d \
 ; RUN:   -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,ZVFH,RV32
-; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v,+m -target-abi=lp64d \
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=lp64d \
 ; RUN:   -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,ZVFH,RV64
-; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v,+m -target-abi=ilp32d \
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=ilp32d \
 ; RUN:   -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,ZVFHMIN,RV32ZVFHMIN
-; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v,+m -target-abi=lp64d \
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=lp64d \
 ; RUN:   -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,ZVFHMIN,RV64ZVFHMIN
 
 declare <4 x i1> @llvm.vp.merge.v4i1(<4 x i1>, <4 x i1>, <4 x i1>, i32)
@@ -1240,3 +1240,139 @@ define <32 x double> @vpmerge_vf_v32f64(double %a, <32 x double> %vb, <32 x i1>
   %v = call <32 x double> @llvm.vp.merge.v32f64(<32 x i1> %m, <32 x double> %va, <32 x double> %vb, i32 %evl)
   ret <32 x double> %v
 }
+
+declare <2 x bfloat> @llvm.vp.merge.v2bf16(<2 x i1>, <2 x bfloat>, <2 x bfloat>, i32)
+
+define <2 x bfloat> @vpmerge_vv_v2bf16(<2 x bfloat> %va, <2 x bfloat> %vb, <2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: vpmerge_vv_v2bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf4, tu, ma
+; CHECK-NEXT:    vmerge.vvm v9, v9, v8, v0
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %v = call <2 x bfloat> @llvm.vp.merge.v2bf16(<2 x i1> %m, <2 x bfloat> %va, <2 x bfloat> %vb, i32 %evl)
+  ret <2 x bfloat> %v
+}
+
+define <2 x bfloat> @vpmerge_vf_v2bf16(bfloat %a, <2 x bfloat> %vb, <2 x i1> %m, i32 zeroext %evl) {
+; ZVFH-LABEL: vpmerge_vf_v2bf16:
+; ZVFH:       # %bb.0:
+; ZVFH-NEXT:    vsetvli zero, a0, e16, mf4, tu, ma
+; ZVFH-NEXT:    vfmerge.vfm v8, v8, fa0, v0
+; ZVFH-NEXT:    ret
+;
+; ZVFHMIN-LABEL: vpmerge_vf_v2bf16:
+; ZVFHMIN:       # %bb.0:
+; ZVFHMIN-NEXT:    fcvt.s.bf16 fa5, fa0
+; ZVFHMIN-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; ZVFHMIN-NEXT:    vfmv.v.f v9, fa5
+; ZVFHMIN-NEXT:    vsetvli zero, a0, e16, mf4, tu, mu
+; ZVFHMIN-NEXT:    vfncvtbf16.f.f.w v8, v9, v0.t
+; ZVFHMIN-NEXT:    ret
+  %elt.head = insertelement <2 x bfloat> poison, bfloat %a, i32 0
+  %va = shufflevector <2 x bfloat> %elt.head, <2 x bfloat> poison, <2 x i32> zeroinitializer
+  %v = call <2 x bfloat> @llvm.vp.merge.v2bf16(<2 x i1> %m, <2 x bfloat> %va, <2 x bfloat> %vb, i32 %evl)
+  ret <2 x bfloat> %v
+}
+
+declare <4 x bfloat> @llvm.vp.merge.v4bf16(<4 x i1>, <4 x bfloat>, <4 x bfloat>, i32)
+
+define <4 x bfloat> @vpmerge_vv_v4bf16(<4 x bfloat> %va, <4 x bfloat> %vb, <4 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: vpmerge_vv_v4bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, tu, ma
+; CHECK-NEXT:    vmerge.vvm v9, v9, v8, v0
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %v = call <4 x bfloat> @llvm.vp.merge.v4bf16(<4 x i1> %m, <4 x bfloat> %va, <4 x bfloat> %vb, i32 %evl)
+  ret <4 x bfloat> %v
+}
+
+define <4 x bfloat> @vpmerge_vf_v4bf16(bfloat %a, <4 x bfloat> %vb, <4 x i1> %m, i32 zeroext %evl) {
+; ZVFH-LABEL: vpmerge_vf_v4bf16:
+; ZVFH:       # %bb.0:
+; ZVFH-NEXT:    vsetvli zero, a0, e16, mf2, tu, ma
+; ZVFH-NEXT:    vfmerge.vfm v8, v8, fa0, v0
+; ZVFH-NEXT:    ret
+;
+; ZVFHMIN-LABEL: vpmerge_vf_v4bf16:
+; ZVFHMIN:       # %bb.0:
+; ZVFHMIN-NEXT:    fcvt.s.bf16 fa5, fa0
+; ZVFHMIN-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
+; ZVFHMIN-NEXT:    vfmv.v.f v9, fa5
+; ZVFHMIN-NEXT:    vsetvli zero, a0, e16, mf2, tu, mu
+; ZVFHMIN-NEXT:    vfncvtbf16.f.f.w v8, v9, v0.t
+; ZVFHMIN-NEXT:    ret
+  %elt.head = insertelement <4 x bfloat> poison, bfloat %a, i32 0
+  %va = shufflevector <4 x bfloat> %elt.head, <4 x bfloat> poison, <4 x i32> zeroinitializer
+  %v = call <4 x bfloat> @llvm.vp.merge.v4bf16(<4 x i1> %m, <4 x bfloat> %va, <4 x bfloat> %vb, i32 %evl)
+  ret <4 x bfloat> %v
+}
+
+declare <8 x bfloat> @llvm.vp.merge.v8bf16(<8 x i1>, <8 x bfloat>, <8 x bfloat>, i32)
+
+define <8 x bfloat> @vpmerge_vv_v8bf16(<8 x bfloat> %va, <8 x bfloat> %vb, <8 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: vpmerge_vv_v8bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e16, m1, tu, ma
+; CHECK-NEXT:    vmerge.vvm v9, v9, v8, v0
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %v = call <8 x bfloat> @llvm.vp.merge.v8bf16(<8 x i1> %m, <8 x bfloat> %va, <8 x bfloat> %vb, i32 %evl)
+  ret <8 x bfloat> %v
+}
+
+define <8 x bfloat> @vpmerge_vf_v8bf16(bfloat %a, <8 x bfloat> %vb, <8 x i1> %m, i32 zeroext %evl) {
+; ZVFH-LABEL: vpmerge_vf_v8bf16:
+; ZVFH:       # %bb.0:
+; ZVFH-NEXT:    vsetvli zero, a0, e16, m1, tu, ma
+; ZVFH-NEXT:    vfmerge.vfm v8, v8, fa0, v0
+; ZVFH-NEXT:    ret
+;
+; ZVFHMIN-LABEL: vpmerge_vf_v8bf16:
+; ZVFHMIN:       # %bb.0:
+; ZVFHMIN-NEXT:    fcvt.s.bf16 fa5, fa0
+; ZVFHMIN-NEXT:    vsetvli a1, zero, e32, m2, ta, ma
+; ZVFHMIN-NEXT:    vfmv.v.f v10, fa5
+; ZVFHMIN-NEXT:    vsetvli zero, a0, e16, m1, tu, mu
+; ZVFHMIN-NEXT:    vfncvtbf16.f.f.w v8, v10, v0.t
+; ZVFHMIN-NEXT:    ret
+  %elt.head = insertelement <8 x bfloat> poison, bfloat %a, i32 0
+  %va = shufflevector <8 x bfloat> %elt.head, <8 x bfloat> poison, <8 x i32> zeroinitializer
+  %v = call <8 x bfloat> @llvm.vp.merge.v8bf16(<8 x i1> %m, <8 x bfloat> %va, <8 x bfloat> %vb, i32 %evl)
+  ret <8 x bfloat> %v
+}
+
+declare <16 x bfloat> @llvm.vp.merge.v16bf16(<16 x i1>, <16 x bfloat>, <16 x bfloat>, i32)
+
+define <16 x bfloat> @vpmerge_vv_v16bf16(<16 x bfloat> %va, <16 x bfloat> %vb, <16 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: vpmerge_vv_v16bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e16, m2, tu, ma
+; CHECK-NEXT:    vmerge.vvm v10, v10, v8, v0
+; CHECK-NEXT:    vmv2r.v v8, v10
+; CHECK-NEXT:    ret
+  %v = call <16 x bfloat> @llvm.vp.merge.v16bf16(<16 x i1> %m, <16 x bfloat> %va, <16 x bfloat> %vb, i32 %evl)
+  ret <16 x bfloat> %v
+}
+
+define <16 x bfloat> @vpmerge_vf_v16bf16(bfloat %a, <16 x bfloat> %vb, <16 x i1> %m, i32 zeroext %evl) {
+; ZVFH-LABEL: vpmerge_vf_v16bf16:
+; ZVFH:       # %bb.0:
+; ZVFH-NEXT:    vsetvli zero, a0, e16, m2, tu, ma
+; ZVFH-NEXT:    vfmerge.vfm v8, v8, fa0, v0
+; ZVFH-NEXT:    ret
+;
+; ZVFHMIN-LABEL: vpmerge_vf_v16bf16:
+; ZVFHMIN:       # %bb.0:
+; ZVFHMIN-NEXT:    fcvt.s.bf16 fa5, fa0
...
[truncated]

@jacquesguan
Copy link
Contributor Author

ping.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jacquesguan jacquesguan merged commit d5ab38f into llvm:main Jun 6, 2024
6 checks passed
@nico
Copy link
Contributor

nico commented Jun 6, 2024

Looks like this breaks tests: http://45.33.8.238/linux/139855/step_11.txt

Please take a look and revert for now if it takes a while to fix.

@joker-eph
Copy link
Collaborator

Reverted in #94565

@jacquesguan jacquesguan deleted the bf16-select branch June 11, 2024 02:13
@@ -1101,6 +1101,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::EXTRACT_SUBVECTOR},
VT, Custom);
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
if (Subtarget.hasStdExtZfbfmin()) {
if (Subtarget.hasVInstructionsF16())
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use vfmv.v.f to splat a bf16 value. If the scalar F register isn't properly nan-boxed, the vector would need to be filled with a bf16 nan. The vfmv.v.f instruction would create an fp16 nan.

let Predicates = GetVTypePredicates<fvti>.Predicates in

let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
GetVTypeScalarPredicates<fvti>.Predicates) in
def : Pat<(fvti.Vector (vselect (fvti.Mask V0),
(SplatFPOp fvti.ScalarRegClass:$rs1),
Copy link
Collaborator

@topperc topperc Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern isn't valid for bf16. If the scalar F register isn't properly nan-boxed it will create the wrong nan value in the vector domain.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants