From 8e875e13afda56300ce19503139f5ba8cffbf8d4 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Wed, 21 May 2025 08:32:49 -0700 Subject: [PATCH 1/3] [RISCV] Support scalable vectors for the zvqdotq lowering paths This was an oversight in the original patch series. Without this change, the newly added tests fail assertions. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 21 +- llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll | 589 ++++++++++++++++++ 2 files changed, 601 insertions(+), 9 deletions(-) create mode 100644 llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 1158499718737..73798b899e9ff 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -18177,17 +18177,20 @@ static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1, assert(VT == Op1.getSimpleValueType() && VT.getVectorElementType() == MVT::i32); - assert(VT.isFixedLengthVector()); - MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); - SDValue Passthru = convertToScalableVector( - ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget); - Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget); - Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget); - + SDValue Passthru = DAG.getConstant(0, DL, VT); + MVT ContainerVT = VT; + if (VT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); + Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget); + Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget); + Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget); + } auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT, {Op0, Op1, Passthru, Mask, VL}); - return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget); + if (VT.isFixedLengthVector()) + return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget); + return LocalAccum; } static MVT getQDOTXResultType(MVT OpVT) { @@ -18207,7 +18210,7 @@ static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B, EVT AVT = A.getValueType(); EVT BVT = B.getValueType(); assert(AVT.getVectorElementType() == BVT.getVectorElementType()); - if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) { + if (AVT.getVectorMinNumElements() > BVT.getVectorMinNumElements()) { std::swap(A, B); std::swap(AVT, BVT); } diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll new file mode 100644 index 0000000000000..d811cdb5e444d --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll @@ -0,0 +1,589 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT +; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT +; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32 +; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64 + +define i32 @vqdot_vv( %a, %b) { +; NODOT-LABEL: vqdot_vv: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; NODOT-NEXT: vsext.vf2 v16, v8 +; NODOT-NEXT: vsext.vf2 v20, v10 +; NODOT-NEXT: vwmul.vv v8, v16, v20 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v16, zero +; NODOT-NEXT: vredsum.vs v8, v8, v16 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdot_vv: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdot.vv v12, v8, v10 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vredsum.vs v8, v12, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.sext = sext %a to + %b.sext = sext %b to + %mul = mul nuw nsw %a.sext, %b.sext + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + ret i32 %res +} + +define i32 @vqdot_vx_constant( %a) { +; CHECK-LABEL: vqdot_vx_constant: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; CHECK-NEXT: vsext.vf2 v16, v8 +; CHECK-NEXT: li a0, 23 +; CHECK-NEXT: vwmul.vx v8, v16, a0 +; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; CHECK-NEXT: vmv.s.x v16, zero +; CHECK-NEXT: vredsum.vs v8, v8, v16 +; CHECK-NEXT: vmv.x.s a0, v8 +; CHECK-NEXT: ret +entry: + %a.sext = sext %a to + %mul = mul nuw nsw %a.sext, splat (i32 23) + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + ret i32 %res +} + +define i32 @vqdot_vx_constant_swapped( %a) { +; CHECK-LABEL: vqdot_vx_constant_swapped: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; CHECK-NEXT: vsext.vf2 v16, v8 +; CHECK-NEXT: li a0, 23 +; CHECK-NEXT: vwmul.vx v8, v16, a0 +; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; CHECK-NEXT: vmv.s.x v16, zero +; CHECK-NEXT: vredsum.vs v8, v8, v16 +; CHECK-NEXT: vmv.x.s a0, v8 +; CHECK-NEXT: ret +entry: + %a.sext = sext %a to + %mul = mul nuw nsw splat (i32 23), %a.sext + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + ret i32 %res +} + +define i32 @vqdotu_vv( %a, %b) { +; NODOT-LABEL: vqdotu_vv: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma +; NODOT-NEXT: vwmulu.vv v12, v8, v10 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v8, zero +; NODOT-NEXT: vsetvli zero, zero, e16, m4, ta, ma +; NODOT-NEXT: vwredsumu.vs v8, v12, v8 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotu_vv: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdotu.vv v12, v8, v10 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vredsum.vs v8, v12, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.zext = zext %a to + %b.zext = zext %b to + %mul = mul nuw nsw %a.zext, %b.zext + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + ret i32 %res +} + +define i32 @vqdotu_vx_constant( %a) { +; CHECK-LABEL: vqdotu_vx_constant: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; CHECK-NEXT: vzext.vf2 v16, v8 +; CHECK-NEXT: li a0, 123 +; CHECK-NEXT: vwmulu.vx v8, v16, a0 +; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; CHECK-NEXT: vmv.s.x v16, zero +; CHECK-NEXT: vredsum.vs v8, v8, v16 +; CHECK-NEXT: vmv.x.s a0, v8 +; CHECK-NEXT: ret +entry: + %a.zext = zext %a to + %mul = mul nuw nsw %a.zext, splat (i32 123) + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + ret i32 %res +} + +define i32 @vqdotsu_vv( %a, %b) { +; NODOT-LABEL: vqdotsu_vv: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; NODOT-NEXT: vsext.vf2 v16, v8 +; NODOT-NEXT: vzext.vf2 v20, v10 +; NODOT-NEXT: vwmulsu.vv v8, v16, v20 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v16, zero +; NODOT-NEXT: vredsum.vs v8, v8, v16 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotsu_vv: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdotsu.vv v12, v8, v10 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vredsum.vs v8, v12, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.sext = sext %a to + %b.zext = zext %b to + %mul = mul nuw nsw %a.sext, %b.zext + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + ret i32 %res +} + +define i32 @vqdotsu_vv_swapped( %a, %b) { +; NODOT-LABEL: vqdotsu_vv_swapped: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; NODOT-NEXT: vsext.vf2 v16, v8 +; NODOT-NEXT: vzext.vf2 v20, v10 +; NODOT-NEXT: vwmulsu.vv v8, v16, v20 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v16, zero +; NODOT-NEXT: vredsum.vs v8, v8, v16 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotsu_vv_swapped: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdotsu.vv v12, v8, v10 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vredsum.vs v8, v12, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.sext = sext %a to + %b.zext = zext %b to + %mul = mul nuw nsw %b.zext, %a.sext + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + ret i32 %res +} + +define i32 @vdotqsu_vx_constant( %a) { +; CHECK-LABEL: vdotqsu_vx_constant: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; CHECK-NEXT: vsext.vf2 v16, v8 +; CHECK-NEXT: li a0, 123 +; CHECK-NEXT: vwmul.vx v8, v16, a0 +; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; CHECK-NEXT: vmv.s.x v16, zero +; CHECK-NEXT: vredsum.vs v8, v8, v16 +; CHECK-NEXT: vmv.x.s a0, v8 +; CHECK-NEXT: ret +entry: + %a.sext = sext %a to + %mul = mul nuw nsw %a.sext, splat (i32 123) + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + ret i32 %res +} + +define i32 @vdotqus_vx_constant( %a) { +; CHECK-LABEL: vdotqus_vx_constant: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; CHECK-NEXT: vzext.vf2 v16, v8 +; CHECK-NEXT: li a0, -23 +; CHECK-NEXT: vmv.v.x v20, a0 +; CHECK-NEXT: vwmulsu.vv v8, v20, v16 +; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; CHECK-NEXT: vmv.s.x v16, zero +; CHECK-NEXT: vredsum.vs v8, v8, v16 +; CHECK-NEXT: vmv.x.s a0, v8 +; CHECK-NEXT: ret +entry: + %a.zext = zext %a to + %mul = mul nuw nsw %a.zext, splat (i32 -23) + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + ret i32 %res +} + +define i32 @reduce_of_sext( %a) { +; NODOT-LABEL: reduce_of_sext: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma +; NODOT-NEXT: vsext.vf4 v16, v8 +; NODOT-NEXT: vmv.s.x v8, zero +; NODOT-NEXT: vredsum.vs v8, v16, v8 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT32-LABEL: reduce_of_sext: +; DOT32: # %bb.0: # %entry +; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT32-NEXT: vmv.v.i v10, 0 +; DOT32-NEXT: lui a0, 4112 +; DOT32-NEXT: addi a0, a0, 257 +; DOT32-NEXT: vqdot.vx v10, v8, a0 +; DOT32-NEXT: vmv.s.x v8, zero +; DOT32-NEXT: vredsum.vs v8, v10, v8 +; DOT32-NEXT: vmv.x.s a0, v8 +; DOT32-NEXT: ret +; +; DOT64-LABEL: reduce_of_sext: +; DOT64: # %bb.0: # %entry +; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT64-NEXT: vmv.v.i v10, 0 +; DOT64-NEXT: lui a0, 4112 +; DOT64-NEXT: addiw a0, a0, 257 +; DOT64-NEXT: vqdot.vx v10, v8, a0 +; DOT64-NEXT: vmv.s.x v8, zero +; DOT64-NEXT: vredsum.vs v8, v10, v8 +; DOT64-NEXT: vmv.x.s a0, v8 +; DOT64-NEXT: ret +entry: + %a.ext = sext %a to + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %a.ext) + ret i32 %res +} + +define i32 @reduce_of_zext( %a) { +; NODOT-LABEL: reduce_of_zext: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma +; NODOT-NEXT: vzext.vf4 v16, v8 +; NODOT-NEXT: vmv.s.x v8, zero +; NODOT-NEXT: vredsum.vs v8, v16, v8 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT32-LABEL: reduce_of_zext: +; DOT32: # %bb.0: # %entry +; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT32-NEXT: vmv.v.i v10, 0 +; DOT32-NEXT: lui a0, 4112 +; DOT32-NEXT: addi a0, a0, 257 +; DOT32-NEXT: vqdotu.vx v10, v8, a0 +; DOT32-NEXT: vmv.s.x v8, zero +; DOT32-NEXT: vredsum.vs v8, v10, v8 +; DOT32-NEXT: vmv.x.s a0, v8 +; DOT32-NEXT: ret +; +; DOT64-LABEL: reduce_of_zext: +; DOT64: # %bb.0: # %entry +; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT64-NEXT: vmv.v.i v10, 0 +; DOT64-NEXT: lui a0, 4112 +; DOT64-NEXT: addiw a0, a0, 257 +; DOT64-NEXT: vqdotu.vx v10, v8, a0 +; DOT64-NEXT: vmv.s.x v8, zero +; DOT64-NEXT: vredsum.vs v8, v10, v8 +; DOT64-NEXT: vmv.x.s a0, v8 +; DOT64-NEXT: ret +entry: + %a.ext = zext %a to + %res = tail call i32 @llvm.vector.reduce.add.v16i32( %a.ext) + ret i32 %res +} + +define i32 @vqdot_vv_accum( %a, %b, %x) { +; NODOT-LABEL: vqdot_vv_accum: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; NODOT-NEXT: vsext.vf2 v12, v8 +; NODOT-NEXT: vsext.vf2 v24, v10 +; NODOT-NEXT: vwmacc.vv v16, v12, v24 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v8, zero +; NODOT-NEXT: vredsum.vs v8, v16, v8 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdot_vv_accum: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdot.vv v12, v8, v10 +; DOT-NEXT: vadd.vv v16, v12, v16 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma +; DOT-NEXT: vredsum.vs v8, v16, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.sext = sext %a to + %b.sext = sext %b to + %mul = mul nuw nsw %a.sext, %b.sext + %add = add %mul, %x + %sum = tail call i32 @llvm.vector.reduce.add.v16i32( %add) + ret i32 %sum +} + +define i32 @vqdotu_vv_accum( %a, %b, %x) { +; NODOT-LABEL: vqdotu_vv_accum: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma +; NODOT-NEXT: vwmulu.vv v12, v8, v10 +; NODOT-NEXT: vsetvli zero, zero, e16, m4, ta, ma +; NODOT-NEXT: vwaddu.wv v16, v16, v12 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v8, zero +; NODOT-NEXT: vredsum.vs v8, v16, v8 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotu_vv_accum: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdotu.vv v12, v8, v10 +; DOT-NEXT: vadd.vv v16, v12, v16 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma +; DOT-NEXT: vredsum.vs v8, v16, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.zext = zext %a to + %b.zext = zext %b to + %mul = mul nuw nsw %a.zext, %b.zext + %add = add %mul, %x + %sum = tail call i32 @llvm.vector.reduce.add.v16i32( %add) + ret i32 %sum +} + +define i32 @vqdotsu_vv_accum( %a, %b, %x) { +; NODOT-LABEL: vqdotsu_vv_accum: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; NODOT-NEXT: vsext.vf2 v12, v8 +; NODOT-NEXT: vzext.vf2 v24, v10 +; NODOT-NEXT: vwmaccsu.vv v16, v12, v24 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v8, zero +; NODOT-NEXT: vredsum.vs v8, v16, v8 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotsu_vv_accum: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdotsu.vv v12, v8, v10 +; DOT-NEXT: vadd.vv v16, v12, v16 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma +; DOT-NEXT: vredsum.vs v8, v16, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.sext = sext %a to + %b.zext = zext %b to + %mul = mul nuw nsw %a.sext, %b.zext + %add = add %mul, %x + %sum = tail call i32 @llvm.vector.reduce.add.v16i32( %add) + ret i32 %sum +} + +define i32 @vqdot_vv_scalar_add( %a, %b, i32 %x) { +; NODOT-LABEL: vqdot_vv_scalar_add: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a1, zero, e16, m4, ta, ma +; NODOT-NEXT: vsext.vf2 v16, v8 +; NODOT-NEXT: vsext.vf2 v20, v10 +; NODOT-NEXT: vwmul.vv v8, v16, v20 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v16, a0 +; NODOT-NEXT: vredsum.vs v8, v8, v16 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdot_vv_scalar_add: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a1, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdot.vv v12, v8, v10 +; DOT-NEXT: vmv.s.x v8, a0 +; DOT-NEXT: vredsum.vs v8, v12, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.sext = sext %a to + %b.sext = sext %b to + %mul = mul nuw nsw %a.sext, %b.sext + %sum = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + %add = add i32 %sum, %x + ret i32 %add +} + +define i32 @vqdotu_vv_scalar_add( %a, %b, i32 %x) { +; NODOT-LABEL: vqdotu_vv_scalar_add: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a1, zero, e8, m2, ta, ma +; NODOT-NEXT: vwmulu.vv v12, v8, v10 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v8, a0 +; NODOT-NEXT: vsetvli zero, zero, e16, m4, ta, ma +; NODOT-NEXT: vwredsumu.vs v8, v12, v8 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotu_vv_scalar_add: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a1, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdotu.vv v12, v8, v10 +; DOT-NEXT: vmv.s.x v8, a0 +; DOT-NEXT: vredsum.vs v8, v12, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.zext = zext %a to + %b.zext = zext %b to + %mul = mul nuw nsw %a.zext, %b.zext + %sum = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + %add = add i32 %sum, %x + ret i32 %add +} + +define i32 @vqdotsu_vv_scalar_add( %a, %b, i32 %x) { +; NODOT-LABEL: vqdotsu_vv_scalar_add: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a1, zero, e16, m4, ta, ma +; NODOT-NEXT: vsext.vf2 v16, v8 +; NODOT-NEXT: vzext.vf2 v20, v10 +; NODOT-NEXT: vwmulsu.vv v8, v16, v20 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v16, a0 +; NODOT-NEXT: vredsum.vs v8, v8, v16 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotsu_vv_scalar_add: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a1, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v12, 0 +; DOT-NEXT: vqdotsu.vv v12, v8, v10 +; DOT-NEXT: vmv.s.x v8, a0 +; DOT-NEXT: vredsum.vs v8, v12, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.sext = sext %a to + %b.zext = zext %b to + %mul = mul nuw nsw %a.sext, %b.zext + %sum = tail call i32 @llvm.vector.reduce.add.v16i32( %mul) + %add = add i32 %sum, %x + ret i32 %add +} + +define i32 @vqdot_vv_split( %a, %b, %c, %d) { +; NODOT-LABEL: vqdot_vv_split: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; NODOT-NEXT: vsext.vf2 v16, v8 +; NODOT-NEXT: vsext.vf2 v20, v10 +; NODOT-NEXT: vsext.vf2 v24, v12 +; NODOT-NEXT: vsext.vf2 v28, v14 +; NODOT-NEXT: vwmul.vv v8, v16, v20 +; NODOT-NEXT: vwmacc.vv v8, v24, v28 +; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma +; NODOT-NEXT: vmv.s.x v16, zero +; NODOT-NEXT: vredsum.vs v8, v8, v16 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdot_vv_split: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; DOT-NEXT: vmv.v.i v16, 0 +; DOT-NEXT: vmv.v.i v18, 0 +; DOT-NEXT: vqdot.vv v16, v8, v10 +; DOT-NEXT: vqdot.vv v18, v12, v14 +; DOT-NEXT: vadd.vv v8, v16, v18 +; DOT-NEXT: vmv.s.x v10, zero +; DOT-NEXT: vredsum.vs v8, v8, v10 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret +entry: + %a.sext = sext %a to + %b.sext = sext %b to + %mul = mul nuw nsw %a.sext, %b.sext + %c.sext = sext %c to + %d.sext = sext %d to + %mul2 = mul nuw nsw %c.sext, %d.sext + %add = add %mul, %mul2 + %sum = tail call i32 @llvm.vector.reduce.add.v16i32( %add) + ret i32 %sum +} + + +define @vqdot_vv_partial_reduce( %a, %b) { +; CHECK-LABEL: vqdot_vv_partial_reduce: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; CHECK-NEXT: vsext.vf2 v16, v8 +; CHECK-NEXT: vsext.vf2 v20, v10 +; CHECK-NEXT: vwmul.vv v8, v16, v20 +; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; CHECK-NEXT: vadd.vv v8, v14, v8 +; CHECK-NEXT: vadd.vv v10, v10, v12 +; CHECK-NEXT: vadd.vv v8, v10, v8 +; CHECK-NEXT: ret +entry: + %a.sext = sext %a to + %b.sext = sext %b to + %mul = mul nuw nsw %a.sext, %b.sext + %res = call @llvm.experimental.vector.partial.reduce.add( zeroinitializer, %mul) + ret %res +} + +define @vqdot_vv_partial_reduce2( %a, %b, %accum) { +; CHECK-LABEL: vqdot_vv_partial_reduce2: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; CHECK-NEXT: vsext.vf2 v24, v8 +; CHECK-NEXT: vsext.vf2 v28, v10 +; CHECK-NEXT: vwmul.vv v16, v24, v28 +; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; CHECK-NEXT: vadd.vv v8, v18, v20 +; CHECK-NEXT: vadd.vv v10, v12, v16 +; CHECK-NEXT: vadd.vv v10, v22, v10 +; CHECK-NEXT: vadd.vv v8, v8, v10 +; CHECK-NEXT: ret +entry: + %a.sext = sext %a to + %b.sext = sext %b to + %mul = mul nuw nsw %a.sext, %b.sext + %res = call @llvm.experimental.vector.partial.reduce.add( %accum, %mul) + ret %res +} + +define @vqdot_vv_partial_reduce3( %a, %b) { +; CHECK-LABEL: vqdot_vv_partial_reduce3: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma +; CHECK-NEXT: vsext.vf2 v16, v8 +; CHECK-NEXT: vsext.vf2 v20, v10 +; CHECK-NEXT: vwmul.vv v8, v16, v20 +; CHECK-NEXT: ret +entry: + %a.sext = sext %a to + %b.sext = sext %b to + %mul = mul nuw nsw %a.sext, %b.sext + %res = call @llvm.experimental.vector.partial.reduce.add.nvx8i32.nvx16i32.nvx16i32( %mul, zeroinitializer) + ret %res +} From aa122051a8405bf00922309004c76e49c79bd3f8 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Wed, 21 May 2025 08:46:37 -0700 Subject: [PATCH 2/3] [RISCV] Support scalable vectors in zvqdotq accumulator folding (This part is a missed optimization, not a correctness issue.) --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 23 +++++++++---- llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll | 32 +++++++------------ 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 73798b899e9ff..d69e04a9912a2 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -18644,7 +18644,7 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG, static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(N->getOpcode() == RISCVISD::ADD_VL); + assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD); if (!N->getValueType(0).isVector()) return SDValue(); @@ -18652,9 +18652,11 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG, SDValue Addend = N->getOperand(0); SDValue DotOp = N->getOperand(1); - SDValue AddPassthruOp = N->getOperand(2); - if (!AddPassthruOp.isUndef()) - return SDValue(); + if (N->getOpcode() == RISCVISD::ADD_VL) { + SDValue AddPassthruOp = N->getOperand(2); + if (!AddPassthruOp.isUndef()) + return SDValue(); + } auto IsVqdotqOpc = [](unsigned Opc) { switch (Opc) { @@ -18673,8 +18675,15 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG, if (!IsVqdotqOpc(DotOp.getOpcode())) return SDValue(); - SDValue AddMask = N->getOperand(3); - SDValue AddVL = N->getOperand(4); + auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (N->getOpcode() == ISD::ADD) { + SDLoc DL(N); + return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG, + Subtarget); + } + return std::make_pair(N->getOperand(3), N->getOperand(4)); + }(N, DAG, Subtarget); SDValue MulVL = DotOp.getOperand(4); if (AddVL != MulVL) @@ -19312,6 +19321,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return V; if (SDValue V = combineToVWMACC(N, DAG, Subtarget)) return V; + if (SDValue V = combineVqdotAccum(N, DAG, Subtarget)) + return V; return performADDCombine(N, DCI, Subtarget); } case ISD::SUB: { diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll index d811cdb5e444d..a56ef0cd75d6a 100644 --- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll @@ -314,12 +314,10 @@ define i32 @vqdot_vv_accum( %a, %b, %a, %b, %a, %b, %a, %b, Date: Wed, 21 May 2025 10:13:40 -0700 Subject: [PATCH 3/3] Address review comment --- llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll index a56ef0cd75d6a..34084459edcd9 100644 --- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll @@ -576,6 +576,6 @@ entry: %a.sext = sext %a to %b.sext = sext %b to %mul = mul nuw nsw %a.sext, %b.sext - %res = call @llvm.experimental.vector.partial.reduce.add.nvx8i32.nvx16i32.nvx16i32( %mul, zeroinitializer) + %res = call @llvm.experimental.vector.partial.reduce.add.nvx16i32.nvx16i32( %mul, zeroinitializer) ret %res }