diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 5025122db3681..7cf6f203fda89 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -1867,6 +1867,43 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { CurDAG->RemoveDeadNode(Node); return; } + case RISCVISD::PPACK_DH: { + assert(Subtarget->enablePExtCodeGen() && Subtarget->isRV32()); + + SDValue Val0 = Node->getOperand(0); + SDValue Val1 = Node->getOperand(1); + SDValue Val2 = Node->getOperand(2); + SDValue Val3 = Node->getOperand(3); + + SDValue Ops[] = { + CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32), Val0, + CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32), Val2, + CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)}; + SDValue RegPair0 = + SDValue(CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, + MVT::Untyped, Ops), + 0); + SDValue Ops1[] = { + CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32), Val1, + CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32), Val3, + CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)}; + SDValue RegPair1 = + SDValue(CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, + MVT::Untyped, Ops1), + 0); + + MachineSDNode *PackDH = CurDAG->getMachineNode( + RISCV::PPACK_DH, DL, MVT::Untyped, {RegPair0, RegPair1}); + + SDValue Lo = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_even, DL, + MVT::i32, SDValue(PackDH, 0)); + SDValue Hi = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_odd, DL, + MVT::i32, SDValue(PackDH, 0)); + ReplaceUses(SDValue(Node, 0), Lo); + ReplaceUses(SDValue(Node, 1), Hi); + CurDAG->RemoveDeadNode(Node); + return; + } case ISD::INTRINSIC_WO_CHAIN: { unsigned IntNo = Node->getConstantOperandVal(0); switch (IntNo) { @@ -2696,7 +2733,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { case ISD::SCALAR_TO_VECTOR: if (Subtarget->enablePExtCodeGen()) { MVT SrcVT = Node->getOperand(0).getSimpleValueType(); - if (VT == MVT::v2i32 && SrcVT == MVT::i64) { + if ((VT == MVT::v2i32 && SrcVT == MVT::i64) || + (VT == MVT::v4i8 && SrcVT == MVT::i32)) { ReplaceUses(SDValue(Node, 0), Node->getOperand(0)); CurDAG->RemoveDeadNode(Node); return; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 6020fb6ca16ce..dd9be0de88737 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -519,6 +519,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand); } else { VTs.append({MVT::v2i16, MVT::v4i8}); + setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom); } setOperationAction(ISD::UADDSAT, VTs, Legal); setOperationAction(ISD::SADDSAT, VTs, Legal); @@ -4434,6 +4435,33 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, SDLoc DL(Op); + if (Subtarget.isRV32() && Subtarget.enablePExtCodeGen()) { + if (VT != MVT::v4i8) + return SDValue(); + + // <4 x i8> BUILD_VECTOR a, b, c, d -> PACK(PPACK.DH pair(a, b), pair(c, d)) + SDValue Val0 = + DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4i8, Op->getOperand(0)); + SDValue Val1 = + DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4i8, Op->getOperand(1)); + SDValue Val2 = + DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4i8, Op->getOperand(2)); + SDValue Val3 = + DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4i8, Op->getOperand(3)); + SDValue PackDH = + DAG.getNode(RISCVISD::PPACK_DH, DL, {MVT::v2i16, MVT::v2i16}, + {Val0, Val1, Val2, Val3}); + + return DAG.getNode( + ISD::BITCAST, DL, MVT::v4i8, + SDValue( + DAG.getMachineNode( + RISCV::PACK, DL, MVT::i32, + {DAG.getNode(ISD::BITCAST, DL, MVT::i32, PackDH.getValue(0)), + DAG.getNode(ISD::BITCAST, DL, MVT::i32, PackDH.getValue(1))}), + 0)); + } + // Proper support for f16 requires Zvfh. bf16 always requires special // handling. We need to cast the scalar to integer and create an integer // build_vector. diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td index 51339d66f6de1..de278aa51d19b 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td @@ -24,6 +24,13 @@ def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> { let RenderMethod = "addSImm8UnsignedOperands"; } +// (<2 x i16>, <2 x i16>) PPACK_DH (<4 x i8>, <4 x i8>, <4 x i8>, <4 x i8>) +def SDT_RISCVPPackDH + : SDTypeProfile<2, 4, [SDTCisVT<0, v2i16>, SDTCisSameAs<0, 1>, + SDTCisVT<2, v4i8>, SDTCisSameAs<0, 3>, + SDTCisSameAs<0, 4>, SDTCisSameAs<0, 5>]>; +def riscv_ppack_dh : RVSDNode<"PPACK_DH", SDT_RISCVPPackDH>; + // A 8-bit signed immediate allowing range [-128, 255] // but represented as [-128, 127]. def simm8_unsigned : RISCVOp, ImmLeaf(Imm);"> { @@ -1530,6 +1537,10 @@ let Predicates = [HasStdExtP, IsRV32] in { def : StPat; def : LdPat; def : LdPat; + + // Build vector patterns + def : Pat<(v2i16 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b))), + (PACK GPR:$a, GPR:$b)>; } // Predicates = [HasStdExtP, IsRV32] let Predicates = [HasStdExtP, IsRV64] in { @@ -1566,4 +1577,29 @@ let Predicates = [HasStdExtP, IsRV64] in { def : LdPat; def : LdPat; def : LdPat; + + // Build vector patterns + def : Pat<(v8i8 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b), + (XLenVT GPR:$c), (XLenVT GPR:$d), + (XLenVT undef), (XLenVT undef), + (XLenVT undef), (XLenVT undef))), + (PPACK_W (PPACK_H GPR:$a, GPR:$b), (PPACK_H GPR:$c, GPR:$d))>; + + def : Pat<(v8i8 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b), + (XLenVT GPR:$c), (XLenVT GPR:$d), + (XLenVT GPR:$e), (XLenVT GPR:$f), + (XLenVT GPR:$g), (XLenVT GPR:$h))), + (PACK(PPACK_W (PPACK_H GPR:$a, GPR:$b), (PPACK_H GPR:$c, GPR:$d)), + (PPACK_W (PPACK_H GPR:$e, GPR:$f), (PPACK_H GPR:$g, GPR:$h)))>; + + def : Pat<(v4i16 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b), + (XLenVT undef), (XLenVT undef))), + (PPACK_W GPR:$a, GPR:$b)>; + + def : Pat<(v4i16 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b), + (XLenVT GPR:$c), (XLenVT GPR:$d))), + (PACK (PPACK_W GPR:$a, GPR:$b), (PPACK_W GPR:$c, GPR:$d))>; + + def : Pat<(v2i32 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b))), + (PACK GPR:$a, GPR:$b)>; } // Predicates = [HasStdExtP, IsRV64] diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll index bb3e691311cd8..79cf5b7903454 100644 --- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll +++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll @@ -523,6 +523,47 @@ define void @test_non_const_splat_i16(ptr %ret_ptr, ptr %a_ptr, i16 %elt) { ret void } +define void @test_build_vector_i8(i8 %a, i8 %c, i8 %b, i8 %d, ptr %ret_ptr) { +; CHECK-RV32-LABEL: test_build_vector_i8: +; CHECK-RV32: # %bb.0: +; CHECK-RV32-NEXT: ppack.dh a0, a0, a2 +; CHECK-RV32-NEXT: pack a0, a0, a1 +; CHECK-RV32-NEXT: sw a0, 0(a4) +; CHECK-RV32-NEXT: ret +; +; CHECK-RV64-LABEL: test_build_vector_i8: +; CHECK-RV64: # %bb.0: +; CHECK-RV64-NEXT: ppack.h a1, a1, a3 +; CHECK-RV64-NEXT: ppack.h a0, a0, a2 +; CHECK-RV64-NEXT: ppack.w a0, a0, a1 +; CHECK-RV64-NEXT: sw a0, 0(a4) +; CHECK-RV64-NEXT: ret + %v0 = insertelement <4 x i8> poison, i8 %a, i32 0 + %v1 = insertelement <4 x i8> %v0, i8 %b, i32 1 + %v2 = insertelement <4 x i8> %v1, i8 %c, i32 2 + %v3 = insertelement <4 x i8> %v2, i8 %d, i32 3 + store <4 x i8> %v3, ptr %ret_ptr + ret void +} + +define void @test_build_vector_i16(ptr %ret_ptr, i16 %a, i16 %b) { +; CHECK-RV32-LABEL: test_build_vector_i16: +; CHECK-RV32: # %bb.0: +; CHECK-RV32-NEXT: pack a1, a1, a2 +; CHECK-RV32-NEXT: sw a1, 0(a0) +; CHECK-RV32-NEXT: ret +; +; CHECK-RV64-LABEL: test_build_vector_i16: +; CHECK-RV64: # %bb.0: +; CHECK-RV64-NEXT: ppack.w a1, a1, a2 +; CHECK-RV64-NEXT: sw a1, 0(a0) +; CHECK-RV64-NEXT: ret + %v0 = insertelement <2 x i16> poison, i16 %a, i32 0 + %v1 = insertelement <2 x i16> %v0, i16 %b, i32 1 + store <2 x i16> %v1, ptr %ret_ptr + ret void +} + ; Intrinsic declarations declare <2 x i16> @llvm.sadd.sat.v2i16(<2 x i16>, <2 x i16>) declare <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16>, <2 x i16>) diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll index f989b025a12dc..36996f0ac7ac8 100644 --- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll +++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll @@ -685,6 +685,59 @@ define void @test_non_const_splat_i32(ptr %ret_ptr, ptr %a_ptr, i32 %elt) { ret void } +define void @test_build_vector_i8(ptr %ret_ptr, i8 %a, i8 %b, i8 %c, i8 %d, i8 %e, i8 %f, i8 %g, i8 %h) { +; CHECK-LABEL: test_build_vector_i8: +; CHECK: # %bb.0: +; CHECK-NEXT: lbu t0, 0(sp) +; CHECK-NEXT: ppack.h a5, a5, a6 +; CHECK-NEXT: ppack.h a3, a3, a4 +; CHECK-NEXT: ppack.h a1, a1, a2 +; CHECK-NEXT: ppack.h a2, a7, t0 +; CHECK-NEXT: ppack.w a2, a5, a2 +; CHECK-NEXT: ppack.w a1, a1, a3 +; CHECK-NEXT: pack a1, a1, a2 +; CHECK-NEXT: sd a1, 0(a0) +; CHECK-NEXT: ret + %v0 = insertelement <8 x i8> poison, i8 %a, i32 0 + %v1 = insertelement <8 x i8> %v0, i8 %b, i32 1 + %v2 = insertelement <8 x i8> %v1, i8 %c, i32 2 + %v3 = insertelement <8 x i8> %v2, i8 %d, i32 3 + %v4 = insertelement <8 x i8> %v3, i8 %e, i32 4 + %v5 = insertelement <8 x i8> %v4, i8 %f, i32 5 + %v6 = insertelement <8 x i8> %v5, i8 %g, i32 6 + %v7 = insertelement <8 x i8> %v6, i8 %h, i32 7 + store <8 x i8> %v7, ptr %ret_ptr + ret void +} + +define void @test_build_vector_i16(ptr %ret_ptr, i16 %a, i16 %b, i16 %c, i16 %d) { +; CHECK-LABEL: test_build_vector_i16: +; CHECK: # %bb.0: +; CHECK-NEXT: ppack.w a3, a3, a4 +; CHECK-NEXT: ppack.w a1, a1, a2 +; CHECK-NEXT: pack a1, a1, a3 +; CHECK-NEXT: sd a1, 0(a0) +; CHECK-NEXT: ret + %v0 = insertelement <4 x i16> poison, i16 %a, i32 0 + %v1 = insertelement <4 x i16> %v0, i16 %b, i32 1 + %v2 = insertelement <4 x i16> %v1, i16 %c, i32 2 + %v3 = insertelement <4 x i16> %v2, i16 %d, i32 3 + store <4 x i16> %v3, ptr %ret_ptr + ret void +} + +define void @test_build_vector_i32(ptr %ret_ptr, i32 %a, i32 %b) { +; CHECK-LABEL: test_build_vector_i32: +; CHECK: # %bb.0: +; CHECK-NEXT: pack a1, a1, a2 +; CHECK-NEXT: sd a1, 0(a0) +; CHECK-NEXT: ret + %v0 = insertelement <2 x i32> poison, i32 %a, i32 0 + %v1 = insertelement <2 x i32> %v0, i32 %b, i32 1 + store <2 x i32> %v1, ptr %ret_ptr + ret void +} + ; Intrinsic declarations declare <4 x i16> @llvm.sadd.sat.v4i16(<4 x i16>, <4 x i16>) declare <4 x i16> @llvm.uadd.sat.v4i16(<4 x i16>, <4 x i16>)