Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][WIP] Optimize sum of absolute differences pattern. #82722

Closed
wants to merge 1 commit into from

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Feb 23, 2024

This writes (abs (sub (zext X), (zext Y))) to
(zext (sub (zext (max X, Y), (min X, Y)))).

This was taken from my downstream and has some overfitting to a particular benchmark.
It only works on i32 vectors and checks that the user can also become a widening instruction.

Posting for discussion.

This writes (abs (sub (zext X), (zext Y))) to
(zext (sub (zext (max X, Y), (min X, Y)))).

This was taken from my downstream and has some overfitting to a
particular benchmark.

Posting for discussion.
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 23, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

This writes (abs (sub (zext X), (zext Y))) to
(zext (sub (zext (max X, Y), (min X, Y)))).

This was taken from my downstream and has some overfitting to a particular benchmark.
It only works on i32 vectors and checks that the user can also become a widening instruction.

Posting for discussion.


Full diff: https://github.com/llvm/llvm-project/pull/82722.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+58)
  • (added) llvm/test/CodeGen/RISCV/rvv/sad.ll (+120)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5c67aaf6785669..6bd62b79e5a74f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13176,6 +13176,61 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
   return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
 }
 
+// Look for (abs (sub (zext X), (zext Y))).
+// Rewrite as (zext (sub (zext (max X, Y), (min X, Y)))) if the user is an add
+// or reduction add. The min/max can be done in parallel and with a lower LMUL
+// than the original code. The two zexts can be folded into widening sub and
+// widening add or widening redsum.
+static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG) {
+  EVT VT = N->getValueType(0);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i32 ||
+      !TLI.isTypeLegal(VT))
+    return SDValue();
+
+  SDValue Src = N->getOperand(0);
+  if (Src.getOpcode() != ISD::SUB || !Src.hasOneUse())
+    return SDValue();
+
+  // Make sure the use is an add or reduce add so the zext we create at the end
+  // will be folded.
+  if (!N->hasOneUse() || (N->use_begin()->getOpcode() != ISD::ADD &&
+                          N->use_begin()->getOpcode() != ISD::VECREDUCE_ADD))
+    return SDValue();
+
+  // Inputs to the subtract should be zext.
+  SDValue Op0 = Src.getOperand(0);
+  SDValue Op1 = Src.getOperand(1);
+  if (Op0.getOpcode() != ISD::ZERO_EXTEND || !Op0.hasOneUse() ||
+      Op1.getOpcode() != ISD::ZERO_EXTEND || !Op1.hasOneUse())
+    return SDValue();
+
+  Op0 = Op0.getOperand(0);
+  Op1 = Op1.getOperand(0);
+
+  // Inputs should be i8 vectors.
+  if (Op0.getValueType().getVectorElementType() != MVT::i8 ||
+      Op1.getValueType().getVectorElementType() != MVT::i8)
+    return SDValue();
+
+  SDLoc DL(N);
+
+  SDValue Max = DAG.getNode(ISD::UMAX, DL, Op0.getValueType(), Op0, Op1);
+  SDValue Min = DAG.getNode(ISD::UMIN, DL, Op0.getValueType(), Op0, Op1);
+
+  // The intermediate VT should be i16.
+  EVT IntermediateVT =
+      EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorElementCount());
+
+  Max = DAG.getNode(ISD::ZERO_EXTEND, DL, IntermediateVT, Max);
+  Min = DAG.getNode(ISD::ZERO_EXTEND, DL, IntermediateVT, Min);
+
+  SDValue Sub = DAG.getNode(ISD::SUB, DL, IntermediateVT, Max, Min);
+
+  return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Sub);
+}
+
 static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG) {
   EVT VT = N->getValueType(0);
   if (!VT.isVector())
@@ -15698,6 +15753,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                        DAG.getConstant(~SignBit, DL, VT));
   }
   case ISD::ABS: {
+    if (SDValue V = performABSCombine(N, DAG))
+      return V;
+
     EVT VT = N->getValueType(0);
     SDValue N0 = N->getOperand(0);
     // abs (sext) -> zext (abs)
diff --git a/llvm/test/CodeGen/RISCV/rvv/sad.ll b/llvm/test/CodeGen/RISCV/rvv/sad.ll
new file mode 100644
index 00000000000000..ed25431c6f45cc
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/sad.ll
@@ -0,0 +1,120 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s
+
+define signext i32 @sad(ptr %a, ptr %b) {
+; CHECK-LABEL: sad:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a1)
+; CHECK-NEXT:    vminu.vv v10, v8, v9
+; CHECK-NEXT:    vmaxu.vv v8, v8, v9
+; CHECK-NEXT:    vwsubu.vv v9, v8, v10
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.s.x v8, zero
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwredsumu.vs v8, v9, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %0 = load <4 x i8>, ptr %a, align 1
+  %1 = zext <4 x i8> %0 to <4 x i32>
+  %2 = load <4 x i8>, ptr %b, align 1
+  %3 = zext <4 x i8> %2 to <4 x i32>
+  %4 = sub nsw <4 x i32> %1, %3
+  %5 = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> %4, i1 true)
+  %6 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %5)
+  ret i32 %6
+}
+
+define signext i32 @sad2(ptr %a, ptr %b, i32 signext %stridea, i32 signext %strideb) {
+; CHECK-LABEL: sad2:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a1)
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a1, a1, a3
+; CHECK-NEXT:    vle8.v v10, (a0)
+; CHECK-NEXT:    vle8.v v11, (a1)
+; CHECK-NEXT:    vminu.vv v12, v8, v9
+; CHECK-NEXT:    vmaxu.vv v8, v8, v9
+; CHECK-NEXT:    vwsubu.vv v14, v8, v12
+; CHECK-NEXT:    vminu.vv v8, v10, v11
+; CHECK-NEXT:    vmaxu.vv v9, v10, v11
+; CHECK-NEXT:    vwsubu.vv v12, v9, v8
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a1, a1, a3
+; CHECK-NEXT:    vle8.v v16, (a0)
+; CHECK-NEXT:    vle8.v v17, (a1)
+; CHECK-NEXT:    vwaddu.vv v8, v12, v14
+; CHECK-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
+; CHECK-NEXT:    vminu.vv v12, v16, v17
+; CHECK-NEXT:    vmaxu.vv v13, v16, v17
+; CHECK-NEXT:    vwsubu.vv v14, v13, v12
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a1, a1, a3
+; CHECK-NEXT:    vle8.v v12, (a0)
+; CHECK-NEXT:    vle8.v v13, (a1)
+; CHECK-NEXT:    vwaddu.wv v8, v8, v14
+; CHECK-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
+; CHECK-NEXT:    vminu.vv v14, v12, v13
+; CHECK-NEXT:    vmaxu.vv v12, v12, v13
+; CHECK-NEXT:    vwsubu.vv v16, v12, v14
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vwaddu.wv v8, v8, v16
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT:    vmv.s.x v12, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v12
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %idx.ext8 = sext i32 %strideb to i64
+  %idx.ext = sext i32 %stridea to i64
+  %0 = load <16 x i8>, ptr %a, align 1
+  %1 = zext <16 x i8> %0 to <16 x i32>
+  %2 = load <16 x i8>, ptr %b, align 1
+  %3 = zext <16 x i8> %2 to <16 x i32>
+  %4 = sub nsw <16 x i32> %1, %3
+  %5 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %4, i1 true)
+  %6 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
+  %add.ptr = getelementptr inbounds i8, ptr %a, i64 %idx.ext
+  %add.ptr9 = getelementptr inbounds i8, ptr %b, i64 %idx.ext8
+  %7 = load <16 x i8>, ptr %add.ptr, align 1
+  %8 = zext <16 x i8> %7 to <16 x i32>
+  %9 = load <16 x i8>, ptr %add.ptr9, align 1
+  %10 = zext <16 x i8> %9 to <16 x i32>
+  %11 = sub nsw <16 x i32> %8, %10
+  %12 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %11, i1 true)
+  %13 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %12)
+  %op.rdx.1 = add i32 %13, %6
+  %add.ptr.1 = getelementptr inbounds i8, ptr %add.ptr, i64 %idx.ext
+  %add.ptr9.1 = getelementptr inbounds i8, ptr %add.ptr9, i64 %idx.ext8
+  %14 = load <16 x i8>, ptr %add.ptr.1, align 1
+  %15 = zext <16 x i8> %14 to <16 x i32>
+  %16 = load <16 x i8>, ptr %add.ptr9.1, align 1
+  %17 = zext <16 x i8> %16 to <16 x i32>
+  %18 = sub nsw <16 x i32> %15, %17
+  %19 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %18, i1 true)
+  %20 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %19)
+  %op.rdx.2 = add i32 %20, %op.rdx.1
+  %add.ptr.2 = getelementptr inbounds i8, ptr %add.ptr.1, i64 %idx.ext
+  %add.ptr9.2 = getelementptr inbounds i8, ptr %add.ptr9.1, i64 %idx.ext8
+  %21 = load <16 x i8>, ptr %add.ptr.2, align 1
+  %22 = zext <16 x i8> %21 to <16 x i32>
+  %23 = load <16 x i8>, ptr %add.ptr9.2, align 1
+  %24 = zext <16 x i8> %23 to <16 x i32>
+  %25 = sub nsw <16 x i32> %22, %24
+  %26 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %25, i1 true)
+  %27 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %26)
+  %op.rdx.3 = add i32 %27, %op.rdx.2
+  ret i32 %op.rdx.3
+}
+
+declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1)
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
+declare <16 x i32> @llvm.abs.v16i32(<16 x i32>, i1)
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)

