Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Move strength reduction of mul X, 3/5/9*2^N to combine #89966

Merged
merged 4 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13426,10 +13426,27 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
if (MulAmt % Divisor != 0)
continue;
uint64_t MulAmt2 = MulAmt / Divisor;
// 3/5/9 * 2^N -> shXadd (sll X, C), (sll X, C)
// Matched in tablegen, avoid perturbing patterns.
if (isPowerOf2_64(MulAmt2))
return SDValue();
// 3/5/9 * 2^N -> shl (shXadd X, X), N
if (isPowerOf2_64(MulAmt2)) {
SDLoc DL(N);
SDValue X = N->getOperand(0);
// Put the shift first if we can fold a zext into the
// shift forming a slli.uw.
if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
Comment on lines +13435 to +13436
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use SDPatternMatch?

Suggested change
if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
if (sd_match(X, m_And(m_Value(), m_SpecificInt(0xffffffff)))) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to leave this as is, and then doing a single NFC to replace several usage examples.

SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), Shl);
}
// Otherwise, put rhe shl second so that it can fold with following
// instructions (e.g. sext or add).
SDValue Mul359 =
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
return DAG.getNode(ISD::SHL, DL, VT, Mul359,
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
}

// 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
Expand Down
29 changes: 0 additions & 29 deletions llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
Original file line number Diff line number Diff line change
Expand Up @@ -549,40 +549,11 @@ def : Pat<(add_non_imm12 sh2add_op:$rs1, (XLenVT GPR:$rs2)),
def : Pat<(add_non_imm12 sh3add_op:$rs1, (XLenVT GPR:$rs2)),
(TH_ADDSL GPR:$rs2, sh3add_op:$rs1, 3)>;

def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
(TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 1)), 1)>;
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 10)), GPR:$rs2),
(TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 2)), 1)>;
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 18)), GPR:$rs2),
(TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 3)), 1)>;
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 12)), GPR:$rs2),
(TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 1)), 2)>;
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 20)), GPR:$rs2),
(TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 2)), 2)>;
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 36)), GPR:$rs2),
(TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 3)), 2)>;
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 24)), GPR:$rs2),
(TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 1)), 3)>;
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 40)), GPR:$rs2),
(TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 2)), 3)>;
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 72)), GPR:$rs2),
(TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 3)), 3)>;

def : Pat<(add (XLenVT GPR:$r), CSImm12MulBy4:$i),
(TH_ADDSL GPR:$r, (XLenVT (ADDI (XLenVT X0), (SimmShiftRightBy2XForm CSImm12MulBy4:$i))), 2)>;
def : Pat<(add (XLenVT GPR:$r), CSImm12MulBy8:$i),
(TH_ADDSL GPR:$r, (XLenVT (ADDI (XLenVT X0), (SimmShiftRightBy3XForm CSImm12MulBy8:$i))), 3)>;

def : Pat<(mul (XLenVT GPR:$r), C3LeftShift:$i),
(SLLI (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 1)),
(TrailingZeros C3LeftShift:$i))>;
def : Pat<(mul (XLenVT GPR:$r), C5LeftShift:$i),
(SLLI (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)),
(TrailingZeros C5LeftShift:$i))>;
def : Pat<(mul (XLenVT GPR:$r), C9LeftShift:$i),
(SLLI (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)),
(TrailingZeros C9LeftShift:$i))>;

def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)),
(SLLI (XLenVT (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)),
(XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 2)), 3)>;
Expand Down
74 changes: 0 additions & 74 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
Original file line number Diff line number Diff line change
Expand Up @@ -173,42 +173,6 @@ def BCLRIANDIMaskLow : SDNodeXForm<imm, [{
SDLoc(N), N->getValueType(0));
}]>;

def C3LeftShift : PatLeaf<(imm), [{
uint64_t C = N->getZExtValue();
return C > 3 && (C >> llvm::countr_zero(C)) == 3;
}]>;

def C5LeftShift : PatLeaf<(imm), [{
uint64_t C = N->getZExtValue();
return C > 5 && (C >> llvm::countr_zero(C)) == 5;
}]>;

def C9LeftShift : PatLeaf<(imm), [{
uint64_t C = N->getZExtValue();
return C > 9 && (C >> llvm::countr_zero(C)) == 9;
}]>;

