From c4fc70691c4db460a39036b091d16639a42ca20a Mon Sep 17 00:00:00 2001 From: David Green Date: Wed, 26 Nov 2025 10:31:12 +0000 Subject: [PATCH 1/2] [AArch64] Combine vector add(trunc(shift)) This adds a combine for add(trunc(ashr(A, C)), trunc(lshr(A, BW-1))), with C >= BW -> X = trunc(ashr(A, C)); add(x, lshr(X, BW-1) The original converts into ashr+lshr+xtn+xtn+add. The second becomes ashr+xtn+usra. The first form has less total latency due to more parallelism, but more micro-ops and seems to be slower in practice. --- .../Target/AArch64/AArch64ISelLowering.cpp | 35 ++++++++- llvm/test/CodeGen/AArch64/addtruncshift.ll | 72 ++++++++++++------- 2 files changed, 79 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 83ce39fa314d1..d4f489a7fc171 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -105,6 +105,7 @@ #include using namespace llvm; +using namespace llvm::SDPatternMatch; #define DEBUG_TYPE "aarch64-lower" @@ -22586,6 +22587,37 @@ static SDValue performSubWithBorrowCombine(SDNode *N, SelectionDAG &DAG) { Flags); } +// add(trunc(ashr(A, C)), trunc(lshr(A, BW-1))), with C >= BW +// -> +// X = trunc(ashr(A, C)); add(x, lshr(X, BW-1) +// The original converts into ashr+lshr+xtn+xtn+add. The second becomes +// ashr+xtn+usra. The first form has less total latency due to more parallelism, +// but more micro-ops and seems to be slower in practice. +static SDValue performAddTruncShiftCombine(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + if (VT != MVT::v2i32 && VT != MVT::v4i16 && VT != MVT::v8i8) + return SDValue(); + + SDValue AShr, LShr; + if (!sd_match(N, m_Add(m_Trunc(m_Value(AShr)), m_Trunc(m_Value(LShr))))) + return SDValue(); + if (AShr.getOpcode() != AArch64ISD::VASHR) + std::swap(AShr, LShr); + if (AShr.getOpcode() != AArch64ISD::VASHR || + LShr.getOpcode() != AArch64ISD::VLSHR || + AShr.getOperand(0) != LShr.getOperand(0) || + AShr.getConstantOperandVal(1) < VT.getScalarSizeInBits() || + LShr.getConstantOperandVal(1) != VT.getScalarSizeInBits() * 2 - 1) + return SDValue(); + + SDLoc DL(N); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, AShr); + SDValue Shift = DAG.getNode( + AArch64ISD::VLSHR, DL, VT, Trunc, + DAG.getTargetConstant(VT.getScalarSizeInBits() - 1, DL, MVT::i32)); + return DAG.getNode(ISD::ADD, DL, VT, Trunc, Shift); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // Try to change sum of two reductions. @@ -22609,6 +22641,8 @@ static SDValue performAddSubCombine(SDNode *N, return Val; if (SDValue Val = performSubWithBorrowCombine(N, DCI.DAG)) return Val; + if (SDValue Val = performAddTruncShiftCombine(N, DCI.DAG)) + return Val; if (SDValue Val = performExtBinopLoadFold(N, DCI.DAG)) return Val; @@ -28116,7 +28150,6 @@ static SDValue performRNDRCombine(SDNode *N, SelectionDAG &DAG) { static SDValue performCTPOPCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - using namespace llvm::SDPatternMatch; if (!DCI.isBeforeLegalize()) return SDValue(); diff --git a/llvm/test/CodeGen/AArch64/addtruncshift.ll b/llvm/test/CodeGen/AArch64/addtruncshift.ll index f3af50ec8cf3e..6dbe0b3d80b9a 100644 --- a/llvm/test/CodeGen/AArch64/addtruncshift.ll +++ b/llvm/test/CodeGen/AArch64/addtruncshift.ll @@ -3,14 +3,21 @@ ; RUN: llc -mtriple=aarch64-none-elf -global-isel < %s | FileCheck %s --check-prefixes=CHECK,CHECK-GI define <2 x i32> @test_v2i64(<2 x i64> %n) { -; CHECK-LABEL: test_v2i64: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushr v1.2d, v0.2d, #63 -; CHECK-NEXT: sshr v0.2d, v0.2d, #35 -; CHECK-NEXT: xtn v1.2s, v1.2d -; CHECK-NEXT: xtn v0.2s, v0.2d -; CHECK-NEXT: add v0.2s, v1.2s, v0.2s -; CHECK-NEXT: ret +; CHECK-SD-LABEL: test_v2i64: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: sshr v0.2d, v0.2d, #35 +; CHECK-SD-NEXT: xtn v0.2s, v0.2d +; CHECK-SD-NEXT: usra v0.2s, v0.2s, #31 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: test_v2i64: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: ushr v1.2d, v0.2d, #63 +; CHECK-GI-NEXT: sshr v0.2d, v0.2d, #35 +; CHECK-GI-NEXT: xtn v1.2s, v1.2d +; CHECK-GI-NEXT: xtn v0.2s, v0.2d +; CHECK-GI-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-GI-NEXT: ret entry: %shr = lshr <2 x i64> %n, splat (i64 63) %vmovn.i4 = trunc nuw nsw <2 x i64> %shr to <2 x i32> @@ -21,14 +28,21 @@ entry: } define <4 x i16> @test_v4i32(<4 x i32> %n) { -; CHECK-LABEL: test_v4i32: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushr v1.4s, v0.4s, #31 -; CHECK-NEXT: sshr v0.4s, v0.4s, #17 -; CHECK-NEXT: xtn v1.4h, v1.4s -; CHECK-NEXT: xtn v0.4h, v0.4s -; CHECK-NEXT: add v0.4h, v1.4h, v0.4h -; CHECK-NEXT: ret +; CHECK-SD-LABEL: test_v4i32: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: sshr v0.4s, v0.4s, #17 +; CHECK-SD-NEXT: xtn v0.4h, v0.4s +; CHECK-SD-NEXT: usra v0.4h, v0.4h, #15 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: test_v4i32: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: ushr v1.4s, v0.4s, #31 +; CHECK-GI-NEXT: sshr v0.4s, v0.4s, #17 +; CHECK-GI-NEXT: xtn v1.4h, v1.4s +; CHECK-GI-NEXT: xtn v0.4h, v0.4s +; CHECK-GI-NEXT: add v0.4h, v1.4h, v0.4h +; CHECK-GI-NEXT: ret entry: %shr = lshr <4 x i32> %n, splat (i32 31) %vmovn.i4 = trunc nuw nsw <4 x i32> %shr to <4 x i16> @@ -39,14 +53,21 @@ entry: } define <8 x i8> @test_v8i16(<8 x i16> %n) { -; CHECK-LABEL: test_v8i16: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushr v1.8h, v0.8h, #15 -; CHECK-NEXT: sshr v0.8h, v0.8h, #9 -; CHECK-NEXT: xtn v1.8b, v1.8h -; CHECK-NEXT: xtn v0.8b, v0.8h -; CHECK-NEXT: add v0.8b, v1.8b, v0.8b -; CHECK-NEXT: ret +; CHECK-SD-LABEL: test_v8i16: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: sshr v0.8h, v0.8h, #9 +; CHECK-SD-NEXT: xtn v0.8b, v0.8h +; CHECK-SD-NEXT: usra v0.8b, v0.8b, #7 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: test_v8i16: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: ushr v1.8h, v0.8h, #15 +; CHECK-GI-NEXT: sshr v0.8h, v0.8h, #9 +; CHECK-GI-NEXT: xtn v1.8b, v1.8h +; CHECK-GI-NEXT: xtn v0.8b, v0.8h +; CHECK-GI-NEXT: add v0.8b, v1.8b, v0.8b +; CHECK-GI-NEXT: ret entry: %shr = lshr <8 x i16> %n, splat (i16 15) %vmovn.i4 = trunc nuw nsw <8 x i16> %shr to <8 x i8> @@ -91,6 +112,3 @@ entry: ret <2 x i32> %add } -;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: -; CHECK-GI: {{.*}} -; CHECK-SD: {{.*}} From d01e4cc2d0a2ed9d051280ad588eed31a81f920b Mon Sep 17 00:00:00 2001 From: David Green Date: Wed, 26 Nov 2025 13:51:03 +0000 Subject: [PATCH 2/2] Fix clang build --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index d4f489a7fc171..d4099b56b6d6e 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -105,7 +105,6 @@ #include using namespace llvm; -using namespace llvm::SDPatternMatch; #define DEBUG_TYPE "aarch64-lower" @@ -22594,6 +22593,7 @@ static SDValue performSubWithBorrowCombine(SDNode *N, SelectionDAG &DAG) { // ashr+xtn+usra. The first form has less total latency due to more parallelism, // but more micro-ops and seems to be slower in practice. static SDValue performAddTruncShiftCombine(SDNode *N, SelectionDAG &DAG) { + using namespace llvm::SDPatternMatch; EVT VT = N->getValueType(0); if (VT != MVT::v2i32 && VT != MVT::v4i16 && VT != MVT::v8i8) return SDValue(); @@ -28150,6 +28150,7 @@ static SDValue performRNDRCombine(SDNode *N, SelectionDAG &DAG) { static SDValue performCTPOPCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + using namespace llvm::SDPatternMatch; if (!DCI.isBeforeLegalize()) return SDValue();