-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[llvm][RISCV] Support rounding mulh for P extension codegen #171593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
4vtomat
wants to merge
2
commits into
llvm:main
Choose a base branch
from
4vtomat:p_ext_codegen_mul2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+474
−36
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
For mulh pattern with operands that are both signed and unsigned, combination is performed automatically. However for mulh with operands which are signed and unsigned respectively we need to combine them manually same as we've done for PASUB*.
In p extension spec, rounding is performed by adding 1 << (elt_bits - 1) to its result. Stack on: llvm#171581
Member
|
@llvm/pr-subscribers-backend-risc-v Author: Brandon Wu (4vtomat) Changes
Patch is 22.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171593.diff 4 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f28772a74d433..c75ac76415714 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15210,18 +15210,26 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
break;
}
case RISCVISD::PASUB:
- case RISCVISD::PASUBU: {
+ case RISCVISD::PASUBU:
+ case RISCVISD::PMULHSU:
+ case RISCVISD::PMULHR:
+ case RISCVISD::PMULHRU:
+ case RISCVISD::PMULHRSU: {
MVT VT = N->getSimpleValueType(0);
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
- assert(VT == MVT::v2i16 || VT == MVT::v4i8);
+ unsigned Opcode = N->getOpcode();
+ // PMULH* variants don't support i8
+ bool IsMulH = Opcode == RISCVISD::PMULHSU || Opcode == RISCVISD::PMULHR ||
+ Opcode == RISCVISD::PMULHRU || Opcode == RISCVISD::PMULHRSU;
+ assert(VT == MVT::v2i16 || (!IsMulH && VT == MVT::v4i8));
MVT NewVT = MVT::v4i16;
if (VT == MVT::v4i8)
NewVT = MVT::v8i8;
SDValue Undef = DAG.getUNDEF(VT);
Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
- Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
+ Results.push_back(DAG.getNode(Opcode, DL, NewVT, {Op0, Op1}));
return;
}
case ISD::EXTRACT_VECTOR_ELT: {
@@ -16331,9 +16339,10 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
}
-// Handle P extension averaging subtraction pattern:
-// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
-// -> PASUB/PASUBU
+// Handle P extension truncate patterns:
+// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
+// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
+// PMULHR*: (trunc (srl (add (mul (sext a), (zext b)), round_const), EltBits))
static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
@@ -16346,7 +16355,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
return SDValue();
- // Check if shift amount is 1
+ // Check if shift amount is a splat constant
SDValue ShAmt = N0.getOperand(1);
if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
return SDValue();
@@ -16360,44 +16369,84 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
if (!C)
return SDValue();
- if (C->getZExtValue() != 1)
- return SDValue();
- // Check for SUB operation
- SDValue Sub = N0.getOperand(0);
- if (Sub.getOpcode() != ISD::SUB)
- return SDValue();
+ SDValue Op = N0.getOperand(0);
+ unsigned ShAmtVal = C->getZExtValue();
+ unsigned EltBits = VecVT.getScalarSizeInBits();
+
+ // Check for rounding pattern: (add (mul ...), round_const)
+ bool IsRounding = false;
+ if (Op.getOpcode() == ISD::ADD && (EltBits == 16 || EltBits == 32)) {
+ SDValue AddRHS = Op.getOperand(1);
+ if (AddRHS.getOpcode() == ISD::BUILD_VECTOR) {
+ if (auto *RndBV = dyn_cast<BuildVectorSDNode>(AddRHS.getNode())) {
+ if (auto *RndC =
+ dyn_cast_or_null<ConstantSDNode>(RndBV->getSplatValue())) {
+ uint64_t ExpectedRnd = 1ULL << (EltBits - 1);
+ if (RndC->getZExtValue() == ExpectedRnd &&
+ Op.getOperand(0).getOpcode() == ISD::MUL) {
+ Op = Op.getOperand(0);
+ IsRounding = true;
+ }
+ }
+ }
+ }
+ }
- SDValue LHS = Sub.getOperand(0);
- SDValue RHS = Sub.getOperand(1);
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
- // Check if both operands are sign/zero extends from the target
- // type
- bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
- RHS.getOpcode() == ISD::SIGN_EXTEND;
- bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
- RHS.getOpcode() == ISD::ZERO_EXTEND;
+ bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
+ bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
+ bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
+ bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
- if (!IsSignExt && !IsZeroExt)
+ if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt))
return SDValue();
SDValue A = LHS.getOperand(0);
SDValue B = RHS.getOperand(0);
- // Check if the extends are from our target vector type
if (A.getValueType() != VT || B.getValueType() != VT)
return SDValue();
- // Determine the instruction based on type and signedness
unsigned Opc;
- if (IsSignExt)
- Opc = RISCVISD::PASUB;
- else if (IsZeroExt)
- Opc = RISCVISD::PASUBU;
- else
+ switch (Op.getOpcode()) {
+ default:
return SDValue();
+ case ISD::SUB:
+ // PASUB/PASUBU: shift amount must be 1
+ if (ShAmtVal != 1)
+ return SDValue();
+ if (LHSIsSExt && RHSIsSExt)
+ Opc = RISCVISD::PASUB;
+ else if (LHSIsZExt && RHSIsZExt)
+ Opc = RISCVISD::PASUBU;
+ else
+ return SDValue();
+ break;
+ case ISD::MUL:
+ // PMULH*/PMULHR*: shift amount must be element size, only for i16/i32
+ if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
+ return SDValue();
+ if (IsRounding) {
+ if (LHSIsSExt && RHSIsSExt)
+ Opc = RISCVISD::PMULHR;
+ else if (LHSIsZExt && RHSIsZExt)
+ Opc = RISCVISD::PMULHRU;
+ else if (LHSIsSExt && RHSIsZExt)
+ Opc = RISCVISD::PMULHRSU;
+ else
+ return SDValue();
+ } else {
+ if (LHSIsSExt && RHSIsZExt)
+ Opc = RISCVISD::PMULHSU;
+ else
+ return SDValue();
+ }
+ break;
+ }
- // Create the machine node directly
return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index bba9f961b9639..587ac8ee238f4 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1463,12 +1463,16 @@ let Predicates = [HasStdExtP, IsRV32] in {
def riscv_absw : RVSDNode<"ABSW", SDT_RISCVIntUnaryOpW>;
-def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
- SDTCisInt<0>,
- SDTCisSameAs<0, 1>,
- SDTCisSameAs<0, 2>]>;
-def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPASUB>;
-def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPASUB>;
+def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
+ SDTCisInt<0>,
+ SDTCisSameAs<0, 1>,
+ SDTCisSameAs<0, 2>]>;
+def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
+def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
+def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
+def riscv_pmulhr : RVSDNode<"PMULHR", SDT_RISCVPBinOp>;
+def riscv_pmulhru : RVSDNode<"PMULHRU", SDT_RISCVPBinOp>;
+def riscv_pmulhrsu : RVSDNode<"PMULHRSU", SDT_RISCVPBinOp>;
let Predicates = [HasStdExtP] in {
def : PatGpr<abs, ABS>;
@@ -1513,6 +1517,16 @@ let Predicates = [HasStdExtP] in {
def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
+ // 16-bit multiply high patterns
+ def: Pat<(XLenVecI16VT (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
+
+ // 16-bit multiply high rounding patterns
+ def: Pat<(XLenVecI16VT (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_H GPR:$rs1, GPR:$rs2)>;
+
// 8-bit logical shift left patterns
def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
(PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1580,6 +1594,16 @@ let Predicates = [HasStdExtP, IsRV64] in {
def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
+ // 32-bit multiply high patterns
+ def: Pat<(v2i32 (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
+
+ // 32-bit multiply high rounding patterns
+ def: Pat<(v2i32 (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_W GPR:$rs1, GPR:$rs2)>;
+
// 32-bit logical shift left
def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
(PSLL_WS GPR:$rs1, GPR:$rs2)>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index cd59aa03597e2..33127c3d140fa 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -746,3 +746,126 @@ define void @test_psll_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
store <4 x i8> %res, ptr %ret_ptr
ret void
}
+
+; Test packed multiply high signed for v2i16
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulh.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = sext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high unsigned for v2i16
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = zext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high signed-unsigned for v2i16
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhsu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %shift = lshr <2 x i32> %mul, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding signed for v2i16
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhr.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = sext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding unsigned for v2i16
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhru.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = zext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding signed-unsigned for v2i16
+define void @test_pmulhrsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhrsu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index c7fb891cdd996..8a741f5821b70 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -841,3 +841,245 @@ define void @test_psll_ws_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
store <2 x i32> %res, ptr %ret_ptr
ret void
}
+
+; Test packed multiply high signed
+define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulh.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = sext <4 x i16> %a to <4 x i32>
+ %b_ext = sext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulh_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulh_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulh.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = sext <2 x i32> %a to <2 x i64>
+ %b_ext = sext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+
+; Test packed multiply high unsigned
+define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhu.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = zext <4 x i16> %a to <4 x i32>
+ %b_ext = zext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhu_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhu.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = zext <2 x i32> %a to <2 x i64>
+ %b_ext = zext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+
+; Test packed multiply high signed-unsigned
+define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhsu.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = sext <4 x i16> %a to <4 x i32>
+ %b_ext = zext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %shift = lshr <4 x i32> %mul, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhsu_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhsu.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = sext <2 x i32> %a to <2 x i64>
+ %b_ext = zext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %shift = lshr <2 x i64> %mul, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding signed
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhr.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = sext <4 x i16> %a to <4 x i32>
+ %b_ext = sext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+ %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhr_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhr.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = sext <2 x i32> %a to <2 x i64>
+ %b_ext = sext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+ %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding unsigned
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhru.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = zext <4 x i16> %a to <4 x i32>
+ %b_ext = zext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+ %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ s...
[truncated]
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In p extension spec, rounding is performed by adding 1 << (elt_bits - 1) to its result.
Stack on: #171581