From 1d0ea8b887ee3a8b4dd18fb07e072e0ab549a8b2 Mon Sep 17 00:00:00 2001 From: rez5427 Date: Thu, 6 Nov 2025 22:04:01 +0800 Subject: [PATCH 1/6] [DAGCombiner] add sra xor sra pattern fold --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 22 +++++++++++++ llvm/test/CodeGen/RISCV/sra-xor-sra.ll | 32 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 llvm/test/CodeGen/RISCV/sra-xor-sra.ll diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index d2ea6525e1116..b7e195d44b031 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10968,6 +10968,28 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { } } + // fold (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3) + if (N0.getOpcode() == ISD::XOR && N0.hasOneUse() && + isAllOnesConstant(N0.getOperand(1))) { + SDValue Inner = N0.getOperand(0); + if (Inner.getOpcode() == ISD::SRA && N1C) { + if (ConstantSDNode *InnerShiftAmt = + isConstOrConstSplat(Inner.getOperand(1))) { + APInt c1 = InnerShiftAmt->getAPIntValue(); + APInt c2 = N1C->getAPIntValue(); + zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + APInt Sum = c1 + c2; + unsigned ShiftSum = + Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue(); + SDValue NewShift = + DAG.getNode(ISD::SRA, DL, VT, Inner.getOperand(0), + DAG.getConstant(ShiftSum, DL, N1.getValueType())); + return DAG.getNode(ISD::XOR, DL, VT, NewShift, + DAG.getAllOnesConstant(DL, VT)); + } + } + } + // fold (sra (shl X, m), (sub result_size, n)) // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for // result_size - n != m. diff --git a/llvm/test/CodeGen/RISCV/sra-xor-sra.ll b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll new file mode 100644 index 0000000000000..b04f0a29d07f3 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll @@ -0,0 +1,32 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s | FileCheck %s + +; Test folding of: (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3) +; Original motivating example: should merge sra+sra across xor +define i16 @not_invert_signbit_splat_mask(i8 %x, i16 %y) { +; CHECK-LABEL: not_invert_signbit_splat_mask: +; CHECK: # %bb.0: +; CHECK-NEXT: slli a0, a0, 56 +; CHECK-NEXT: srai a0, a0, 62 +; CHECK-NEXT: not a0, a0 +; CHECK-NEXT: and a0, a0, a1 +; CHECK-NEXT: ret + %a = ashr i8 %x, 6 + %n = xor i8 %a, -1 + %s = sext i8 %n to i16 + %r = and i16 %s, %y + ret i16 %r +} + +; Edge case +define i16 @sra_xor_sra_overflow(i8 %x, i16 %y) { +; CHECK-LABEL: sra_xor_sra_overflow: +; CHECK: # %bb.0: +; CHECK-NEXT: li a0, 0 +; CHECK-NEXT: ret + %a = ashr i8 %x, 10 + %n = xor i8 %a, -1 + %s = sext i8 %n to i16 + %r = and i16 %s, %y + ret i16 %r +} From 9f2e1e381b6b0e75fe17fb924c33f1f6ec13b7a8 Mon Sep 17 00:00:00 2001 From: rez5427 Date: Thu, 6 Nov 2025 22:42:06 +0800 Subject: [PATCH 2/6] Use sd_match --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index b7e195d44b031..e7ab0a94c34fb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10968,26 +10968,22 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { } } - // fold (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3) - if (N0.getOpcode() == ISD::XOR && N0.hasOneUse() && - isAllOnesConstant(N0.getOperand(1))) { - SDValue Inner = N0.getOperand(0); - if (Inner.getOpcode() == ISD::SRA && N1C) { - if (ConstantSDNode *InnerShiftAmt = - isConstOrConstSplat(Inner.getOperand(1))) { - APInt c1 = InnerShiftAmt->getAPIntValue(); - APInt c2 = N1C->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - APInt Sum = c1 + c2; - unsigned ShiftSum = - Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue(); - SDValue NewShift = - DAG.getNode(ISD::SRA, DL, VT, Inner.getOperand(0), - DAG.getConstant(ShiftSum, DL, N1.getValueType())); - return DAG.getNode(ISD::XOR, DL, VT, NewShift, - DAG.getAllOnesConstant(DL, VT)); - } - } + // fold (sra (xor (sra x, c1), -1), c2) -> (xor (sra x, c1+c2), -1) + // This allows merging two arithmetic shifts even when there's a NOT in + // between. + SDValue X; + APInt C1, C2; + if (sd_match(N0, m_OneUse(m_Xor(m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1))), + m_AllOnes()))) && + sd_match(N1, m_ConstInt(C2))) { + zeroExtendToMatch(C1, C2, 1 /* Overflow Bit */); + APInt Sum = C1 + C2; + unsigned ShiftSum = + Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue(); + SDValue NewShift = DAG.getNode( + ISD::SRA, DL, VT, X, DAG.getConstant(ShiftSum, DL, N1.getValueType())); + return DAG.getNode(ISD::XOR, DL, VT, NewShift, + DAG.getAllOnesConstant(DL, VT)); } // fold (sra (shl X, m), (sub result_size, n)) From 66e06e70c2f204378ef97c2f5b6e45f3428297fc Mon Sep 17 00:00:00 2001 From: rez5427 Date: Fri, 7 Nov 2025 08:31:51 +0800 Subject: [PATCH 3/6] Use m_Not --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index e7ab0a94c34fb..abf790280c408 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10968,14 +10968,15 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { } } - // fold (sra (xor (sra x, c1), -1), c2) -> (xor (sra x, c1+c2), -1) + // fold (sra (xor (sra x, c1), -1), c2) -> (xor (sra x, c3), -1) // This allows merging two arithmetic shifts even when there's a NOT in // between. SDValue X; - APInt C1, C2; - if (sd_match(N0, m_OneUse(m_Xor(m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1))), - m_AllOnes()))) && - sd_match(N1, m_ConstInt(C2))) { + APInt C1; + if (sd_match(N0, + m_OneUse(m_Not(m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1)))))) && + N1C) { + APInt C2 = N1C->getAPIntValue(); zeroExtendToMatch(C1, C2, 1 /* Overflow Bit */); APInt Sum = C1 + C2; unsigned ShiftSum = From 17eeb2c25095213cc02277283e4fa943c5c14f96 Mon Sep 17 00:00:00 2001 From: rez5427 Date: Fri, 7 Nov 2025 13:16:28 +0800 Subject: [PATCH 4/6] Address code review feedback --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index abf790280c408..34d7e72f72435 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10973,18 +10973,16 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // between. SDValue X; APInt C1; - if (sd_match(N0, - m_OneUse(m_Not(m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1)))))) && - N1C) { + if (N1C && sd_match(N0, m_OneUse(m_Not( + m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1))))))) { APInt C2 = N1C->getAPIntValue(); - zeroExtendToMatch(C1, C2, 1 /* Overflow Bit */); + zeroExtendToMatch(C1, C2, /*OverflowBit=*/1); APInt Sum = C1 + C2; unsigned ShiftSum = Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue(); SDValue NewShift = DAG.getNode( ISD::SRA, DL, VT, X, DAG.getConstant(ShiftSum, DL, N1.getValueType())); - return DAG.getNode(ISD::XOR, DL, VT, NewShift, - DAG.getAllOnesConstant(DL, VT)); + return DAG.getNOT(DL, NewShift, VT); } // fold (sra (shl X, m), (sub result_size, n)) From 6cdf7352b16ada00c19551c48b18e99fc8f8ad24 Mon Sep 17 00:00:00 2001 From: rez5427 <785369607@qq.com> Date: Sat, 8 Nov 2025 13:44:54 +0800 Subject: [PATCH 5/6] Use getShiftAmountConstant --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 34d7e72f72435..c6c77c7417a97 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10981,7 +10981,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { unsigned ShiftSum = Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue(); SDValue NewShift = DAG.getNode( - ISD::SRA, DL, VT, X, DAG.getConstant(ShiftSum, DL, N1.getValueType())); + ISD::SRA, DL, VT, X, DAG.getShiftAmountConstant(ShiftSum, VT, DL)); return DAG.getNOT(DL, NewShift, VT); } From ff9b44fd03adb48c4e01570165f5c99fba2876cf Mon Sep 17 00:00:00 2001 From: rez5427 Date: Mon, 10 Nov 2025 21:49:00 +0800 Subject: [PATCH 6/6] Use getLimitedValue --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index c6c77c7417a97..9601206a70acd 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10976,10 +10976,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { if (N1C && sd_match(N0, m_OneUse(m_Not( m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1))))))) { APInt C2 = N1C->getAPIntValue(); - zeroExtendToMatch(C1, C2, /*OverflowBit=*/1); + zeroExtendToMatch(C1, C2, 1 /* Overflow Bit */); APInt Sum = C1 + C2; - unsigned ShiftSum = - Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue(); + unsigned ShiftSum = Sum.getLimitedValue(OpSizeInBits - 1); SDValue NewShift = DAG.getNode( ISD::SRA, DL, VT, X, DAG.getShiftAmountConstant(ShiftSum, VT, DL)); return DAG.getNOT(DL, NewShift, VT);