diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index cd7f0e719ad0c..183fc763cd2e9 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -19964,6 +19964,254 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) { CSel0.getOperand(1), getCondCode(DAG, CC1), CCmp); } +// Fold lsl + lsr + orr to rev for half-width shifts +// Pattern: orr(lsl(x, shift), lsr(x, shift)) -> rev(x) when shift == half_bitwidth +static SDValue performLSL_LSR_ORRCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + if (!Subtarget->hasSVE()) + return SDValue(); + + EVT VT = N->getValueType(0); + if (!VT.isScalableVector()) + return SDValue(); + + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + // Check if one operand is LSL and the other is LSR + SDValue LSL, LSR; + if (LHS.getOpcode() == ISD::SHL && RHS.getOpcode() == ISD::SRL) { + LSL = LHS; + LSR = RHS; + } else if (LHS.getOpcode() == ISD::SRL && RHS.getOpcode() == ISD::SHL) { + LSL = RHS; + LSR = LHS; + } else { + return SDValue(); + } + + // Check that both shifts operate on the same source + SDValue Src = LSL.getOperand(0); + if (Src != LSR.getOperand(0)) + return SDValue(); + + // Check that both shifts have the same constant amount + if (!isa(LSL.getOperand(1)) || + !isa(LSR.getOperand(1))) + return SDValue(); + + uint64_t ShiftAmt = LSL.getConstantOperandVal(1); + if (ShiftAmt != LSR.getConstantOperandVal(1)) + return SDValue(); + + // Check if shift amount equals half the bitwidth + EVT EltVT = VT.getVectorElementType(); + if (!EltVT.isSimple()) + return SDValue(); + + unsigned EltSize = EltVT.getSizeInBits(); + if (ShiftAmt != EltSize / 2) + return SDValue(); + + // Determine the appropriate REV instruction based on element size and shift amount + unsigned RevOp; + switch (EltSize) { + case 16: + if (ShiftAmt == 8) + RevOp = AArch64ISD::BSWAP_MERGE_PASSTHRU; // 16-bit elements, 8-bit shift -> revb + else + return SDValue(); + break; + case 32: + if (ShiftAmt == 16) + RevOp = AArch64ISD::REVH_MERGE_PASSTHRU; // 32-bit elements, 16-bit shift -> revh + else + return SDValue(); + break; + case 64: + if (ShiftAmt == 32) + RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 64-bit elements, 32-bit shift -> revw + else + return SDValue(); + break; + default: + return SDValue(); + } + + // Create the REV instruction + SDLoc DL(N); + SDValue Pg = getPredicateForVector(DAG, DL, VT); + SDValue Undef = DAG.getUNDEF(VT); + + return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef); +} + +// Fold bswap to correct rev instruction for scalable vectors +// DAGCombiner converts lsl+lsr+orr with 8-bit shift to BSWAP, but for scalable vectors +// we need to use the correct REV instruction based on element size +static SDValue performBSWAPCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + LLVM_DEBUG(dbgs() << "BSWAP combine called\n"); + if (!Subtarget->hasSVE()) + return SDValue(); + + EVT VT = N->getValueType(0); + if (!VT.isScalableVector()) + return SDValue(); + + LLVM_DEBUG(dbgs() << "BSWAP combine called for scalable vector\n"); + + EVT EltVT = VT.getVectorElementType(); + if (!EltVT.isSimple()) + return SDValue(); + + unsigned EltSize = EltVT.getSizeInBits(); + + // For scalable vectors with 16-bit elements, BSWAP should use REVB, not REVH + // REVH is not available for 16-bit elements, only for 32-bit and 64-bit elements + // For 16-bit elements, REVB (byte reverse) is equivalent to halfword reverse + if (EltSize != 16) + return SDValue(); // Use default BSWAP lowering for other sizes + + // The current BSWAP lowering is already correct for 16-bit elements + // BSWAP_MERGE_PASSTHRU maps to REVB which is correct for 16-bit elements + return SDValue(); +} + +// Fold rotl to rev instruction for half-width rotations on scalable vectors +// Pattern: rotl(x, half_bitwidth) -> rev(x) for scalable vectors +static SDValue performROTLCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + LLVM_DEBUG(dbgs() << "ROTL combine called\n"); + if (!Subtarget->hasSVE()) + return SDValue(); + + EVT VT = N->getValueType(0); + if (!VT.isScalableVector()) + return SDValue(); + + // Check that the rotation amount is a constant + if (!isa(N->getOperand(1))) + return SDValue(); + + uint64_t RotAmt = N->getConstantOperandVal(1); + + // Check if rotation amount equals half the bitwidth + EVT EltVT = VT.getVectorElementType(); + if (!EltVT.isSimple()) + return SDValue(); + + unsigned EltSize = EltVT.getSizeInBits(); + if (RotAmt != EltSize / 2) + return SDValue(); + + // Determine the appropriate REV instruction based on element size + unsigned RevOp; + switch (EltSize) { + case 16: + return SDValue(); // 16-bit case handled by BSWAP + case 32: + RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 32-bit elements, 16-bit rotation -> revw + break; + case 64: + RevOp = AArch64ISD::REVD_MERGE_PASSTHRU; // 64-bit elements, 32-bit rotation -> revd + break; + default: + return SDValue(); + } + + // Create the REV instruction + SDLoc DL(N); + SDValue Src = N->getOperand(0); + SDValue Pg = getPredicateForVector(DAG, DL, VT); + SDValue Undef = DAG.getUNDEF(VT); + + return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef); +} + +// Fold predicated shl + srl + orr to rev for half-width shifts on scalable vectors +// Pattern: orr(AArch64ISD::SHL_PRED(pg, x, shift), AArch64ISD::SRL_PRED(pg, x, shift)) -> rev(x) when shift == half_bitwidth +static SDValue performSVE_SHL_SRL_ORRCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + if (!Subtarget->hasSVE()) + return SDValue(); + + EVT VT = N->getValueType(0); + if (!VT.isScalableVector()) + return SDValue(); + + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + // Check if one operand is predicated SHL and the other is predicated SRL + SDValue SHL, SRL; + if (LHS.getOpcode() == AArch64ISD::SHL_PRED && RHS.getOpcode() == AArch64ISD::SRL_PRED) { + SHL = LHS; + SRL = RHS; + } else if (LHS.getOpcode() == AArch64ISD::SRL_PRED && RHS.getOpcode() == AArch64ISD::SHL_PRED) { + SHL = RHS; + SRL = LHS; + } else { + return SDValue(); + } + + // Check that both shifts operate on the same predicate and source + SDValue SHLPred = SHL.getOperand(0); + SDValue SHLSrc = SHL.getOperand(1); + SDValue SHLAmt = SHL.getOperand(2); + + SDValue SRLPred = SRL.getOperand(0); + SDValue SRLSrc = SRL.getOperand(1); + SDValue SRLAmt = SRL.getOperand(2); + + if (SHLPred != SRLPred || SHLSrc != SRLSrc || SHLAmt != SRLAmt) + return SDValue(); + + // Check that the shift amount is a constant + if (!isa(SHLAmt->getOperand(0))) // For splat_vector + return SDValue(); + + uint64_t ShiftAmt = cast(SHLAmt->getOperand(0))->getZExtValue(); + + // Check if shift amount equals half the bitwidth + EVT EltVT = VT.getVectorElementType(); + if (!EltVT.isSimple()) + return SDValue(); + + unsigned EltSize = EltVT.getSizeInBits(); + if (ShiftAmt != EltSize / 2) + return SDValue(); + + // Determine the appropriate REV instruction based on element size and shift amount + unsigned RevOp; + switch (EltSize) { + case 16: + return SDValue(); // 16-bit case handled by BSWAP + case 32: + if (ShiftAmt == 16) + RevOp = AArch64ISD::REVH_MERGE_PASSTHRU; // 32-bit elements, 16-bit shift -> revh + else + return SDValue(); + break; + case 64: + if (ShiftAmt == 32) + RevOp = AArch64ISD::REVW_MERGE_PASSTHRU; // 64-bit elements, 32-bit shift -> revw + else + return SDValue(); + break; + default: + return SDValue(); + } + + // Create the REV instruction + SDLoc DL(N); + SDValue Pg = SHLPred; + SDValue Src = SHLSrc; + SDValue Undef = DAG.getUNDEF(VT); + + return DAG.getNode(RevOp, DL, VT, Pg, Src, Undef); +} + static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget, const AArch64TargetLowering &TLI) { @@ -19972,6 +20220,13 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, if (SDValue R = performANDORCSELCombine(N, DAG)) return R; + if (SDValue R = performLSL_LSR_ORRCombine(N, DAG, Subtarget)) + return R; + + // Try the predicated shift combine for SVE + if (SDValue R = performSVE_SHL_SRL_ORRCombine(N, DAG, Subtarget)) + return R; + return SDValue(); } @@ -27592,6 +27847,12 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performFpToIntCombine(N, DAG, DCI, Subtarget); case ISD::OR: return performORCombine(N, DCI, Subtarget, *this); + case ISD::BSWAP: + return performBSWAPCombine(N, DAG, Subtarget); + case AArch64ISD::BSWAP_MERGE_PASSTHRU: + return performBSWAPCombine(N, DAG, Subtarget); + case ISD::ROTL: + return performROTLCombine(N, DAG, Subtarget); case ISD::AND: return performANDCombine(N, DCI); case ISD::FADD: diff --git a/llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll b/llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll new file mode 100644 index 0000000000000..8abfbdcc3edcb --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-lsl-lsr-orr-rev-combine.ll @@ -0,0 +1,97 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck --check-prefixes=CHECK,CHECK-SVE %s +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -force-streaming < %s | FileCheck --check-prefixes=CHECK,CHECK-SME %s + +; Test the optimization that folds lsl + lsr + orr to rev for half-width shifts + +; Test case 1: 16-bit elements with 8-bit shift -> revh +define @lsl_lsr_orr_revh_i16( %x) { +; CHECK-LABEL: lsl_lsr_orr_revh_i16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: revb z0.h, p0/m, z0.h +; CHECK-NEXT: ret + %lsl = shl %x, splat(i16 8) + %lsr = lshr %x, splat(i16 8) + %orr = or %lsl, %lsr + ret %orr +} + +; Test case 2: 32-bit elements with 16-bit shift -> revh +define @lsl_lsr_orr_revw_i32( %x) { +; CHECK-SVE-LABEL: lsl_lsr_orr_revw_i32: +; CHECK-SVE: // %bb.0: +; CHECK-SVE-NEXT: ptrue p0.s +; CHECK-SVE-NEXT: revh z0.s, p0/m, z0.s +; CHECK-SVE-NEXT: ret +; +; CHECK-SME-LABEL: lsl_lsr_orr_revw_i32: +; CHECK-SME: // %bb.0: +; CHECK-SME-NEXT: movi v1.2d, #0000000000000000 +; CHECK-SME-NEXT: xar z0.s, z0.s, z1.s, #16 +; CHECK-SME-NEXT: ret + %lsl = shl %x, splat(i32 16) + %lsr = lshr %x, splat(i32 16) + %orr = or %lsl, %lsr + ret %orr +} + +; Test case 3: 64-bit elements with 32-bit shift -> revw +define @lsl_lsr_orr_revd_i64( %x) { +; CHECK-SVE-LABEL: lsl_lsr_orr_revd_i64: +; CHECK-SVE: // %bb.0: +; CHECK-SVE-NEXT: ptrue p0.d +; CHECK-SVE-NEXT: revw z0.d, p0/m, z0.d +; CHECK-SVE-NEXT: ret +; +; CHECK-SME-LABEL: lsl_lsr_orr_revd_i64: +; CHECK-SME: // %bb.0: +; CHECK-SME-NEXT: movi v1.2d, #0000000000000000 +; CHECK-SME-NEXT: xar z0.d, z0.d, z1.d, #32 +; CHECK-SME-NEXT: ret + %lsl = shl %x, splat(i64 32) + %lsr = lshr %x, splat(i64 32) + %orr = or %lsl, %lsr + ret %orr +} + +; Test case 4: Order doesn't matter - lsr + lsl + orr -> revh +define @lsr_lsl_orr_revh_i16( %x) { +; CHECK-LABEL: lsr_lsl_orr_revh_i16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: revb z0.h, p0/m, z0.h +; CHECK-NEXT: ret + %lsr = lshr %x, splat(i16 8) + %lsl = shl %x, splat(i16 8) + %orr = or %lsr, %lsl + ret %orr +} + +; Test case 5: Non-half-width shift should not be optimized +define @lsl_lsr_orr_no_opt_i16( %x) { +; CHECK-LABEL: lsl_lsr_orr_no_opt_i16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z1.h, z0.h, #4 +; CHECK-NEXT: lsr z0.h, z0.h, #4 +; CHECK-NEXT: orr z0.d, z1.d, z0.d +; CHECK-NEXT: ret + %lsl = shl %x, splat(i16 4) + %lsr = lshr %x, splat(i16 4) + %orr = or %lsl, %lsr + ret %orr +} + +; Test case 6: Different shift amounts should not be optimized +define @lsl_lsr_orr_different_shifts_i16( %x) { +; CHECK-LABEL: lsl_lsr_orr_different_shifts_i16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z1.h, z0.h, #8 +; CHECK-NEXT: lsr z0.h, z0.h, #4 +; CHECK-NEXT: orr z0.d, z1.d, z0.d +; CHECK-NEXT: ret + %lsl = shl %x, splat(i16 8) + %lsr = lshr %x, splat(i16 4) + %orr = or %lsl, %lsr + ret %orr +}