Skip to content

Commit

Permalink
[CodeGen]: Emit a more efficient magic number multiplication for exac…
Browse files Browse the repository at this point in the history
…t udivs

Have simpler lowering for exact udivs in both SelectionDAG and GlobalISel.
  • Loading branch information
AtariDreams committed Apr 22, 2024
1 parent f5cf417 commit c031895
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 142 deletions.
64 changes: 56 additions & 8 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5071,8 +5071,36 @@ MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) {
LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
auto &MIB = Builder;

bool UseSRL = false;
bool UseNPQ = false;
SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
SmallVector<Register, 16> Shifts, Factors;
auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
bool IsSplat = getIConstantSplatVal(*RHSDefInstr, MRI).has_value();

auto BuildExactUDIVPattern = [&](const Constant *C) {
// Don't recompute inverses for each splat element.
if (IsSplat && !Factors.empty()) {
Shifts.push_back(Shifts[0]);
Factors.push_back(Factors[0]);
return true;
}

auto *CI = cast<ConstantInt>(C);
APInt Divisor = CI->getValue();
unsigned Shift = Divisor.countr_zero();
if (Shift) {
Divisor.lshrInPlace(Shift);
UseSRL = true;
}

// Calculate the multiplicative inverse modulo BW.
// 2^W requires W + 1 bits, so we have to extend and then truncate.
APInt Factor = Divisor.multiplicativeInverse();
Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
return true;
};

auto BuildUDIVPattern = [&](const Constant *C) {
auto *CI = cast<ConstantInt>(C);
Expand Down Expand Up @@ -5115,6 +5143,29 @@ MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) {
return true;
};

if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
// Collect all magic values from the build vector.
bool Matched = matchUnaryPredicate(MRI, RHS, BuildExactUDIVPattern);
(void)Matched;
assert(Matched && "Expected unary predicate match to succeed");

Register Shift, Factor;
if (Ty.isVector()) {
Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
Factor = MIB.buildBuildVector(Ty, Factors).getReg(0);
} else {
Shift = Shifts[0];
Factor = Factors[0];
}

Register Res = LHS;

if (UseSRL)
Res = MIB.buildLShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);

return MIB.buildMul(Ty, Res, Factor);
}

// Collect the shifts/magic values from each element.
bool Matched = matchUnaryPredicate(MRI, RHS, BuildUDIVPattern);
(void)Matched;
Expand Down Expand Up @@ -5168,7 +5219,8 @@ bool CombinerHelper::matchUDivByConst(MachineInstr &MI) {
Register RHS = MI.getOperand(2).getReg();
LLT DstTy = MRI.getType(Dst);
auto *RHSDef = MRI.getVRegDef(RHS);
if (!isConstantOrConstantVector(*RHSDef, MRI))
if (!MI.getFlag(MachineInstr::MIFlag::IsExact) &&
!isConstantOrConstantVector(*RHSDef, MRI))
return false;

auto &MF = *MI.getMF();
Expand Down Expand Up @@ -5197,12 +5249,8 @@ bool CombinerHelper::matchUDivByConst(MachineInstr &MI) {
return false;
}

auto CheckEltValue = [&](const Constant *C) {
if (auto *CI = dyn_cast_or_null<ConstantInt>(C))
return !CI->isZero();
return false;
};
return matchUnaryPredicate(MRI, RHS, CheckEltValue);
return matchUnaryPredicate(
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
}

void CombinerHelper::applyUDivByConst(MachineInstr &MI) {
Expand Down Expand Up @@ -5232,7 +5280,7 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) {
// If the sdiv has an 'exact' flag we can use a simpler lowering.
if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
return matchUnaryPredicate(
MRI, RHS, [](const Constant *C) { return C && !C->isZeroValue(); });
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
}

// Don't support the general case for now.
Expand Down
69 changes: 68 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6116,7 +6116,6 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,

// Shift the value upfront if it is even, so the LSB is one.
if (UseSRA) {
// TODO: For UDIV use SRL instead of SRA.
SDNodeFlags Flags;
Flags.setExact(true);
Res = DAG.getNode(ISD::SRA, dl, VT, Res, Shift, Flags);
Expand All @@ -6126,6 +6125,70 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
return DAG.getNode(ISD::MUL, dl, VT, Res, Factor);
}

/// Given an exact UDIV by a constant, create a multiplication
/// with the multiplicative inverse of the constant.
static SDValue BuildExactUDIV(const TargetLowering &TLI, SDNode *N,
const SDLoc &dl, SelectionDAG &DAG,
SmallVectorImpl<SDNode *> &Created) {
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
EVT VT = N->getValueType(0);
EVT SVT = VT.getScalarType();
EVT ShVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
EVT ShSVT = ShVT.getScalarType();

bool UseSRL = false;
SmallVector<SDValue, 16> Shifts, Factors;

auto BuildUDIVPattern = [&](ConstantSDNode *C) {
if (C->isZero())
return false;
APInt Divisor = C->getAPIntValue();
unsigned Shift = Divisor.countr_zero();
if (Shift) {
Divisor.lshrInPlace(Shift);
UseSRL = true;
}
// Calculate the multiplicative inverse modulo BW.
APInt Factor = Divisor.multiplicativeInverse();
Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
Factors.push_back(DAG.getConstant(Factor, dl, SVT));
return true;
};

// Collect all magic values from the build vector.
if (!ISD::matchUnaryPredicate(Op1, BuildUDIVPattern))
return SDValue();

SDValue Shift, Factor;
if (Op1.getOpcode() == ISD::BUILD_VECTOR) {
Shift = DAG.getBuildVector(ShVT, dl, Shifts);
Factor = DAG.getBuildVector(VT, dl, Factors);
} else if (Op1.getOpcode() == ISD::SPLAT_VECTOR) {
assert(Shifts.size() == 1 && Factors.size() == 1 &&
"Expected matchUnaryPredicate to return one element for scalable "
"vectors");
Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]);
Factor = DAG.getSplatVector(VT, dl, Factors[0]);
} else {
assert(isa<ConstantSDNode>(Op1) && "Expected a constant");
Shift = Shifts[0];
Factor = Factors[0];
}

SDValue Res = Op0;

// Shift the value upfront if it is even, so the LSB is one.
if (UseSRL) {
SDNodeFlags Flags;
Flags.setExact(true);
Res = DAG.getNode(ISD::SRL, dl, VT, Res, Shift, Flags);
Created.push_back(Res.getNode());
}

return DAG.getNode(ISD::MUL, dl, VT, Res, Factor);
}

SDValue TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
SelectionDAG &DAG,
SmallVectorImpl<SDNode *> &Created) const {
Expand Down Expand Up @@ -6385,6 +6448,10 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

// If the udiv has an 'exact' bit we can use a simpler lowering.
if (N->getFlags().hasExact())
return BuildExactUDIV(*this, N, dl, DAG, Created);

SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);