@preames
Copy link
Collaborator

preames commented Feb 23, 2024

Posting for discussion.

Well, I'd been planning on posting a similar patch and was trying to figure out how to motivate it cleanly enough to have you approve it. :)

My recent set of narrowing patches were motivated by the same workload, thought they do apply a bit more broadly across the different kernels. My thinking was that I'd start with the slightly broader applicability pieces, then return to this one once the low hanging fruit was done.

@@ -13176,6 +13176,61 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
}

// Look for (abs (sub (zext X), (zext Y))).
// Rewrite as (zext (sub (zext (max X, Y), (min X, Y)))) if the user is an add
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't the sub be done at the narrower type as well? (a >=u b) should imply that (a-b) doesn't underflow, and thus the high bits are always zero?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can, but it would leave a bare vzext.vf2 later. I was trying to carefully create a widening sub and a widening add or widening reduction to minimize the number of individual vector operations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the workload in question we can pull that vext.vf2 through through the chain of adds that sums the absolute difference pieces. We could use an i16 accumulator for the beginning of the chain and switch to an i32 accumulator later in the chain.

Naive use of the computeKnownBits could get us some of that to prove the overflows don't happen. Need to check the l length of the chain in the workload to see if that would exceed computeKnownBits depth limit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This is mostly a response to myself)

