-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV][WIP] Optimize sum of absolute differences pattern. #82722
Conversation
This writes (abs (sub (zext X), (zext Y))) to (zext (sub (zext (max X, Y), (min X, Y)))). This was taken from my downstream and has some overfitting to a particular benchmark. Posting for discussion.
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesThis writes (abs (sub (zext X), (zext Y))) to This was taken from my downstream and has some overfitting to a particular benchmark. Posting for discussion. Full diff: https://github.com/llvm/llvm-project/pull/82722.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5c67aaf6785669..6bd62b79e5a74f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13176,6 +13176,61 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
}
+// Look for (abs (sub (zext X), (zext Y))).
+// Rewrite as (zext (sub (zext (max X, Y), (min X, Y)))) if the user is an add
+// or reduction add. The min/max can be done in parallel and with a lower LMUL
+// than the original code. The two zexts can be folded into widening sub and
+// widening add or widening redsum.
+static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+ if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i32 ||
+ !TLI.isTypeLegal(VT))
+ return SDValue();
+
+ SDValue Src = N->getOperand(0);
+ if (Src.getOpcode() != ISD::SUB || !Src.hasOneUse())
+ return SDValue();
+
+ // Make sure the use is an add or reduce add so the zext we create at the end
+ // will be folded.
+ if (!N->hasOneUse() || (N->use_begin()->getOpcode() != ISD::ADD &&
+ N->use_begin()->getOpcode() != ISD::VECREDUCE_ADD))
+ return SDValue();
+
+ // Inputs to the subtract should be zext.
+ SDValue Op0 = Src.getOperand(0);
+ SDValue Op1 = Src.getOperand(1);
+ if (Op0.getOpcode() != ISD::ZERO_EXTEND || !Op0.hasOneUse() ||
+ Op1.getOpcode() != ISD::ZERO_EXTEND || !Op1.hasOneUse())
+ return SDValue();
+
+ Op0 = Op0.getOperand(0);
+ Op1 = Op1.getOperand(0);
+
+ // Inputs should be i8 vectors.
+ if (Op0.getValueType().getVectorElementType() != MVT::i8 ||
+ Op1.getValueType().getVectorElementType() != MVT::i8)
+ return SDValue();
+
+ SDLoc DL(N);
+
+ SDValue Max = DAG.getNode(ISD::UMAX, DL, Op0.getValueType(), Op0, Op1);
+ SDValue Min = DAG.getNode(ISD::UMIN, DL, Op0.getValueType(), Op0, Op1);
+
+ // The intermediate VT should be i16.
+ EVT IntermediateVT =
+ EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorElementCount());
+
+ Max = DAG.getNode(ISD::ZERO_EXTEND, DL, IntermediateVT, Max);
+ Min = DAG.getNode(ISD::ZERO_EXTEND, DL, IntermediateVT, Min);
+
+ SDValue Sub = DAG.getNode(ISD::SUB, DL, IntermediateVT, Max, Min);
+
+ return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Sub);
+}
+
static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
if (!VT.isVector())
@@ -15698,6 +15753,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAG.getConstant(~SignBit, DL, VT));
}
case ISD::ABS: {
+ if (SDValue V = performABSCombine(N, DAG))
+ return V;
+
EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0);
// abs (sext) -> zext (abs)
diff --git a/llvm/test/CodeGen/RISCV/rvv/sad.ll b/llvm/test/CodeGen/RISCV/rvv/sad.ll
new file mode 100644
index 00000000000000..ed25431c6f45cc
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/sad.ll
@@ -0,0 +1,120 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s
+
+define signext i32 @sad(ptr %a, ptr %b) {
+; CHECK-LABEL: sad:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-NEXT: vle8.v v8, (a0)
+; CHECK-NEXT: vle8.v v9, (a1)
+; CHECK-NEXT: vminu.vv v10, v8, v9
+; CHECK-NEXT: vmaxu.vv v8, v8, v9
+; CHECK-NEXT: vwsubu.vv v9, v8, v10
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vmv.s.x v8, zero
+; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
+; CHECK-NEXT: vwredsumu.vs v8, v9, v8
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %0 = load <4 x i8>, ptr %a, align 1
+ %1 = zext <4 x i8> %0 to <4 x i32>
+ %2 = load <4 x i8>, ptr %b, align 1
+ %3 = zext <4 x i8> %2 to <4 x i32>
+ %4 = sub nsw <4 x i32> %1, %3
+ %5 = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> %4, i1 true)
+ %6 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %5)
+ ret i32 %6
+}
+
+define signext i32 @sad2(ptr %a, ptr %b, i32 signext %stridea, i32 signext %strideb) {
+; CHECK-LABEL: sad2:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; CHECK-NEXT: vle8.v v8, (a0)
+; CHECK-NEXT: vle8.v v9, (a1)
+; CHECK-NEXT: add a0, a0, a2
+; CHECK-NEXT: add a1, a1, a3
+; CHECK-NEXT: vle8.v v10, (a0)
+; CHECK-NEXT: vle8.v v11, (a1)
+; CHECK-NEXT: vminu.vv v12, v8, v9
+; CHECK-NEXT: vmaxu.vv v8, v8, v9
+; CHECK-NEXT: vwsubu.vv v14, v8, v12
+; CHECK-NEXT: vminu.vv v8, v10, v11
+; CHECK-NEXT: vmaxu.vv v9, v10, v11
+; CHECK-NEXT: vwsubu.vv v12, v9, v8
+; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT: add a0, a0, a2
+; CHECK-NEXT: add a1, a1, a3
+; CHECK-NEXT: vle8.v v16, (a0)
+; CHECK-NEXT: vle8.v v17, (a1)
+; CHECK-NEXT: vwaddu.vv v8, v12, v14
+; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma
+; CHECK-NEXT: vminu.vv v12, v16, v17
+; CHECK-NEXT: vmaxu.vv v13, v16, v17
+; CHECK-NEXT: vwsubu.vv v14, v13, v12
+; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT: add a0, a0, a2
+; CHECK-NEXT: add a1, a1, a3
+; CHECK-NEXT: vle8.v v12, (a0)
+; CHECK-NEXT: vle8.v v13, (a1)
+; CHECK-NEXT: vwaddu.wv v8, v8, v14
+; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma
+; CHECK-NEXT: vminu.vv v14, v12, v13
+; CHECK-NEXT: vmaxu.vv v12, v12, v13
+; CHECK-NEXT: vwsubu.vv v16, v12, v14
+; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT: vwaddu.wv v8, v8, v16
+; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT: vmv.s.x v12, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v12
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %idx.ext8 = sext i32 %strideb to i64
+ %idx.ext = sext i32 %stridea to i64
+ %0 = load <16 x i8>, ptr %a, align 1
+ %1 = zext <16 x i8> %0 to <16 x i32>
+ %2 = load <16 x i8>, ptr %b, align 1
+ %3 = zext <16 x i8> %2 to <16 x i32>
+ %4 = sub nsw <16 x i32> %1, %3
+ %5 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %4, i1 true)
+ %6 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
+ %add.ptr = getelementptr inbounds i8, ptr %a, i64 %idx.ext
+ %add.ptr9 = getelementptr inbounds i8, ptr %b, i64 %idx.ext8
+ %7 = load <16 x i8>, ptr %add.ptr, align 1
+ %8 = zext <16 x i8> %7 to <16 x i32>
+ %9 = load <16 x i8>, ptr %add.ptr9, align 1
+ %10 = zext <16 x i8> %9 to <16 x i32>
+ %11 = sub nsw <16 x i32> %8, %10
+ %12 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %11, i1 true)
+ %13 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %12)
+ %op.rdx.1 = add i32 %13, %6
+ %add.ptr.1 = getelementptr inbounds i8, ptr %add.ptr, i64 %idx.ext
+ %add.ptr9.1 = getelementptr inbounds i8, ptr %add.ptr9, i64 %idx.ext8
+ %14 = load <16 x i8>, ptr %add.ptr.1, align 1
+ %15 = zext <16 x i8> %14 to <16 x i32>
+ %16 = load <16 x i8>, ptr %add.ptr9.1, align 1
+ %17 = zext <16 x i8> %16 to <16 x i32>
+ %18 = sub nsw <16 x i32> %15, %17
+ %19 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %18, i1 true)
+ %20 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %19)
+ %op.rdx.2 = add i32 %20, %op.rdx.1
+ %add.ptr.2 = getelementptr inbounds i8, ptr %add.ptr.1, i64 %idx.ext
+ %add.ptr9.2 = getelementptr inbounds i8, ptr %add.ptr9.1, i64 %idx.ext8
+ %21 = load <16 x i8>, ptr %add.ptr.2, align 1
+ %22 = zext <16 x i8> %21 to <16 x i32>
+ %23 = load <16 x i8>, ptr %add.ptr9.2, align 1
+ %24 = zext <16 x i8> %23 to <16 x i32>
+ %25 = sub nsw <16 x i32> %22, %24
+ %26 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %25, i1 true)
+ %27 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %26)
+ %op.rdx.3 = add i32 %27, %op.rdx.2
+ ret i32 %op.rdx.3
+}
+
+declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1)
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
+declare <16 x i32> @llvm.abs.v16i32(<16 x i32>, i1)
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)
|
Well, I'd been planning on posting a similar patch and was trying to figure out how to motivate it cleanly enough to have you approve it. :) My recent set of narrowing patches were motivated by the same workload, thought they do apply a bit more broadly across the different kernels. My thinking was that I'd start with the slightly broader applicability pieces, then return to this one once the low hanging fruit was done. |
@@ -13176,6 +13176,61 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG, | |||
return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget); | |||
} | |||
|
|||
// Look for (abs (sub (zext X), (zext Y))). | |||
// Rewrite as (zext (sub (zext (max X, Y), (min X, Y)))) if the user is an add |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't the sub be done at the narrower type as well? (a >=u b) should imply that (a-b) doesn't underflow, and thus the high bits are always zero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can, but it would leave a bare vzext.vf2 later. I was trying to carefully create a widening sub and a widening add or widening reduction to minimize the number of individual vector operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the workload in question we can pull that vext.vf2 through through the chain of adds that sums the absolute difference pieces. We could use an i16 accumulator for the beginning of the chain and switch to an i32 accumulator later in the chain.
Naive use of the computeKnownBits could get us some of that to prove the overflows don't happen. Need to check the l length of the chain in the workload to see if that would exceed computeKnownBits depth limit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(This is mostly a response to myself)
Here's the alive2 proof for the transformation in this patch:
https://alive2.llvm.org/ce/z/XoCBZ5
Note the need for noundef on the source parameters. Alternatively, we could use freeze in the target.
Here's my proposed variant:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is lowering (abs (sub X, Y)) to (sub (umax x, y), (umin x, y)) worthwhile doing on its own, ignoring pulling through the zexts for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is lowering (abs (sub X, Y)) to (sub (umax x, y), (umin x, y)) worthwhile doing on its own, ignoring pulling through the zexts for now?
I'm not sure that's valid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that we would check the sub doesn't overflow with computeKnownBits, e.g. that the upper bits of X and Y are zero: https://alive2.llvm.org/ce/z/MZuw8V
|
||
// Make sure the use is an add or reduce add so the zext we create at the end | ||
// will be folded. | ||
if (!N->hasOneUse() || (N->use_begin()->getOpcode() != ISD::ADD && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of this, we can focus on the fact this allows a narrower representation.
EVT VT = N->getValueType(0); | ||
const TargetLowering &TLI = DAG.getTargetLoweringInfo(); | ||
|
||
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i32 || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't need to be fixed. Or i32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep it's overfitting a workload.
Op0 = Op0.getOperand(0); | ||
Op1 = Op1.getOperand(0); | ||
|
||
// Inputs should be i8 vectors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or i8.
Alternative approach: #86592 |
This writes (abs (sub (zext X), (zext Y))) to
(zext (sub (zext (max X, Y), (min X, Y)))).
This was taken from my downstream and has some overfitting to a particular benchmark.
It only works on i32 vectors and checks that the user can also become a widening instruction.
Posting for discussion.