Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] Add CodeGen support for GFX12 s_mul_u64 #75825

Merged
merged 8 commits into from Jan 8, 2024
Merged

[AMDGPU] Add CodeGen support for GFX12 s_mul_u64 #75825

merged 8 commits into from Jan 8, 2024

Conversation

jayfoad
Copy link
Contributor

@jayfoad jayfoad commented Dec 18, 2023

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-amdgpu

Author: Jay Foad (jayfoad)

Changes

Patch is 204.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75825.diff

18 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPUCombine.td (+8-1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp (+17-7)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp (+35)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+195-2)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.h (+2)
  • (modified) llvm/lib/Target/AMDGPU/GCNSubtarget.h (+2)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+65-1)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.cpp (+194)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.h (+6)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+12)
  • (modified) llvm/lib/Target/AMDGPU/SOPInstructions.td (+10)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-mul.mir (+315-472)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/mul.ll (+820)
  • (added) llvm/test/CodeGen/AMDGPU/GlobalISel/postlegalizercombiner-mul.mir (+60)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-mul.mir (+127)
  • (modified) llvm/test/CodeGen/AMDGPU/atomic_optimizations_global_pointer.ll (+62-68)
  • (modified) llvm/test/CodeGen/AMDGPU/mul.ll (+755-76)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
index 8d4cad4c07bc74..0c77fe72595880 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
@@ -104,6 +104,13 @@ def foldable_fneg : GICombineRule<
          [{ return Helper.matchFoldableFneg(*${ffn}, ${matchinfo}); }]),
   (apply [{ Helper.applyFoldableFneg(*${ffn}, ${matchinfo}); }])>;
 
+// Detects s_mul_u64 instructions whose higher bits are zero/sign extended.
+def smulu64 : GICombineRule<
+  (defs root:$smul, unsigned_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_MUL):$smul,
+         [{ return matchCombine_s_mul_u64(*${smul}, ${matchinfo}); }]),
+  (apply [{ applyCombine_s_mul_u64(*${smul}, ${matchinfo}); }])>;
+
 def sign_exension_in_reg_matchdata : GIDefMatchData<"MachineInstr *">;
 
 def sign_extension_in_reg : GICombineRule<
@@ -149,7 +156,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
   "AMDGPUPostLegalizerCombinerImpl",
   [all_combines, gfx6gfx7_combines, gfx8_combines,
    uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
-   rcp_sqrt_to_rsq, sign_extension_in_reg]> {
+   rcp_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
   let CombineAllMethodName = "tryCombineAllImpl";
 }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index fbee2888945185..8f1ae693c4e453 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -701,13 +701,23 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
           .maxScalar(0, S32);
     }
 
-    getActionDefinitionsBuilder(G_MUL)
-      .legalFor({S32, S16, V2S16})
-      .clampMaxNumElementsStrict(0, S16, 2)
-      .scalarize(0)
-      .minScalar(0, S16)
-      .widenScalarToNextMultipleOf(0, 32)
-      .custom();
+    if (ST.hasScalarSMulU64()) {
+      getActionDefinitionsBuilder(G_MUL)
+          .legalFor({S64, S32, S16, V2S16})
+          .clampMaxNumElementsStrict(0, S16, 2)
+          .scalarize(0)
+          .minScalar(0, S16)
+          .widenScalarToNextMultipleOf(0, 32)
+          .custom();
+    } else {
+      getActionDefinitionsBuilder(G_MUL)
+          .legalFor({S32, S16, V2S16})
+          .clampMaxNumElementsStrict(0, S16, 2)
+          .scalarize(0)
+          .minScalar(0, S16)
+          .widenScalarToNextMultipleOf(0, 32)
+          .custom();
+    }
     assert(ST.hasMad64_32());
 
     getActionDefinitionsBuilder({G_UADDSAT, G_USUBSAT, G_SADDSAT, G_SSUBSAT})
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
index 7b18e1f805d8f9..a0c86adde9370e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
@@ -104,6 +104,14 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
   void applyCombineSignExtendInReg(MachineInstr &MI,
                                    MachineInstr *&MatchInfo) const;
 
+  // Find the s_mul_u64 instructions where the higher bits are either
+  // zero-extended or sign-extended.
+  bool matchCombine_s_mul_u64(MachineInstr &MI, unsigned &NewOpcode) const;
+  // Replace the s_mul_u64 instructions with S_MUL_I64_I32_PSEUDO if the higher
+  // 33 bits are sign extended and with S_MUL_U64_U32_PSEUDO if the higher 32
+  // bits are zero extended.
+  void applyCombine_s_mul_u64(MachineInstr &MI, unsigned &NewOpcode) const;
+
 private:
 #define GET_GICOMBINER_CLASS_MEMBERS
 #define AMDGPUSubtarget GCNSubtarget
@@ -419,6 +427,33 @@ void AMDGPUPostLegalizerCombinerImpl::applyCombineSignExtendInReg(
   MI.eraseFromParent();
 }
 
+bool AMDGPUPostLegalizerCombinerImpl::matchCombine_s_mul_u64(
+    MachineInstr &MI, unsigned &NewOpcode) const {
+  Register Src0 = MI.getOperand(1).getReg();
+  Register Src1 = MI.getOperand(2).getReg();
+  KnownBits Op0KnownBits = KB->getKnownBits(Src0);
+  unsigned Op0LeadingZeros = Op0KnownBits.countMinLeadingZeros();
+  KnownBits Op1KnownBits = KB->getKnownBits(Src1);
+  unsigned Op1LeadingZeros = Op1KnownBits.countMinLeadingZeros();
+  if (Op0LeadingZeros >= 32 && Op1LeadingZeros >= 32) {
+    NewOpcode = AMDGPU::G_AMDGPU_S_MUL_U64_U32_PSEUDO;
+    return true;
+  }
+
+  unsigned Op0SignBits = KB->computeNumSignBits(Src0);
+  unsigned Op1SignBits = KB->computeNumSignBits(Src1);
+  if (Op0SignBits >= 33 && Op1SignBits >= 33) {
+    NewOpcode = AMDGPU::G_AMDGPU_S_MUL_I64_I32_PSEUDO;
+    return true;
+  }
+  return false;
+}
+
+void AMDGPUPostLegalizerCombinerImpl::applyCombine_s_mul_u64(
+    MachineInstr &MI, unsigned &NewOpcode) const {
+  MI.setDesc(TII.get(NewOpcode));
+}
+
 // Pass boilerplate
 // ================
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index c9412f720c62ec..a32e2b1621f53c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -2094,6 +2094,121 @@ bool AMDGPURegisterBankInfo::foldInsertEltToCmpSelect(
   return true;
 }
 
