Skip to content

Commit

Permalink
[RISCV] Fold (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg …
Browse files Browse the repository at this point in the history
…(add X, C1), C)

Similar for a subtract with a constant left hand side.

(sra (add (shl X, 32), C1<<32), 32) is the canonical IR from InstCombine
for (sext (add (trunc X to i32), 32) to i32).

For RISCV, we should lower this as addiw which means turning it into
(sext_inreg (add X, C1)).

There is an existing DAG combine to convert back to (sext (add (trunc X
to i32), 32) to i32), but it requires isTruncateFree to return true
and for i32 to be a legal type as it used sign_extend and truncate
nodes. So that doesn't work for RISCV.

If the outer sra happens be used by a shl by constant, it will be
folded and the shift amount of the sra will be changed before we
can do our own DAG combine. This requires us to match the more
general pattern and restore the shl.

I had wanted to do this as a separate (add (shl X, 32), C1<<32) ->
(shl (add X, C1), 32) combine, but that hit an infinite loop for some
values of C1.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D128869
  • Loading branch information
topperc committed Jun 30, 2022
1 parent 9ace5af commit 51d6729
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 33 deletions.
64 changes: 55 additions & 9 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -8532,28 +8532,74 @@ static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {

// Combine (sra (shl X, 32), 32 - C) -> (shl (sext_inreg X, i32), C)
// FIXME: Should this be a generic combine? There's a similar combine on X86.
//
// Also try these folds where an add or sub is in the middle.
// (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C)
// (sra (sub C1, (shl X, 32)), 32 - C) -> (shl (sext_inreg (sub C1, X), C)
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");

if (N->getValueType(0) != MVT::i64 || !Subtarget.is64Bit())
return SDValue();

auto *C = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!C || C->getZExtValue() >= 32)
auto *ShAmtC = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!ShAmtC || ShAmtC->getZExtValue() > 32)
return SDValue();

SDValue N0 = N->getOperand(0);
if (N0.getOpcode() != ISD::SHL || !N0.hasOneUse() ||
!isa<ConstantSDNode>(N0.getOperand(1)) ||
N0.getConstantOperandVal(1) != 32)

SDValue Shl;
ConstantSDNode *AddC = nullptr;

// We might have an ADD or SUB between the SRA and SHL.
bool IsAdd = N0.getOpcode() == ISD::ADD;
if ((IsAdd || N0.getOpcode() == ISD::SUB)) {
if (!N0.hasOneUse())
return SDValue();
// Other operand needs to be a constant we can modify.
AddC = dyn_cast<ConstantSDNode>(N0.getOperand(IsAdd ? 1 : 0));
if (!AddC)
return SDValue();

// AddC needs to have at least 32 trailing zeros.
if (AddC->getAPIntValue().countTrailingZeros() < 32)
return SDValue();

Shl = N0.getOperand(IsAdd ? 0 : 1);
} else {
// Not an ADD or SUB.
Shl = N0;
}

// Look for a shift left by 32.
if (Shl.getOpcode() != ISD::SHL || !Shl.hasOneUse() ||
!isa<ConstantSDNode>(Shl.getOperand(1)) ||
Shl.getConstantOperandVal(1) != 32)
return SDValue();

SDLoc DL(N);
SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64,
N0.getOperand(0), DAG.getValueType(MVT::i32));
return DAG.getNode(ISD::SHL, DL, MVT::i64, SExt,
DAG.getConstant(32 - C->getZExtValue(), DL, MVT::i64));
SDValue In = Shl.getOperand(0);

// If we looked through an ADD or SUB, we need to rebuild it with the shifted
// constant.
if (AddC) {
SDValue ShiftedAddC =
DAG.getConstant(AddC->getAPIntValue().lshr(32), DL, MVT::i64);
if (IsAdd)
In = DAG.getNode(ISD::ADD, DL, MVT::i64, In, ShiftedAddC);
else
In = DAG.getNode(ISD::SUB, DL, MVT::i64, ShiftedAddC, In);
}

SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, In,
DAG.getValueType(MVT::i32));
if (ShAmtC->getZExtValue() == 32)
return SExt;

return DAG.getNode(
ISD::SHL, DL, MVT::i64, SExt,
DAG.getConstant(32 - ShAmtC->getZExtValue(), DL, MVT::i64));
}

SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
Expand Down
35 changes: 11 additions & 24 deletions llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll
Expand Up @@ -84,11 +84,7 @@ define i64 @test6(i32 signext %a, i32 signext %b) nounwind {
define i64 @test7(i32* %0, i64 %1) {
; RV64I-LABEL: test7:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a0, a1, 32
; RV64I-NEXT: li a1, 1
; RV64I-NEXT: slli a1, a1, 32
; RV64I-NEXT: add a0, a0, a1
; RV64I-NEXT: srai a0, a0, 32
; RV64I-NEXT: addiw a0, a1, 1
; RV64I-NEXT: ret
%3 = shl i64 %1, 32
%4 = add i64 %3, 4294967296
Expand All @@ -102,11 +98,8 @@ define i64 @test7(i32* %0, i64 %1) {
define i64 @test8(i32* %0, i64 %1) {
; RV64I-LABEL: test8:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a0, a1, 32
; RV64I-NEXT: li a1, 1
; RV64I-NEXT: slli a1, a1, 32
; RV64I-NEXT: sub a0, a1, a0
; RV64I-NEXT: srai a0, a0, 32
; RV64I-NEXT: li a0, 1
; RV64I-NEXT: subw a0, a0, a1
; RV64I-NEXT: ret
%3 = mul i64 %1, -4294967296
%4 = add i64 %3, 4294967296
Expand All @@ -119,11 +112,10 @@ define i64 @test8(i32* %0, i64 %1) {
define signext i32 @test9(i32* %0, i64 %1) {
; RV64I-LABEL: test9:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a1, a1, 32
; RV64I-NEXT: lui a2, 4097
; RV64I-NEXT: slli a2, a2, 20
; RV64I-NEXT: add a1, a1, a2
; RV64I-NEXT: srai a1, a1, 30
; RV64I-NEXT: lui a2, 1
; RV64I-NEXT: addiw a2, a2, 1
; RV64I-NEXT: addw a1, a1, a2
; RV64I-NEXT: slli a1, a1, 2
; RV64I-NEXT: add a0, a0, a1
; RV64I-NEXT: lw a0, 0(a0)
; RV64I-NEXT: ret
Expand All @@ -140,12 +132,10 @@ define signext i32 @test9(i32* %0, i64 %1) {
define signext i32 @test10(i32* %0, i64 %1) {
; RV64I-LABEL: test10:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a1, a1, 32
; RV64I-NEXT: lui a2, 30141
; RV64I-NEXT: addiw a2, a2, -747
; RV64I-NEXT: slli a2, a2, 32
; RV64I-NEXT: sub a1, a2, a1
; RV64I-NEXT: srai a1, a1, 30
; RV64I-NEXT: subw a1, a2, a1
; RV64I-NEXT: slli a1, a1, 2
; RV64I-NEXT: add a0, a0, a1
; RV64I-NEXT: lw a0, 0(a0)
; RV64I-NEXT: ret
Expand All @@ -160,11 +150,8 @@ define signext i32 @test10(i32* %0, i64 %1) {
define i64 @test11(i32* %0, i64 %1) {
; RV64I-LABEL: test11:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a0, a1, 32
; RV64I-NEXT: li a1, -1
; RV64I-NEXT: slli a1, a1, 63
; RV64I-NEXT: sub a0, a1, a0
; RV64I-NEXT: srai a0, a0, 32
; RV64I-NEXT: lui a0, 524288
; RV64I-NEXT: subw a0, a0, a1
; RV64I-NEXT: ret
%3 = mul i64 %1, -4294967296
%4 = add i64 %3, 9223372036854775808 ;0x8000'0000'0000'0000
Expand Down

0 comments on commit 51d6729

Please sign in to comment.