// Constant of the form (3 << C) where C is less than 32.
def C3LeftShiftUW : PatLeaf<(imm), [{
uint64_t C = N->getZExtValue();
unsigned Shift = llvm::countr_zero(C);
return 1 <= Shift && Shift < 32 && (C >> Shift) == 3;
}]>;

// Constant of the form (5 << C) where C is less than 32.
def C5LeftShiftUW : PatLeaf<(imm), [{
uint64_t C = N->getZExtValue();
unsigned Shift = llvm::countr_zero(C);
return 1 <= Shift && Shift < 32 && (C >> Shift) == 5;
}]>;

// Constant of the form (9 << C) where C is less than 32.
def C9LeftShiftUW : PatLeaf<(imm), [{
uint64_t C = N->getZExtValue();
unsigned Shift = llvm::countr_zero(C);
return 1 <= Shift && Shift < 32 && (C >> Shift) == 9;
}]>;

def CSImm12MulBy4 : PatLeaf<(imm), [{
if (!N->hasOneUse())
return false;
Expand Down Expand Up @@ -693,42 +657,13 @@ foreach i = {1,2,3} in {
(shxadd pat:$rs1, GPR:$rs2)>;
}

def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
(SH1ADD (XLenVT (SH1ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 10)), GPR:$rs2),
(SH1ADD (XLenVT (SH2ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 18)), GPR:$rs2),
(SH1ADD (XLenVT (SH3ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 12)), GPR:$rs2),
(SH2ADD (XLenVT (SH1ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 20)), GPR:$rs2),
(SH2ADD (XLenVT (SH2ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 36)), GPR:$rs2),
(SH2ADD (XLenVT (SH3ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 24)), GPR:$rs2),
(SH3ADD (XLenVT (SH1ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 40)), GPR:$rs2),
(SH3ADD (XLenVT (SH2ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 72)), GPR:$rs2),
(SH3ADD (XLenVT (SH3ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;

def : Pat<(add_like (XLenVT GPR:$r), CSImm12MulBy4:$i),
(SH2ADD (XLenVT (ADDI (XLenVT X0), (SimmShiftRightBy2XForm CSImm12MulBy4:$i))),
GPR:$r)>;
def : Pat<(add_like (XLenVT GPR:$r), CSImm12MulBy8:$i),
(SH3ADD (XLenVT (ADDI (XLenVT X0), (SimmShiftRightBy3XForm CSImm12MulBy8:$i))),
GPR:$r)>;

def : Pat<(mul (XLenVT GPR:$r), C3LeftShift:$i),
(SLLI (XLenVT (SH1ADD GPR:$r, GPR:$r)),
(TrailingZeros C3LeftShift:$i))>;
def : Pat<(mul (XLenVT GPR:$r), C5LeftShift:$i),
(SLLI (XLenVT (SH2ADD GPR:$r, GPR:$r)),
(TrailingZeros C5LeftShift:$i))>;
def : Pat<(mul (XLenVT GPR:$r), C9LeftShift:$i),
(SLLI (XLenVT (SH3ADD GPR:$r, GPR:$r)),
(TrailingZeros C9LeftShift:$i))>;

} // Predicates = [HasStdExtZba]

let Predicates = [HasStdExtZba, IsRV64] in {
Expand Down Expand Up @@ -780,15 +715,6 @@ def : Pat<(i64 (add_like_non_imm12 (and GPR:$rs1, 0x3FFFFFFFC), (XLenVT GPR:$rs2
def : Pat<(i64 (add_like_non_imm12 (and GPR:$rs1, 0x7FFFFFFF8), (XLenVT GPR:$rs2))),
(SH3ADD_UW (XLenVT (SRLI GPR:$rs1, 3)), GPR:$rs2)>;

def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C3LeftShiftUW:$i)),
(SH1ADD (XLenVT (SLLI_UW GPR:$r, (TrailingZeros C3LeftShiftUW:$i))),
(XLenVT (SLLI_UW GPR:$r, (TrailingZeros C3LeftShiftUW:$i))))>;
def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C5LeftShiftUW:$i)),
(SH2ADD (XLenVT (SLLI_UW GPR:$r, (TrailingZeros C5LeftShiftUW:$i))),
(XLenVT (SLLI_UW GPR:$r, (TrailingZeros C5LeftShiftUW:$i))))>;
def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C9LeftShiftUW:$i)),
(SH3ADD (XLenVT (SLLI_UW GPR:$r, (TrailingZeros C9LeftShiftUW:$i))),
(XLenVT (SLLI_UW GPR:$r, (TrailingZeros C9LeftShiftUW:$i))))>;
} // Predicates = [HasStdExtZba, IsRV64]

let Predicates = [HasStdExtZbcOrZbkc] in {
Expand Down
9 changes: 5 additions & 4 deletions llvm/test/CodeGen/RISCV/addimm-mulimm.ll
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,9 @@ define i64 @add_mul_combine_infinite_loop(i64 %x) {
; RV32IMB-NEXT: sh3add a1, a1, a2
; RV32IMB-NEXT: sh1add a0, a0, a0
; RV32IMB-NEXT: slli a2, a0, 3
; RV32IMB-NEXT: addi a0, a2, 2047
; RV32IMB-NEXT: addi a0, a0, 1
; RV32IMB-NEXT: li a3, 1
; RV32IMB-NEXT: slli a3, a3, 11
; RV32IMB-NEXT: sh3add a0, a0, a3
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case was previously not caught by the "add mul X, 24, Y" pattern because there's two multiplies by 24, and thus it failed the one use check. Instead, it went through generic "mul X, 3 << 2" expansion, and thus ended with the shift/add.

With the change, we hit the "(add_like_non_imm12 (shl GPR:$rs1, (XLenVT i)), GPR:$rs2)" pattern - which critically doesn't check if the immediate could be split across two addi. We should probably adjust this, but it seems a) minor, and b) very very separate. (And if we had zbb, this would be a bseti anyways.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the change, we hit the "(add_like_non_imm12 (shl GPR:$rs1, (XLenVT i)), GPR:$rs2)" pattern - which critically doesn't check if the immediate could be split across two addi. We should probably adjust this, but it seems a) minor, and b) very very separate. (And if we had zbb, this would be a bseti anyways.)

This explanation makes sense to me.

; RV32IMB-NEXT: sltu a2, a0, a2
; RV32IMB-NEXT: add a1, a1, a2
; RV32IMB-NEXT: ret
Expand All @@ -610,8 +611,8 @@ define i64 @add_mul_combine_infinite_loop(i64 %x) {
; RV64IMB: # %bb.0:
; RV64IMB-NEXT: addi a0, a0, 86
; RV64IMB-NEXT: sh1add a0, a0, a0
; RV64IMB-NEXT: li a1, -16
; RV64IMB-NEXT: sh3add a0, a0, a1
; RV64IMB-NEXT: slli a0, a0, 3
; RV64IMB-NEXT: addi a0, a0, -16
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case was previously picked up by the add_like (mul_one_use X, 24), Y pattern which didn't check whether Y is an immediate or not.

; RV64IMB-NEXT: ret
%tmp0 = mul i64 %x, 24
%tmp1 = add i64 %tmp0, 2048
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,8 @@ define i64 @zext_mul12884901888(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul12884901888:
; RV64ZBA: # %bb.0:
; RV64ZBA-NEXT: sh1add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
; RV64ZBA-NEXT: sh1add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 12884901888
Expand All @@ -667,8 +667,8 @@ define i64 @zext_mul21474836480(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul21474836480:
; RV64ZBA: # %bb.0:
; RV64ZBA-NEXT: sh2add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
; RV64ZBA-NEXT: sh2add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 21474836480
Expand All @@ -688,8 +688,8 @@ define i64 @zext_mul38654705664(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul38654705664:
; RV64ZBA: # %bb.0:
; RV64ZBA-NEXT: sh3add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
; RV64ZBA-NEXT: sh3add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 38654705664
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/CodeGen/RISCV/rv64zba.ll
Original file line number Diff line number Diff line change
Expand Up @@ -865,8 +865,8 @@ define i64 @zext_mul12884901888(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul12884901888:
; RV64ZBA: # %bb.0:
; RV64ZBA-NEXT: sh1add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
; RV64ZBA-NEXT: sh1add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 12884901888
Expand All @@ -886,8 +886,8 @@ define i64 @zext_mul21474836480(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul21474836480:
; RV64ZBA: # %bb.0:
; RV64ZBA-NEXT: sh2add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
; RV64ZBA-NEXT: sh2add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 21474836480
Expand All @@ -907,8 +907,8 @@ define i64 @zext_mul38654705664(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul38654705664:
; RV64ZBA: # %bb.0:
; RV64ZBA-NEXT: sh3add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
; RV64ZBA-NEXT: sh3add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 38654705664
Expand Down
Loading