Here's the alive2 proof for the transformation in this patch:

https://alive2.llvm.org/ce/z/XoCBZ5

Note the need for noundef on the source parameters. Alternatively, we could use freeze in the target.

Here's my proposed variant:

https://alive2.llvm.org/ce/z/f7MdJe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is lowering (abs (sub X, Y)) to (sub (umax x, y), (umin x, y)) worthwhile doing on its own, ignoring pulling through the zexts for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is lowering (abs (sub X, Y)) to (sub (umax x, y), (umin x, y)) worthwhile doing on its own, ignoring pulling through the zexts for now?

I'm not sure that's valid.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that we would check the sub doesn't overflow with computeKnownBits, e.g. that the upper bits of X and Y are zero: https://alive2.llvm.org/ce/z/MZuw8V


// Make sure the use is an add or reduce add so the zext we create at the end
// will be folded.
if (!N->hasOneUse() || (N->use_begin()->getOpcode() != ISD::ADD &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this, we can focus on the fact this allows a narrower representation.

EVT VT = N->getValueType(0);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i32 ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to be fixed. Or i32.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep it's overfitting a workload.

Op0 = Op0.getOperand(0);
Op1 = Op1.getOperand(0);

// Inputs should be i8 vectors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or i8.

@preames
Copy link
Collaborator

preames commented Mar 25, 2024

Alternative approach: #86592

@topperc topperc closed this May 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants