diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a81de5c5adc34..132afc27135e9 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -19476,6 +19476,61 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) { Op1 ? Op1 : Mul->getOperand(1)); } +// Multiplying an RDSVL value by a constant can sometimes be done cheaper by +// folding a power-of-two factor of the constant into the RDSVL immediate and +// compensating with an extra shift. +// +// We rewrite: +// (mul (srl (rdsvl 1), w), x) +// to one of: +// (shl (rdsvl y), z) if z > 0 +// (srl (rdsvl y), abs(z)) if z < 0 +// where integers y, z satisfy x = y * 2^(w + z) and y ∈ [-32, 31]. +static SDValue performMulRdsvlCombine(SDNode *Mul, SelectionDAG &DAG) { + SDLoc DL(Mul); + EVT VT = Mul->getValueType(0); + SDValue MulOp0 = Mul->getOperand(0); + int ConstMultiplier = + cast(Mul->getOperand(1))->getSExtValue(); + if ((MulOp0->getOpcode() != ISD::SRL) || + (MulOp0->getOperand(0).getOpcode() != AArch64ISD::RDSVL)) + return SDValue(); + + unsigned AbsConstValue = abs(ConstMultiplier); + unsigned OperandShift = + cast(MulOp0->getOperand(1))->getZExtValue(); + + // z ≤ ctz(|x|) - w (largest extra shift we can take while keeping y + // integral) + int UpperBound = llvm::countr_zero(AbsConstValue) - OperandShift; + + // To keep y in range, with B = 31 for x > 0 and B = 32 for x < 0, we need: + // 2^(w + z) ≥ ceil(x / B) ⇒ z ≥ ceil_log2(ceil(x / B)) - w (LowerBound). + unsigned B = ConstMultiplier < 0 ? 32 : 31; + unsigned CeilAxOverB = (AbsConstValue + (B - 1)) / B; // ceil(|x|/B) + int LowerBound = llvm::Log2_32_Ceil(CeilAxOverB) - OperandShift; + + // No valid solution found. + if (LowerBound > UpperBound) + return SDValue(); + + // Any value of z in [LowerBound, UpperBound] is valid. Prefer no extra + // shift if possible. + int Shift = std::min(std::max(/*prefer*/ 0, LowerBound), UpperBound); + + // y = x / 2^(w + z) + int32_t RdsvlMul = (AbsConstValue >> (OperandShift + Shift)) * + (ConstMultiplier < 0 ? -1 : 1); + auto Rdsvl = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, + DAG.getSignedConstant(RdsvlMul, DL, MVT::i32)); + + if (Shift == 0) + return Rdsvl; + return DAG.getNode(Shift < 0 ? ISD::SRL : ISD::SHL, DL, VT, Rdsvl, + DAG.getConstant(abs(Shift), DL, MVT::i32), + SDNodeFlags::Exact); +} + // Combine v4i32 Mul(And(Srl(X, 15), 0x10001), 0xffff) -> v8i16 CMLTz // Same for other types with equivalent constants. static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) { @@ -19604,6 +19659,9 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, if (!isa(N1)) return SDValue(); + if (SDValue Ext = performMulRdsvlCombine(N, DAG)) + return Ext; + ConstantSDNode *C = cast(N1); const APInt &ConstValue = C->getAPIntValue(); diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll index 06c53d8070781..b17b48c7e04d3 100644 --- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll +++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll @@ -86,4 +86,111 @@ define i64 @sme_cntsd_mul() { ret i64 %res } -declare i64 @llvm.aarch64.sme.cntsd() +define i64 @sme_cntsb_mul_pos() { +; CHECK-LABEL: sme_cntsb_mul_pos: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #24 +; CHECK-NEXT: lsl x0, x8, #2 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 3 + %res = mul nuw nsw i64 %shl, 96 + ret i64 %res +} + +define i64 @sme_cntsh_mul_pos() { +; CHECK-LABEL: sme_cntsh_mul_pos: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #3 +; CHECK-NEXT: lsr x0, x8, #1 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 2 + %res = mul nuw nsw i64 %shl, 3 + ret i64 %res +} + +define i64 @sme_cntsw_mul_pos() { +; CHECK-LABEL: sme_cntsw_mul_pos: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #31 +; CHECK-NEXT: lsr x0, x8, #1 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 1 + %res = mul nuw nsw i64 %shl, 62 + ret i64 %res +} + +define i64 @sme_cntsd_mul_pos() { +; CHECK-LABEL: sme_cntsd_mul_pos: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #31 +; CHECK-NEXT: lsl x0, x8, #2 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %res = mul nuw nsw i64 %v, 992 + ret i64 %res +} + +define i64 @sme_cntsb_mul_neg() { +; CHECK-LABEL: sme_cntsb_mul_neg: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #-24 +; CHECK-NEXT: lsl x0, x8, #2 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 3 + %res = mul nuw nsw i64 %shl, -96 + ret i64 %res +} + +define i64 @sme_cntsh_mul_neg() { +; CHECK-LABEL: sme_cntsh_mul_neg: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #-3 +; CHECK-NEXT: lsr x0, x8, #1 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 2 + %res = mul nuw nsw i64 %shl, -3 + ret i64 %res +} + +define i64 @sme_cntsw_mul_neg() { +; CHECK-LABEL: sme_cntsw_mul_neg: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #-31 +; CHECK-NEXT: lsl x0, x8, #3 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 1 + %res = mul nuw nsw i64 %shl, -992 + ret i64 %res +} + +define i64 @sme_cntsd_mul_neg() { +; CHECK-LABEL: sme_cntsd_mul_neg: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #-3 +; CHECK-NEXT: lsr x0, x8, #3 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %res = mul nuw nsw i64 %v, -3 + ret i64 %res +} + +; Negative test for optimization failure +define i64 @sme_cntsd_mul_fail() { +; CHECK-LABEL: sme_cntsd_mul_fail: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov w9, #993 // =0x3e1 +; CHECK-NEXT: lsr x8, x8, #3 +; CHECK-NEXT: mul x0, x8, x9 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %res = mul nuw nsw i64 %v, 993 + ret i64 %res +} +