+// Break s_mul_u64 into 32-bit vector operations.
+void AMDGPURegisterBankInfo::applyMappingSMULU64(
+    const OperandsMapper &OpdMapper) const {
+
+  MachineInstr &MI = OpdMapper.getMI();
+  MachineRegisterInfo &MRI = OpdMapper.getMRI();
+  Register DstReg = MI.getOperand(0).getReg();
+
+  // Insert basic copies.
+  applyDefaultMapping(OpdMapper);
+
+  Register SrcReg0 = MI.getOperand(1).getReg();
+  Register SrcReg1 = MI.getOperand(2).getReg();
+  assert(MRI.getRegBankOrNull(SrcReg0) == &AMDGPU::VGPRRegBank &&
+         MRI.getRegBankOrNull(SrcReg1) == &AMDGPU::VGPRRegBank &&
+         "Source operands should be in vector registers.");
+  MachineBasicBlock *MBB = MI.getParent();
+  DebugLoc DL = MI.getDebugLoc();
+
+  // Extract subregisters from the first operand
+  Register NewSrcReg0 = MRI.createVirtualRegister(&AMDGPU::VReg_64RegClass);
+  MRI.setRegClass(NewSrcReg0, &AMDGPU::VReg_64RegClass);
+  BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), NewSrcReg0)
+      .addReg(SrcReg0, 0, MI.getOperand(1).getSubReg());
+  Register Op0L = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(Op0L, &AMDGPU::VGPR_32RegClass);
+  Register Op0H = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(Op0H, &AMDGPU::VGPR_32RegClass);
+  BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), Op0L)
+      .addReg(NewSrcReg0, 0, AMDGPU::sub0);
+  BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), Op0H)
+      .addReg(NewSrcReg0, 0, AMDGPU::sub1);
+
+  // Extract subregisters from the second operand.
+  Register NewSrcReg1 = MRI.createVirtualRegister(&AMDGPU::VReg_64RegClass);
+  MRI.setRegClass(NewSrcReg1, &AMDGPU::VReg_64RegClass);
+  BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), NewSrcReg1)
+      .addReg(SrcReg1, 0, MI.getOperand(2).getSubReg());
+  Register Op1L = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(Op1L, &AMDGPU::VGPR_32RegClass);
+  Register Op1H = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(Op1H, &AMDGPU::VGPR_32RegClass);
+  BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), Op1L)
+      .addReg(NewSrcReg1, 0, AMDGPU::sub0);
+  BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), Op1H)
+      .addReg(NewSrcReg1, 0, AMDGPU::sub1);
+
+  // Split s_mul_u64 in 32-bit multiplications.
+  Register NewDestReg = MRI.createVirtualRegister(&AMDGPU::VReg_64RegClass);
+  MRI.setRegClass(NewDestReg, &AMDGPU::VReg_64RegClass);
+  Register DestSub0 = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(DestSub0, &AMDGPU::VGPR_32RegClass);
+  Register DestSub1 = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(DestSub1, &AMDGPU::VGPR_32RegClass);
+
+  // The multiplication is done as follows:
+  //
+  //                            Op1H  Op1L
+  //                          * Op0H  Op0L
+  //                       --------------------
+  //                       Op1H*Op0L  Op1L*Op0L
+  //          + Op1H*Op0H  Op1L*Op0H
+  // -----------------------------------------
+  // (Op1H*Op0L + Op1L*Op0H + carry)  Op1L*Op0L
+  //
+  //  We drop Op1H*Op0H because the result of the multiplication is a 64-bit
+  //  value and that would overflow.
+  //  The low 32-bit value is Op1L*Op0L.
+  //  The high 32-bit value is Op1H*Op0L + Op1L*Op0H + carry (from
+  //  Op1L*Op0L).
+
+  Register Op1L_Op0H_Reg = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(Op1L_Op0H_Reg, &AMDGPU::VGPR_32RegClass);
+  BuildMI(*MBB, MI, DL, TII->get(AMDGPU::V_MUL_LO_U32_e64), Op1L_Op0H_Reg)
+      .addReg(Op1L)
+      .addReg(Op0H);
+
+  Register Op1H_Op0L_Reg = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(Op1H_Op0L_Reg, &AMDGPU::VGPR_32RegClass);
+  BuildMI(*MBB, MI, DL, TII->get(AMDGPU::V_MUL_LO_U32_e64), Op1H_Op0L_Reg)
+      .addReg(Op1H)
+      .addReg(Op0L);
+
+  Register CarryReg = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(CarryReg, &AMDGPU::VGPR_32RegClass);
+  BuildMI(*MBB, MI, DL, TII->get(AMDGPU::V_MUL_HI_U32_e64), CarryReg)
+      .addReg(Op1L)
+      .addReg(Op0L);
+
+  BuildMI(*MBB, MI, DL, TII->get(AMDGPU::V_MUL_LO_U32_e64), DestSub0)
+      .addReg(Op1L)
+      .addReg(Op0L);
+
+  Register AddReg = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+  MRI.setRegClass(AddReg, &AMDGPU::VGPR_32RegClass);
+  BuildMI(*MBB, MI, DL, TII->get(AMDGPU::V_ADD_U32_e32), AddReg)
+      .addReg(Op1L_Op0H_Reg)
+      .addReg(Op1H_Op0L_Reg);
+
+  BuildMI(*MBB, MI, DL, TII->get(AMDGPU::V_ADD_U32_e32), DestSub1)
+      .addReg(AddReg)
+      .addReg(CarryReg);
+
+  BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::REG_SEQUENCE), NewDestReg)
+      .addReg(DestSub0)
+      .addImm(AMDGPU::sub0)
+      .addReg(DestSub1)
+      .addImm(AMDGPU::sub1);
+
+  BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), DstReg)
+      .addReg(NewDestReg);
+
+  MI.eraseFromParent();
+}
+
 void AMDGPURegisterBankInfo::applyMappingImpl(
     MachineIRBuilder &B, const OperandsMapper &OpdMapper) const {
   MachineInstr &MI = OpdMapper.getMI();
@@ -2393,14 +2508,24 @@ void AMDGPURegisterBankInfo::applyMappingImpl(
   case AMDGPU::G_UMAX: {
     Register DstReg = MI.getOperand(0).getReg();
     LLT DstTy = MRI.getType(DstReg);
+    const LLT S64 = LLT::scalar(64);
+    const RegisterBank *DstBank =
+        OpdMapper.getInstrMapping().getOperandMapping(0).BreakDown[0].RegBank;
+
+    // Special case for s_mul_u64. There is not a vector equivalent of
+    // s_mul_u64. Hence, we have to break down s_mul_u64 into 32-bit vector
+    // multiplications.
+    if (Opc == AMDGPU::G_MUL && DstTy == S64 &&
+        DstBank == &AMDGPU::VGPRRegBank) {
+      applyMappingSMULU64(OpdMapper);
+      return;
+    }
 
     // 16-bit operations are VALU only, but can be promoted to 32-bit SALU.
     // Packed 16-bit operations need to be scalarized and promoted.
     if (DstTy != LLT::scalar(16) && DstTy != LLT::fixed_vector(2, 16))
       break;
 
-    const RegisterBank *DstBank =
-      OpdMapper.getInstrMapping().getOperandMapping(0).BreakDown[0].RegBank;
     if (DstBank == &AMDGPU::VGPRRegBank)
       break;
 
@@ -2451,6 +2576,72 @@ void AMDGPURegisterBankInfo::applyMappingImpl(
 
     return;
   }
+  case AMDGPU::G_AMDGPU_S_MUL_I64_I32_PSEUDO:
+  case AMDGPU::G_AMDGPU_S_MUL_U64_U32_PSEUDO: {
+    // This is a special case for s_mul_u64. We use
+    // G_AMDGPU_S_MUL_I64_I32_PSEUDO opcode to represent an s_mul_u64 operation
+    // where the 33 higher bits are sign-extended and
+    // G_AMDGPU_S_MUL_U64_U32_PSEUDO opcode to represent an s_mul_u64 operation
+    // where the 32 higher bits are zero-extended. In case scalar registers are
+    // selected, both opcodes are lowered as s_mul_u64. If the vector registers
+    // are selected, then G_AMDGPU_S_MUL_I64_I32_PSEUDO and
+    // G_AMDGPU_S_MUL_U64_U32_PSEUDO are lowered with a vector mad instruction.
+
+    // Insert basic copies.
+    applyDefaultMapping(OpdMapper);
+
+    Register DstReg = MI.getOperand(0).getReg();
+    Register SrcReg0 = MI.getOperand(1).getReg();
+    Register SrcReg1 = MI.getOperand(2).getReg();
+    const LLT S32 = LLT::scalar(32);
+    const LLT S64 = LLT::scalar(64);
+    assert(MRI.getType(DstReg) == S64 && "This is a special case for s_mul_u64 "
+                                         "that handles only 64-bit operands.");
+    const RegisterBank *DstBank =
+        OpdMapper.getInstrMapping().getOperandMapping(0).BreakDown[0].RegBank;
+
+    // Replace G_AMDGPU_S_MUL_I64_I32_PSEUDO and G_AMDGPU_S_MUL_U64_U32_PSEUDO
+    // with s_mul_u64 operation.
+    if (DstBank == &AMDGPU::SGPRRegBank) {
+      MI.setDesc(TII->get(AMDGPU::S_MUL_U64));
+      MRI.setRegClass(DstReg, &AMDGPU::SGPR_64RegClass);
+      MRI.setRegClass(SrcReg0, &AMDGPU::SGPR_64RegClass);
+      MRI.setRegClass(SrcReg1, &AMDGPU::SGPR_64RegClass);
+      return;
+    }
+
+    // Replace G_AMDGPU_S_MUL_I64_I32_PSEUDO and G_AMDGPU_S_MUL_U64_U32_PSEUDO
+    // with a vector mad.
+    assert(MRI.getRegBankOrNull(DstReg) == &AMDGPU::VGPRRegBank &&
+           "The destination operand should be in vector registers.");
+
+    DebugLoc DL = MI.getDebugLoc();
+
+    // Extract the lower subregister from the first operand.
+    Register Op0L = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+    MRI.setRegClass(Op0L, &AMDGPU::VGPR_32RegClass);
+    MRI.setType(Op0L, S32);
+    B.buildTrunc(Op0L, SrcReg0);
+
+    // Extract the lower subregister from the second operand.
+    Register Op1L = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
+    MRI.setRegClass(Op1L, &AMDGPU::VGPR_32RegClass);
+    MRI.setType(Op1L, S32);
+    B.buildTrunc(Op1L, SrcReg1);
+
+    unsigned NewOpc = Opc == AMDGPU::G_AMDGPU_S_MUL_U64_U32_PSEUDO
+                          ? AMDGPU::G_AMDGPU_MAD_U64_U32
+                          : AMDGPU::G_AMDGPU_MAD_I64_I32;
+
+    MachineIRBuilder B(MI);
+    Register Zero64 = B.buildConstant(S64, 0).getReg(0);
+    MRI.setRegClass(Zero64, &AMDGPU::VReg_64RegClass);
+    Register CarryOut = MRI.createVirtualRegister(&AMDGPU::VReg_64RegClass);
+    MRI.setRegClass(CarryOut, &AMDGPU::VReg_64RegClass);
+    B.buildInstr(NewOpc, {DstReg, CarryOut}, {Op0L, Op1L, Zero64});
+    MI.eraseFromParent();
+    return;
+  }
   case AMDGPU::G_SEXT_INREG: {
     SmallVector<Register, 2> SrcRegs(OpdMapper.getVRegs(1));
     if (SrcRegs.empty())
@@ -3753,6 +3944,8 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case AMDGPU::G_SHUFFLE_VECTOR:
   case AMDGPU::G_SBFX:
   case AMDGPU::G_UBFX:
+  case AMDGPU::G_AMDGPU_S_MUL_I64_I32_PSEUDO:
+  case AMDGPU::G_AMDGPU_S_MUL_U64_U32_PSEUDO:
     if (isSALUMapping(MI))
       return getDefaultMappingSOP(MI);
     return getDefaultMappingVOP(MI);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.h b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.h
index b5d16e70ab23a2..2b69a562f6a687 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.h
@@ -84,6 +84,8 @@ class AMDGPURegisterBankInfo final : public AMDGPUGenRegisterBankInfo {
   bool applyMappingMAD_64_32(MachineIRBuilder &B,
                              const OperandsMapper &OpdMapper) const;
 
+  void applyMappingSMULU64(const OperandsMapper &OpdMapper) const;
+
   Register handleD16VData(MachineIRBuilder &B, MachineRegisterInfo &MRI,
                           Register Reg) const;
 
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
index 91a70930326955..15a2f4b3875954 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
@@ -682,6 +682,8 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,
 
   bool hasScalarAddSub64() const { return getGeneration() >= GFX12; }
 
+  bool hasScalarSMulU64() const { return getGeneration() >= GFX12; }
+
   bool hasUnpackedD16VMem() const {
     return HasUnpackedD16VMem;
   }
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 34826809c1a6bf..109d5e4999c75a 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -760,6 +760,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
 
   setOperationAction({ISD::SMULO, ISD::UMULO}, MVT::i64, Custom);
 
+  if (Subtarget->hasScalarSMulU64())
+    setOperationAction(ISD::MUL, MVT::i64, Custom);
+
   if (Subtarget->hasMad64_32())
     setOperationAction({ISD::SMUL_LOHI, ISD::UMUL_LOHI}, MVT::i32, Custom);
 
@@ -5426,7 +5429,6 @@ SDValue SITargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::SRL:
   case ISD::ADD:
   case ISD::SUB:
-  case ISD::MUL:
   case ISD::SMIN:
   case ISD::SMAX:
   case ISD::UMIN:
@@ -5440,6 +5442,8 @@ SDValue SITargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::SADDSAT:
   case ISD::SSUBSAT:
     return splitBinaryVectorOp(Op, DAG);
+  case ISD::MUL:
+    return lowerMUL(Op, DAG);
   case ISD::SMULO:
   case ISD::UMULO:
     return lowerXMULO(Op, DAG);
@@ -6092,6 +6096,66 @@ SDValue SITargetLowering::lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getNode(ISD::FLDEXP, DL, VT, Op.getOperand(0), TruncExp);
 }
 
+// Custom lowering for vector multiplications and s_mul_u64.
+SDValue SITargetLowering::lowerMUL(SDValue Op, SelectionDAG &DAG) const {
+  EVT VT = Op.getValueType();
+
+  // Split vector operands.
+  if (VT.isVector())
+    return splitBinaryVectorOp(Op, DAG);
+
+  assert(VT == MVT::i64 && "The following code is a special for s_mul_u64");
+
+  // There are four ways to lower s_mul_u64:
+  //
+  // 1. If all the operands are uniform, then we lower it as it is.
+  //
+  // 2. If the operands are divergent, then we have to split s_mul_u64 in 32-bit
+  //    multiplications because there is not a vector equivalent of s_mul_u64.
+  //
+  // 3. If the cost model decides that it is more efficient to use vector
+  //    registers, then we have to split s_mul_u64 in 32-bit multiplications.
+  //    This happens in splitScalarSMULU64() in SIInstrInfo.cpp .
+  //
+  // 4. If the cost model decides to use vector registers and both of the
+  //    operands are zero-extended/sign-extended from 32-bits, then we split the
+  //    s_mul_u64 in two 32-bit multiplications. The problem is that it is not
+  //    possible to check if the operands are zero-extended or sign-extended in
+  //    SIInstrInfo.cpp. For this reason, here, we replace s_mul_u64 with
+  //    s_mul_u64_u32_pseudo if both operands are zero-extended and we replace
+  //    s_mul_u64 with s_mul_i64_i32_pseudo if both operands are sign-extended.
+  //    If the cost model decides that we have to use vector registers, then
+  //    splitScalarSMulPseudo() (in SIInstrInfo.cpp) split s_mul_u64_u32/
+  //    s_mul_i64_i32_pseudo in two vector multiplications. If the cost model
+  //    decides that we should use scalar registers, then s_mul_u64_u32_pseudo/
+  //    s_mul_i64_i32_pseudo is lowered as s_mul_u64 in expandPostRAPseudo() in
+  //    SIInstrInfo.cpp .
+
+  if (Op->isDivergent())
+    return SDValue();
+
+  SDValue Op0 = Op.getOperand(0);
+  SDValue Op1 = Op.getOperand(1);
+  // If all the operands are zero-enteted to 32-bits, then we replace s_mul_u64
+  // with s_mul_u64_u32_pseudo. If all the operands are sign-extended to
+  // 32-bits, then we replace s_mul_u64 with s_mul_i64_i32_pseudo.
+  KnownBits Op0KnownBits = DAG.computeKnownBits(Op0);
+  unsigned Op0LeadingZeros = Op0KnownBits.countMinLeadingZeros();
+  KnownBits Op1KnownBits = DAG.computeKnownBits(Op1);
+  unsigned Op1LeadingZeros = Op1KnownBits.countMinLeadingZeros();
+  SDLoc SL(Op);
+  if (Op0LeadingZeros >= 32 && Op1LeadingZeros >= 32)
+    return SDValue(
+        DAG.getMachineNode(AMDGPU::S_MUL_U64_U32_PSEUDO, SL, VT, Op0, Op1), 0);
+  unsigned Op0SignBits = DAG.ComputeNumSignBits(Op0);
+  unsigned Op1SignBits = DAG.ComputeNumSignBits(Op1);
+  if (Op0SignBits >= 33 && Op1SignBits >= 33)
+    return SDValue(
+        DAG.getMachineNode(AMDGPU::S_MUL_I64_I32_PSEUDO, SL, VT, Op0, Op1), 0);
+  // If all the operands are uniform, then we lower s_mul_u64 as it is.
+  return Op;
+}
+
 SDValue SITargetLowering::lowerXMULO(SDValue Op, SelectionDAG &DAG) const {
   EVT VT = Op.ge...
[truncated]

@jayfoad
Copy link
Contributor Author

jayfoad commented Jan 2, 2024

Ping!

Change-Id: Ic8bee17fb7728c17acc0876c2b13507a3bbbee8f
Change-Id: Ie4c581b10023cbc36607edddfd624aa2adc701f1
@jayfoad
Copy link
Contributor Author

jayfoad commented Jan 8, 2024

Ping!

Copy link
Collaborator

@rampitec rampitec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jayfoad jayfoad merged commit daa4728 into llvm:main Jan 8, 2024
4 checks passed
@jayfoad jayfoad deleted the gfx12-mulu64 branch January 8, 2024 19:13
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants