Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] Add CodeGen support for GFX12 s_mul_u64 #75825

Merged
merged 8 commits into from Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion llvm/lib/Target/AMDGPU/AMDGPUCombine.td
Expand Up @@ -104,6 +104,13 @@ def foldable_fneg : GICombineRule<
[{ return Helper.matchFoldableFneg(*${ffn}, ${matchinfo}); }]),
(apply [{ Helper.applyFoldableFneg(*${ffn}, ${matchinfo}); }])>;

// Detects s_mul_u64 instructions whose higher bits are zero/sign extended.
def smulu64 : GICombineRule<
(defs root:$smul, unsigned_matchinfo:$matchinfo),
(match (wip_match_opcode G_MUL):$smul,
[{ return matchCombine_s_mul_u64(*${smul}, ${matchinfo}); }]),
(apply [{ applyCombine_s_mul_u64(*${smul}, ${matchinfo}); }])>;

def sign_exension_in_reg_matchdata : GIDefMatchData<"MachineInstr *">;

def sign_extension_in_reg : GICombineRule<
Expand Down Expand Up @@ -149,7 +156,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
"AMDGPUPostLegalizerCombinerImpl",
[all_combines, gfx6gfx7_combines, gfx8_combines,
uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
rcp_sqrt_to_rsq, sign_extension_in_reg]> {
rcp_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
let CombineAllMethodName = "tryCombineAllImpl";
}

Expand Down
24 changes: 17 additions & 7 deletions llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
Expand Up @@ -701,13 +701,23 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
.maxScalar(0, S32);
}

getActionDefinitionsBuilder(G_MUL)
.legalFor({S32, S16, V2S16})
.clampMaxNumElementsStrict(0, S16, 2)
.scalarize(0)
.minScalar(0, S16)
.widenScalarToNextMultipleOf(0, 32)
.custom();
if (ST.hasScalarSMulU64()) {
getActionDefinitionsBuilder(G_MUL)
.legalFor({S64, S32, S16, V2S16})
.clampMaxNumElementsStrict(0, S16, 2)
.scalarize(0)
.minScalar(0, S16)
.widenScalarToNextMultipleOf(0, 32)
.custom();
} else {
getActionDefinitionsBuilder(G_MUL)
.legalFor({S32, S16, V2S16})
.clampMaxNumElementsStrict(0, S16, 2)
.scalarize(0)
.minScalar(0, S16)
.widenScalarToNextMultipleOf(0, 32)
.custom();
}
assert(ST.hasMad64_32());

getActionDefinitionsBuilder({G_UADDSAT, G_USUBSAT, G_SADDSAT, G_SSUBSAT})
Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
Expand Up @@ -104,6 +104,14 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
void applyCombineSignExtendInReg(MachineInstr &MI,
MachineInstr *&MatchInfo) const;

// Find the s_mul_u64 instructions where the higher bits are either
// zero-extended or sign-extended.
bool matchCombine_s_mul_u64(MachineInstr &MI, unsigned &NewOpcode) const;
// Replace the s_mul_u64 instructions with S_MUL_I64_I32_PSEUDO if the higher
// 33 bits are sign extended and with S_MUL_U64_U32_PSEUDO if the higher 32
// bits are zero extended.
void applyCombine_s_mul_u64(MachineInstr &MI, unsigned &NewOpcode) const;

private:
#define GET_GICOMBINER_CLASS_MEMBERS
#define AMDGPUSubtarget GCNSubtarget
Expand Down Expand Up @@ -419,6 +427,32 @@ void AMDGPUPostLegalizerCombinerImpl::applyCombineSignExtendInReg(
MI.eraseFromParent();
}

bool AMDGPUPostLegalizerCombinerImpl::matchCombine_s_mul_u64(
MachineInstr &MI, unsigned &NewOpcode) const {
Register Src0 = MI.getOperand(1).getReg();
Register Src1 = MI.getOperand(2).getReg();
if (MRI.getType(Src0) != LLT::scalar(64))
return false;

if (KB->getKnownBits(Src1).countMinLeadingZeros() >= 32 &&
KB->getKnownBits(Src0).countMinLeadingZeros() >= 32) {
NewOpcode = AMDGPU::G_AMDGPU_S_MUL_U64_U32;
return true;
}

if (KB->computeNumSignBits(Src1) >= 33 &&
KB->computeNumSignBits(Src0) >= 33) {
NewOpcode = AMDGPU::G_AMDGPU_S_MUL_I64_I32;
return true;
}
return false;
}

void AMDGPUPostLegalizerCombinerImpl::applyCombine_s_mul_u64(
MachineInstr &MI, unsigned &NewOpcode) const {
Helper.replaceOpcodeWith(MI, NewOpcode);
}

// Pass boilerplate
// ================

Expand Down
150 changes: 147 additions & 3 deletions llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
Expand Up @@ -2094,6 +2094,74 @@ bool AMDGPURegisterBankInfo::foldInsertEltToCmpSelect(
return true;
}

// Break s_mul_u64 into 32-bit vector operations.
void AMDGPURegisterBankInfo::applyMappingSMULU64(
MachineIRBuilder &B, const OperandsMapper &OpdMapper) const {
SmallVector<Register, 2> DefRegs(OpdMapper.getVRegs(0));
SmallVector<Register, 2> Src0Regs(OpdMapper.getVRegs(1));
SmallVector<Register, 2> Src1Regs(OpdMapper.getVRegs(2));

// All inputs are SGPRs, nothing special to do.
if (DefRegs.empty()) {
assert(Src0Regs.empty() && Src1Regs.empty());
applyDefaultMapping(OpdMapper);
return;
}

assert(DefRegs.size() == 2);
assert(Src0Regs.size() == Src1Regs.size() &&
(Src0Regs.empty() || Src0Regs.size() == 2));

MachineRegisterInfo &MRI = OpdMapper.getMRI();
MachineInstr &MI = OpdMapper.getMI();
Register DstReg = MI.getOperand(0).getReg();
LLT HalfTy = LLT::scalar(32);

// Depending on where the source registers came from, the generic code may
// have decided to split the inputs already or not. If not, we still need to
// extract the values.

if (Src0Regs.empty())
split64BitValueForMapping(B, Src0Regs, HalfTy, MI.getOperand(1).getReg());
else
setRegsToType(MRI, Src0Regs, HalfTy);

if (Src1Regs.empty())
split64BitValueForMapping(B, Src1Regs, HalfTy, MI.getOperand(2).getReg());
else
setRegsToType(MRI, Src1Regs, HalfTy);

setRegsToType(MRI, DefRegs, HalfTy);

// The multiplication is done as follows:
//
// Op1H Op1L
// * Op0H Op0L
// --------------------
// Op1H*Op0L Op1L*Op0L
// + Op1H*Op0H Op1L*Op0H
// -----------------------------------------
// (Op1H*Op0L + Op1L*Op0H + carry) Op1L*Op0L
//
// We drop Op1H*Op0H because the result of the multiplication is a 64-bit
// value and that would overflow.
// The low 32-bit value is Op1L*Op0L.
// The high 32-bit value is Op1H*Op0L + Op1L*Op0H + carry (from
// Op1L*Op0L).

ApplyRegBankMapping ApplyBank(B, *this, MRI, &AMDGPU::VGPRRegBank);

Register Hi = B.buildUMulH(HalfTy, Src0Regs[0], Src1Regs[0]).getReg(0);
Register MulLoHi = B.buildMul(HalfTy, Src0Regs[0], Src1Regs[1]).getReg(0);
Register Add = B.buildAdd(HalfTy, Hi, MulLoHi).getReg(0);
Register MulHiLo = B.buildMul(HalfTy, Src0Regs[1], Src1Regs[0]).getReg(0);
B.buildAdd(DefRegs[1], Add, MulHiLo);
B.buildMul(DefRegs[0], Src0Regs[0], Src1Regs[0]);

MRI.setRegBank(DstReg, AMDGPU::VGPRRegBank);
MI.eraseFromParent();
}

void AMDGPURegisterBankInfo::applyMappingImpl(
MachineIRBuilder &B, const OperandsMapper &OpdMapper) const {
MachineInstr &MI = OpdMapper.getMI();
Expand Down Expand Up @@ -2394,13 +2462,21 @@ void AMDGPURegisterBankInfo::applyMappingImpl(
Register DstReg = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(DstReg);

// Special case for s_mul_u64. There is not a vector equivalent of
// s_mul_u64. Hence, we have to break down s_mul_u64 into 32-bit vector
// multiplications.
if (Opc == AMDGPU::G_MUL && DstTy.getSizeInBits() == 64) {
applyMappingSMULU64(B, OpdMapper);
return;
}

// 16-bit operations are VALU only, but can be promoted to 32-bit SALU.
// Packed 16-bit operations need to be scalarized and promoted.
if (DstTy != LLT::scalar(16) && DstTy != LLT::fixed_vector(2, 16))
break;

const RegisterBank *DstBank =
OpdMapper.getInstrMapping().getOperandMapping(0).BreakDown[0].RegBank;
OpdMapper.getInstrMapping().getOperandMapping(0).BreakDown[0].RegBank;
if (DstBank == &AMDGPU::VGPRRegBank)
break;

Expand Down Expand Up @@ -2451,6 +2527,72 @@ void AMDGPURegisterBankInfo::applyMappingImpl(

return;
}
case AMDGPU::G_AMDGPU_S_MUL_I64_I32:
case AMDGPU::G_AMDGPU_S_MUL_U64_U32: {
// This is a special case for s_mul_u64. We use
// G_AMDGPU_S_MUL_I64_I32 opcode to represent an s_mul_u64 operation
// where the 33 higher bits are sign-extended and
// G_AMDGPU_S_MUL_U64_U32 opcode to represent an s_mul_u64 operation
// where the 32 higher bits are zero-extended. In case scalar registers are
// selected, both opcodes are lowered as s_mul_u64. If the vector registers
// are selected, then G_AMDGPU_S_MUL_I64_I32 and
// G_AMDGPU_S_MUL_U64_U32 are lowered with a vector mad instruction.

// Insert basic copies.
applyDefaultMapping(OpdMapper);

Register DstReg = MI.getOperand(0).getReg();
Register SrcReg0 = MI.getOperand(1).getReg();
Register SrcReg1 = MI.getOperand(2).getReg();
const LLT S32 = LLT::scalar(32);
const LLT S64 = LLT::scalar(64);
assert(MRI.getType(DstReg) == S64 && "This is a special case for s_mul_u64 "
"that handles only 64-bit operands.");
const RegisterBank *DstBank =
OpdMapper.getInstrMapping().getOperandMapping(0).BreakDown[0].RegBank;

// Replace G_AMDGPU_S_MUL_I64_I32 and G_AMDGPU_S_MUL_U64_U32
// with s_mul_u64 operation.
if (DstBank == &AMDGPU::SGPRRegBank) {
MI.setDesc(TII->get(AMDGPU::S_MUL_U64));
MRI.setRegClass(DstReg, &AMDGPU::SGPR_64RegClass);
MRI.setRegClass(SrcReg0, &AMDGPU::SGPR_64RegClass);
MRI.setRegClass(SrcReg1, &AMDGPU::SGPR_64RegClass);
return;
}

// Replace G_AMDGPU_S_MUL_I64_I32 and G_AMDGPU_S_MUL_U64_U32
// with a vector mad.
assert(MRI.getRegBankOrNull(DstReg) == &AMDGPU::VGPRRegBank &&
"The destination operand should be in vector registers.");

DebugLoc DL = MI.getDebugLoc();

// Extract the lower subregister from the first operand.
Register Op0L = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
MRI.setRegClass(Op0L, &AMDGPU::VGPR_32RegClass);
MRI.setType(Op0L, S32);
B.buildTrunc(Op0L, SrcReg0);

// Extract the lower subregister from the second operand.
Register Op1L = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
MRI.setRegClass(Op1L, &AMDGPU::VGPR_32RegClass);
MRI.setType(Op1L, S32);
B.buildTrunc(Op1L, SrcReg1);

unsigned NewOpc = Opc == AMDGPU::G_AMDGPU_S_MUL_U64_U32
? AMDGPU::G_AMDGPU_MAD_U64_U32
: AMDGPU::G_AMDGPU_MAD_I64_I32;

MachineIRBuilder B(MI);
Register Zero64 = B.buildConstant(S64, 0).getReg(0);
MRI.setRegClass(Zero64, &AMDGPU::VReg_64RegClass);
Register CarryOut = MRI.createVirtualRegister(&AMDGPU::VReg_64RegClass);
MRI.setRegClass(CarryOut, &AMDGPU::VReg_64RegClass);
B.buildInstr(NewOpc, {DstReg, CarryOut}, {Op0L, Op1L, Zero64});
MI.eraseFromParent();
return;
}
case AMDGPU::G_SEXT_INREG: {
SmallVector<Register, 2> SrcRegs(OpdMapper.getVRegs(1));
if (SrcRegs.empty())
Expand Down Expand Up @@ -3667,7 +3809,8 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {

case AMDGPU::G_AND:
case AMDGPU::G_OR:
case AMDGPU::G_XOR: {
case AMDGPU::G_XOR:
case AMDGPU::G_MUL: {
unsigned Size = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
if (Size == 1) {
const RegisterBank *DstBank
Expand Down Expand Up @@ -3735,7 +3878,6 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case AMDGPU::G_PTRMASK:
case AMDGPU::G_ADD:
case AMDGPU::G_SUB:
case AMDGPU::G_MUL:
case AMDGPU::G_SHL:
case AMDGPU::G_LSHR:
case AMDGPU::G_ASHR:
Expand All @@ -3753,6 +3895,8 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case AMDGPU::G_SHUFFLE_VECTOR:
case AMDGPU::G_SBFX:
case AMDGPU::G_UBFX:
case AMDGPU::G_AMDGPU_S_MUL_I64_I32:
case AMDGPU::G_AMDGPU_S_MUL_U64_U32:
if (isSALUMapping(MI))
return getDefaultMappingSOP(MI);
return getDefaultMappingVOP(MI);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.h
Expand Up @@ -84,6 +84,9 @@ class AMDGPURegisterBankInfo final : public AMDGPUGenRegisterBankInfo {
bool applyMappingMAD_64_32(MachineIRBuilder &B,
const OperandsMapper &OpdMapper) const;

void applyMappingSMULU64(MachineIRBuilder &B,
const OperandsMapper &OpdMapper) const;

Register handleD16VData(MachineIRBuilder &B, MachineRegisterInfo &MRI,
Register Reg) const;

Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/GCNSubtarget.h
Expand Up @@ -682,6 +682,8 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,

bool hasScalarAddSub64() const { return getGeneration() >= GFX12; }

bool hasScalarSMulU64() const { return getGeneration() >= GFX12; }

bool hasUnpackedD16VMem() const {
return HasUnpackedD16VMem;
}
Expand Down