Expand Down
45 changes: 22 additions & 23 deletions llvm/test/CodeGen/AArch64/GlobalISel/combine-udiv.mir
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,11 @@ body: |
; CHECK: liveins: $w0
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 1321528399
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 5
; CHECK-NEXT: [[UMULH:%[0-9]+]]:_(s32) = G_UMULH [[COPY]], [[C]]
; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s32) = G_LSHR [[UMULH]], [[C1]](s32)
; CHECK-NEXT: $w0 = COPY [[LSHR]](s32)
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 -991146299
; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s32) = exact G_LSHR [[COPY]], [[C]](s32)
; CHECK-NEXT: [[MUL:%[0-9]+]]:_(s32) = G_MUL [[LSHR]], [[C1]]
; CHECK-NEXT: $w0 = COPY [[MUL]](s32)
; CHECK-NEXT: RET_ReallyLR implicit $w0
%0:_(s32) = COPY $w0
%1:_(s32) = G_CONSTANT i32 104
Expand Down Expand Up @@ -361,11 +361,11 @@ body: |
; CHECK: liveins: $w0
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 1321528399
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 5
; CHECK-NEXT: [[UMULH:%[0-9]+]]:_(s32) = G_UMULH [[COPY]], [[C]]
; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s32) = G_LSHR [[UMULH]], [[C1]](s32)
; CHECK-NEXT: $w0 = COPY [[LSHR]](s32)
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 -991146299
; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s32) = exact G_LSHR [[COPY]], [[C]](s32)
; CHECK-NEXT: [[MUL:%[0-9]+]]:_(s32) = G_MUL [[LSHR]], [[C1]]
; CHECK-NEXT: $w0 = COPY [[MUL]](s32)
; CHECK-NEXT: RET_ReallyLR implicit $w0
%0:_(s32) = COPY $w0
%1:_(s32) = G_CONSTANT i32 104
Expand All @@ -384,15 +384,14 @@ body: |
; CHECK: liveins: $q0
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:_(<4 x s32>) = COPY $q0
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 1321528399
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 5
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 -991146299
; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 954437177
; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 4
; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[C]](s32), [[C2]](s32), [[C]](s32), [[C2]](s32)
; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[C1]](s32), [[C3]](s32), [[C1]](s32), [[C3]](s32)
; CHECK-NEXT: [[UMULH:%[0-9]+]]:_(<4 x s32>) = G_UMULH [[COPY]], [[BUILD_VECTOR]]
; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(<4 x s32>) = G_LSHR [[UMULH]], [[BUILD_VECTOR1]](<4 x s32>)
; CHECK-NEXT: $q0 = COPY [[LSHR]](<4 x s32>)
; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[C]](s32), [[C]](s32), [[C]](s32), [[C]](s32)
; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[C1]](s32), [[C2]](s32), [[C1]](s32), [[C2]](s32)
; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(<4 x s32>) = exact G_LSHR [[COPY]], [[BUILD_VECTOR]](<4 x s32>)
; CHECK-NEXT: [[MUL:%[0-9]+]]:_(<4 x s32>) = G_MUL [[LSHR]], [[BUILD_VECTOR1]]
; CHECK-NEXT: $q0 = COPY [[MUL]](<4 x s32>)
; CHECK-NEXT: RET_ReallyLR implicit $q0
%0:_(<4 x s32>) = COPY $q0
%c1:_(s32) = G_CONSTANT i32 104
Expand All @@ -413,13 +412,13 @@ body: |
; CHECK: liveins: $q0
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:_(<4 x s32>) = COPY $q0
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 1321528399
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 5
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 -991146299
; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[C]](s32), [[C]](s32), [[C]](s32), [[C]](s32)
; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[C1]](s32), [[C1]](s32), [[C1]](s32), [[C1]](s32)
; CHECK-NEXT: [[UMULH:%[0-9]+]]:_(<4 x s32>) = G_UMULH [[COPY]], [[BUILD_VECTOR]]
; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(<4 x s32>) = G_LSHR [[UMULH]], [[BUILD_VECTOR1]](<4 x s32>)
; CHECK-NEXT: $q0 = COPY [[LSHR]](<4 x s32>)
; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(<4 x s32>) = exact G_LSHR [[COPY]], [[BUILD_VECTOR]](<4 x s32>)
; CHECK-NEXT: [[MUL:%[0-9]+]]:_(<4 x s32>) = G_MUL [[LSHR]], [[BUILD_VECTOR1]]
; CHECK-NEXT: $q0 = COPY [[MUL]](<4 x s32>)
; CHECK-NEXT: RET_ReallyLR implicit $q0
%0:_(<4 x s32>) = COPY $q0
%c1:_(s32) = G_CONSTANT i32 104
Expand Down
Loading

0 comments on commit c031895

Please sign in to comment.