-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[DAGCombiner] Add sra-xor-sra pattern fold #166777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-selectiondag Author: guan jian (rez5427) ChangesAdd alive2: https://alive2.llvm.org/ce/z/yxRQf9 Full diff: https://github.com/llvm/llvm-project/pull/166777.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d2ea6525e1116..7ab460bef019e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10968,6 +10968,26 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
}
}
+ // fold (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)
+ if (N0.getOpcode() == ISD::XOR && N0.hasOneUse() &&
+ isAllOnesConstant(N0.getOperand(1))) {
+ SDValue Inner = N0.getOperand(0);
+ if (Inner.getOpcode() == ISD::SRA && N1C) {
+ if (ConstantSDNode *InnerShiftAmt = isConstOrConstSplat(Inner.getOperand(1))) {
+ APInt c1 = InnerShiftAmt->getAPIntValue();
+ APInt c2 = N1C->getAPIntValue();
+ zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+ APInt Sum = c1 + c2;
+ unsigned ShiftSum =
+ Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
+ SDValue NewShift = DAG.getNode(ISD::SRA, DL, VT, Inner.getOperand(0),
+ DAG.getConstant(ShiftSum, DL, N1.getValueType()));
+ return DAG.getNode(ISD::XOR, DL, VT, NewShift,
+ DAG.getAllOnesConstant(DL, VT));
+ }
+ }
+ }
+
// fold (sra (shl X, m), (sub result_size, n))
// -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
// result_size - n != m.
diff --git a/llvm/test/CodeGen/RISCV/sra-xor-sra.ll b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll
new file mode 100644
index 0000000000000..b04f0a29d07f3
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s | FileCheck %s
+
+; Test folding of: (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)
+; Original motivating example: should merge sra+sra across xor
+define i16 @not_invert_signbit_splat_mask(i8 %x, i16 %y) {
+; CHECK-LABEL: not_invert_signbit_splat_mask:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a0, a0, 56
+; CHECK-NEXT: srai a0, a0, 62
+; CHECK-NEXT: not a0, a0
+; CHECK-NEXT: and a0, a0, a1
+; CHECK-NEXT: ret
+ %a = ashr i8 %x, 6
+ %n = xor i8 %a, -1
+ %s = sext i8 %n to i16
+ %r = and i16 %s, %y
+ ret i16 %r
+}
+
+; Edge case
+define i16 @sra_xor_sra_overflow(i8 %x, i16 %y) {
+; CHECK-LABEL: sra_xor_sra_overflow:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 0
+; CHECK-NEXT: ret
+ %a = ashr i8 %x, 10
+ %n = xor i8 %a, -1
+ %s = sext i8 %n to i16
+ %r = and i16 %s, %y
+ ret i16 %r
+}
|
|
@llvm/pr-subscribers-backend-risc-v Author: guan jian (rez5427) ChangesAdd alive2: https://alive2.llvm.org/ce/z/yxRQf9 Full diff: https://github.com/llvm/llvm-project/pull/166777.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d2ea6525e1116..7ab460bef019e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10968,6 +10968,26 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
}
}
+ // fold (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)
+ if (N0.getOpcode() == ISD::XOR && N0.hasOneUse() &&
+ isAllOnesConstant(N0.getOperand(1))) {
+ SDValue Inner = N0.getOperand(0);
+ if (Inner.getOpcode() == ISD::SRA && N1C) {
+ if (ConstantSDNode *InnerShiftAmt = isConstOrConstSplat(Inner.getOperand(1))) {
+ APInt c1 = InnerShiftAmt->getAPIntValue();
+ APInt c2 = N1C->getAPIntValue();
+ zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+ APInt Sum = c1 + c2;
+ unsigned ShiftSum =
+ Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
+ SDValue NewShift = DAG.getNode(ISD::SRA, DL, VT, Inner.getOperand(0),
+ DAG.getConstant(ShiftSum, DL, N1.getValueType()));
+ return DAG.getNode(ISD::XOR, DL, VT, NewShift,
+ DAG.getAllOnesConstant(DL, VT));
+ }
+ }
+ }
+
// fold (sra (shl X, m), (sub result_size, n))
// -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
// result_size - n != m.
diff --git a/llvm/test/CodeGen/RISCV/sra-xor-sra.ll b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll
new file mode 100644
index 0000000000000..b04f0a29d07f3
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s | FileCheck %s
+
+; Test folding of: (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)
+; Original motivating example: should merge sra+sra across xor
+define i16 @not_invert_signbit_splat_mask(i8 %x, i16 %y) {
+; CHECK-LABEL: not_invert_signbit_splat_mask:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a0, a0, 56
+; CHECK-NEXT: srai a0, a0, 62
+; CHECK-NEXT: not a0, a0
+; CHECK-NEXT: and a0, a0, a1
+; CHECK-NEXT: ret
+ %a = ashr i8 %x, 6
+ %n = xor i8 %a, -1
+ %s = sext i8 %n to i16
+ %r = and i16 %s, %y
+ ret i16 %r
+}
+
+; Edge case
+define i16 @sra_xor_sra_overflow(i8 %x, i16 %y) {
+; CHECK-LABEL: sra_xor_sra_overflow:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 0
+; CHECK-NEXT: ret
+ %a = ashr i8 %x, 10
+ %n = xor i8 %a, -1
+ %s = sext i8 %n to i16
+ %r = and i16 %s, %y
+ ret i16 %r
+}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
Sorry about the commit suggestion won't work, really clicked like a hundred times. |
RKSimon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Add
fold (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)The IR like this:
llvm will produce:
56 and 6 can be add up
alive2: https://alive2.llvm.org/ce/z/yxRQf9