Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GlobalISel] convergence control tokens and intrinsics #67006

Merged
merged 3 commits into from
Mar 18, 2024

Conversation

ssahasra
Copy link
Collaborator

Support for lowering convergence control to GMIR:

  • Introduce new G_CONVERGENCECTRL_* opcodes.
  • In the IR translator, convert the LLVM token type to LLT::token(), which is an alias for the s0 type. These show up as implicit uses on convergent operations.
  • In the machine verifier, enforce the static rules of convergence control.

Note that the lowering of the new GMIR opcodes is entirely target-specific. It is generally expected that the backend will use convergence control to change the CFG of each function, and then discard the tokens. This is currently a work in progress for AMDGPU.

Differential Revision: https://reviews.llvm.org/D158147

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 21, 2023

@llvm/pr-subscribers-llvm-adt
@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-support
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-llvm-globalisel

Changes

Support for lowering convergence control to GMIR:

  • Introduce new G_CONVERGENCECTRL_* opcodes.
  • In the IR translator, convert the LLVM token type to LLT::token(), which is an alias for the s0 type. These show up as implicit uses on convergent operations.
  • In the machine verifier, enforce the static rules of convergence control.

Note that the lowering of the new GMIR opcodes is entirely target-specific. It is generally expected that the backend will use convergence control to change the CFG of each function, and then discard the tokens. This is currently a work in progress for AMDGPU.

Differential Revision: https://reviews.llvm.org/D158147


Patch is 42.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67006.diff

26 Files Affected:

  • (modified) llvm/include/llvm/ADT/GenericConvergenceVerifier.h (+5)
  • (modified) llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h (+4)
  • (modified) llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h (+21)
  • (modified) llvm/include/llvm/CodeGen/LowLevelType.h (+29)
  • (added) llvm/include/llvm/CodeGen/MachineConvergenceVerifier.h (+28)
  • (modified) llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h (+8-22)
  • (modified) llvm/include/llvm/Support/TargetOpcodes.def (+4)
  • (modified) llvm/include/llvm/Target/GenericOpcodes.td (+28)
  • (modified) llvm/lib/CodeGen/CMakeLists.txt (+1)
  • (modified) llvm/lib/CodeGen/GlobalISel/CallLowering.cpp (+14)
  • (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+55-6)
  • (modified) llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp (+8)
  • (modified) llvm/lib/CodeGen/LowLevelTypeUtils.cpp (+4)
  • (modified) llvm/lib/CodeGen/MIRParser/MIParser.cpp (+8-5)
  • (added) llvm/lib/CodeGen/MachineConvergenceVerifier.cpp (+97)
  • (modified) llvm/lib/CodeGen/MachineVerifier.cpp (+34)
  • (modified) llvm/lib/IR/ConvergenceVerifier.cpp (+24-7)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp (+6)
  • (modified) llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir (+9)
  • (added) llvm/test/CodeGen/AMDGPU/GlobalISel/irtranslator-convergence-tokens.ll (+83)
  • (added) llvm/test/CodeGen/AMDGPU/verify-convergencectrl/basic.mir (+38)
  • (added) llvm/test/CodeGen/AMDGPU/verify-convergencectrl/cycles.mir (+53)
  • (added) llvm/test/CodeGen/AMDGPU/verify-convergencectrl/mixed2.mir (+15)
  • (added) llvm/test/CodeGen/AMDGPU/verify-convergencectrl/region-nesting.mir (+25)
  • (removed) llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid4.mir (-10)
  • (modified) llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid6.mir (+1-1)
diff --git a/llvm/include/llvm/ADT/GenericConvergenceVerifier.h b/llvm/include/llvm/ADT/GenericConvergenceVerifier.h
index 0810a0701322907..24cd0cda3708517 100644
--- a/llvm/include/llvm/ADT/GenericConvergenceVerifier.h
+++ b/llvm/include/llvm/ADT/GenericConvergenceVerifier.h
@@ -60,6 +60,10 @@ template <typename ContextT> class GenericConvergenceVerifier {
     NoConvergence
   } ConvergenceKind = NoConvergence;
 
+  /// The control token operation performed by a convergence control Intrinsic in LLVM IR,
+  /// or by a G_CONVERGENCECTRL* instruction in GMIR.
+  enum ConvOpKind { CONV_ANCHOR, CONV_ENTRY, CONV_LOOP, CONV_NONE };
+
   // Cache token uses found so far. Note that we track the unique definitions
   // and not the token values.
   DenseMap<const InstructionT *, const InstructionT *> Tokens;
@@ -68,6 +72,7 @@ template <typename ContextT> class GenericConvergenceVerifier {
 
   static bool isInsideConvergentFunction(const InstructionT &I);
   static bool isConvergent(const InstructionT &I);
+  static ConvOpKind getConvOp(const InstructionT &I);
   const InstructionT *findAndCheckConvergenceTokenUsed(const InstructionT &I);
 
   void reportFailure(const Twine &Message, ArrayRef<Printable> Values);
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h
index 4a799eec8899a51..ff2721495dc3b97 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h
@@ -117,6 +117,9 @@ class CallLowering {
     /// vreg that the swifterror should be copied into after the call.
     Register SwiftErrorVReg;
 
+    /// Valid if the call is a controlled convergent operation.
+    Register ConvergenceCtrlToken;
+
     /// Original IR callsite corresponding to this call, if available.
     const CallBase *CB = nullptr;
 
@@ -583,6 +586,7 @@ class CallLowering {
   bool lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &Call,
                  ArrayRef<Register> ResRegs,
                  ArrayRef<ArrayRef<Register>> ArgRegs, Register SwiftErrorVReg,
+                 Register ConvergenceCtrlToken,
                  std::function<unsigned()> GetCalleeReg) const;
 
   /// For targets which want to use big-endian can enable it with
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
index bffc03ed0187e63..d51daa926c9575d 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
@@ -554,6 +554,10 @@ class IRTranslator : public MachineFunctionPass {
     return false;
   }
 
+  bool translateConvergenceControlIntrinsic(const CallInst &CI,
+                                            Intrinsic::ID ID,
+                                            MachineIRBuilder &MIRBuilder);
+
   /// @}
 
   // Builder for machine instruction a la IRBuilder.
@@ -671,6 +675,23 @@ class IRTranslator : public MachineFunctionPass {
     return Regs[0];
   }
 
+  Register getOrCreateConvergenceTokenVReg(const Value &Token) {
+    assert(Token.getType()->isTokenTy());
+    auto &Regs = *VMap.getVRegs(Token);
+    if (!Regs.empty()) {
+      assert(Regs.size() == 1 &&
+             "Expected a single register for convergence tokens.");
+      return Regs[0];
+    }
+
+    auto Reg = MRI->createGenericVirtualRegister(LLT::token());
+    Regs.push_back(Reg);
+    auto &Offsets = *VMap.getOffsets(Token);
+    if (Offsets.empty())
+      Offsets.push_back(0);
+    return Reg;
+  }
+
   /// Allocate some vregs and offsets in the VMap. Then populate just the
   /// offsets while leaving the vregs empty.
   ValueToVRegInfo::VRegListT &allocateVRegs(const Value &Val);
diff --git a/llvm/include/llvm/CodeGen/LowLevelType.h b/llvm/include/llvm/CodeGen/LowLevelType.h
index b7477834a1ff5e1..e0cd56849909422 100644
--- a/llvm/include/llvm/CodeGen/LowLevelType.h
+++ b/llvm/include/llvm/CodeGen/LowLevelType.h
@@ -45,6 +45,13 @@ class LLT {
                /*AddressSpace=*/0};
   }
 
+  /// Get a low-level token; just a scalar with zero bits (or no size).
+  static constexpr LLT token() {
+    return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true,
+               ElementCount::getFixed(0), /*SizeInBits=*/0,
+               /*AddressSpace=*/0};
+  }
+
   /// Get a low-level pointer in the given address space.
   static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits) {
     assert(SizeInBits > 0 && "invalid pointer size");
@@ -304,6 +311,28 @@ class LLT {
   /// described in static const *Field variables. Each of these variables
   /// is a 2-element array, with the first element describing the bitfield size
   /// and the second element describing the bitfield offset.
+  ///
+  /// +--------+---------+--------+----------+----------------------+
+  /// |isScalar|isPointer|isVector| RawData  |Notes                 |
+  /// +--------+---------+--------+----------+----------------------+
+  /// |   0    |    0    |   0    |    0     |Invalid               |
+  /// +--------+---------+--------+----------+----------------------+
+  /// |   0    |    0    |   1    |    0     |Tombstone Key         |
+  /// +--------+---------+--------+----------+----------------------+
+  /// |   0    |    1    |   0    |    0     |Empty Key             |
+  /// +--------+---------+--------+----------+----------------------+
+  /// |   1    |    0    |   0    |    0     |Token                 |
+  /// +--------+---------+--------+----------+----------------------+
+  /// |   1    |    0    |   0    | non-zero |Scalar                |
+  /// +--------+---------+--------+----------+----------------------+
+  /// |   0    |    1    |   0    | non-zero |Pointer               |
+  /// +--------+---------+--------+----------+----------------------+
+  /// |   0    |    0    |   1    | non-zero |Vector of non-pointer |
+  /// +--------+---------+--------+----------+----------------------+
+  /// |   0    |    1    |   1    | non-zero |Vector of pointer     |
+  /// +--------+---------+--------+----------+----------------------+
+  ///
+  /// Everything else is reserved.
   typedef int BitFieldInfo[2];
   ///
   /// This is how the bitfields are packed per Kind:
diff --git a/llvm/include/llvm/CodeGen/MachineConvergenceVerifier.h b/llvm/include/llvm/CodeGen/MachineConvergenceVerifier.h
new file mode 100644
index 000000000000000..fd13546f2ae32f0
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/MachineConvergenceVerifier.h
@@ -0,0 +1,28 @@
+//===- MachineConvergenceVerifier.h - Verify convergenctrl ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// This file declares the GMIR IR specialization of the
+/// GenericConvergenceVerifier template.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
+#define LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
+
+#include "llvm/ADT/GenericConvergenceVerifier.h"
+#include "llvm/CodeGen/MachineSSAContext.h"
+
+namespace llvm {
+
+using MachineConvergenceVerifier =
+    GenericConvergenceVerifier<MachineSSAContext>;
+
+} // namespace llvm
+
+#endif // LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
diff --git a/llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h b/llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h
index 2ba81015cb7b641..037e406fc57d38b 100644
--- a/llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h
+++ b/llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h
@@ -49,17 +49,6 @@ using namespace llvm;
     }                                                                          \
   } while (false)
 
-static bool isConvergenceControlIntrinsic(unsigned IntrinsicID) {
-  switch (IntrinsicID) {
-  default:
-    return false;
-  case Intrinsic::experimental_convergence_anchor:
-  case Intrinsic::experimental_convergence_entry:
-  case Intrinsic::experimental_convergence_loop:
-    return true;
-  }
-}
-
 namespace llvm {
 template <class ContextT> void GenericConvergenceVerifier<ContextT>::clear() {
   Tokens.clear();
@@ -74,12 +63,11 @@ void GenericConvergenceVerifier<ContextT>::visit(const BlockT &BB) {
 
 template <class ContextT>
 void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
-  auto ID = ContextT::getIntrinsicID(I);
+  auto ConvOp = getConvOp(I);
   auto *TokenDef = findAndCheckConvergenceTokenUsed(I);
-  bool IsCtrlIntrinsic = true;
 
-  switch (ID) {
-  case Intrinsic::experimental_convergence_entry:
+  switch (ConvOp) {
+  case CONV_ENTRY:
     Check(isInsideConvergentFunction(I),
           "Entry intrinsic can occur only in a convergent function.",
           {Context.print(&I)});
@@ -91,13 +79,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
           "same basic block.",
           {Context.print(&I)});
     LLVM_FALLTHROUGH;
-  case Intrinsic::experimental_convergence_anchor:
+  case CONV_ANCHOR:
     Check(!TokenDef,
           "Entry or anchor intrinsic cannot have a convergencectrl token "
           "operand.",
           {Context.print(&I)});
     break;
-  case Intrinsic::experimental_convergence_loop:
+  case CONV_LOOP:
     Check(TokenDef, "Loop intrinsic must have a convergencectrl token operand.",
           {Context.print(&I)});
     Check(!SeenFirstConvOp,
@@ -106,14 +94,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
           {Context.print(&I)});
     break;
   default:
-    IsCtrlIntrinsic = false;
     break;
   }
 
   if (isConvergent(I))
     SeenFirstConvOp = true;
 
-  if (TokenDef || IsCtrlIntrinsic) {
+  if (TokenDef || ConvOp != CONV_NONE) {
     Check(isConvergent(I),
           "Convergence control token can only be used in a convergent call.",
           {Context.print(&I)});
@@ -174,8 +161,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
       return;
     }
 
-    Check(ContextT::getIntrinsicID(*User) ==
-              Intrinsic::experimental_convergence_loop,
+    Check(getConvOp(*User) == CONV_LOOP,
           "Convergence token used by an instruction other than "
           "llvm.experimental.convergence.loop in a cycle that does "
           "not contain the token's definition.",
@@ -212,7 +198,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
     for (auto &I : *BB) {
       if (auto *Token = Tokens.lookup(&I))
         checkToken(Token, &I, LiveTokens);
-      if (isConvergenceControlIntrinsic(ContextT::getIntrinsicID(I)))
+      if (getConvOp(I) != CONV_NONE)
         LiveTokens.push_back(&I);
     }
 
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index e02b1a1d01cfec0..04bdc395b302235 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -436,6 +436,10 @@ HANDLE_TARGET_OPCODE(G_INTRINSIC_CONVERGENT)
 /// Generic intrinsic use (with side effects).
 HANDLE_TARGET_OPCODE(G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS)
 
+HANDLE_TARGET_OPCODE(G_CONVERGENCECTRL_ENTRY)
+HANDLE_TARGET_OPCODE(G_CONVERGENCECTRL_ANCHOR)
+HANDLE_TARGET_OPCODE(G_CONVERGENCECTRL_LOOP)
+
 /// Generic extension allowing rubbish in high bits.
 HANDLE_TARGET_OPCODE(G_ANYEXT)
 
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index 0a4fbaa12f96cec..27d4b8a4759ff92 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1276,6 +1276,34 @@ def G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS : GenericInstruction {
   let isConvergent = true;
 }
 
+//------------------------------------------------------------------------------
+// Convergence control operations.
+//------------------------------------------------------------------------------
+
+// Capture the set of threads that are converged on entry to a function.
+def G_CONVERGENCECTRL_ENTRY : GenericInstruction {
+  let InOperandList = (ins);
+  let OutOperandList = (outs type0:$dst);
+  let isConvergent = true;
+  let hasSideEffects = false;
+}
+
+// Capture an implementation-defined subset of converged threads.
+def G_CONVERGENCECTRL_ANCHOR : GenericInstruction {
+  let InOperandList = (ins);
+  let OutOperandList = (outs type0:$dst);
+  let isConvergent = true;
+  let hasSideEffects = false;
+}
+
+// Capture the convergence of threads in a cycle.
+def G_CONVERGENCECTRL_LOOP : GenericInstruction {
+  let InOperandList = (ins type0:$src);
+  let OutOperandList = (outs type0:$dst);
+  let isConvergent = true;
+  let hasSideEffects = false;
+}
+
 //------------------------------------------------------------------------------
 // Branches.
 //------------------------------------------------------------------------------
diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt
index 389c70d04f17ba3..41f6612b68a1f12 100644
--- a/llvm/lib/CodeGen/CMakeLists.txt
+++ b/llvm/lib/CodeGen/CMakeLists.txt
@@ -121,6 +121,7 @@ add_llvm_component_library(LLVMCodeGen
   MachineBranchProbabilityInfo.cpp
   MachineCFGPrinter.cpp
   MachineCombiner.cpp
+  MachineConvergenceVerifier.cpp
   MachineCopyPropagation.cpp
   MachineCSE.cpp
   MachineCheckDebugify.cpp
diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
index 0b1f151135be9fc..45738cc636cb7af 100644
--- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
@@ -21,6 +21,7 @@
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/IR/DataLayout.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Target/TargetMachine.h"
@@ -87,10 +88,20 @@ void CallLowering::addArgFlagsFromAttributes(ISD::ArgFlagsTy &Flags,
   });
 }
 
+static bool hasConvergenceEntryToken(const CallBase &CB) {
+  auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl);
+  if (!Bundle)
+    return true;
+  auto *Token = Bundle->Inputs[0].get();
+  auto *Def = cast<IntrinsicInst>(Token);
+  return Def->getIntrinsicID() == Intrinsic::experimental_convergence_entry;
+}
+
 bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
                              ArrayRef<Register> ResRegs,
                              ArrayRef<ArrayRef<Register>> ArgRegs,
                              Register SwiftErrorVReg,
+                             Register ConvergenceCtrlToken,
                              std::function<unsigned()> GetCalleeReg) const {
   CallLoweringInfo Info;
   const DataLayout &DL = MIRBuilder.getDataLayout();
@@ -121,6 +132,8 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
     CanBeTailCalled = false;
   }
 
+  if (!hasConvergenceEntryToken(CB))
+    CanBeTailCalled = false;
 
   // First step is to marshall all the function's parameters into the correct
   // physregs and memory locations. Gather the sequence of argument types that
@@ -176,6 +189,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
   Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
   Info.CallConv = CallConv;
   Info.SwiftErrorVReg = SwiftErrorVReg;
+  Info.ConvergenceCtrlToken = ConvergenceCtrlToken;
   Info.IsMustTailCall = CB.isMustTailCall();
   Info.IsTailCall = CanBeTailCalled;
   Info.IsVarArg = IsVarArg;
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 9dac3d083994f9f..692e2fbdb313add 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -210,8 +210,9 @@ ArrayRef<Register> IRTranslator::getOrCreateVRegs(const Value &Val) {
   auto *VRegs = VMap.getVRegs(Val);
   auto *Offsets = VMap.getOffsets(Val);
 
-  assert(Val.getType()->isSized() &&
-         "Don't know how to create an empty vreg");
+  if (!Val.getType()->isTokenTy())
+    assert(Val.getType()->isSized() &&
+           "Don't know how to create an empty vreg");
 
   SmallVector<LLT, 4> SplitTys;
   computeValueLLTs(*DL, *Val.getType(), SplitTys,
@@ -1953,6 +1954,37 @@ bool IRTranslator::translateIfEntryValueArgument(
   return true;
 }
 
+static unsigned getConvOpcode(Intrinsic::ID ID) {
+  switch (ID) {
+  default:
+    llvm_unreachable("Unexpected intrinsic");
+    return 0;
+  case Intrinsic::experimental_convergence_anchor:
+    return TargetOpcode::G_CONVERGENCECTRL_ANCHOR;
+  case Intrinsic::experimental_convergence_entry:
+    return TargetOpcode::G_CONVERGENCECTRL_ENTRY;
+  case Intrinsic::experimental_convergence_loop:
+    return TargetOpcode::G_CONVERGENCECTRL_LOOP;
+  }
+}
+
+bool IRTranslator::translateConvergenceControlIntrinsic(
+    const CallInst &CI, Intrinsic::ID ID, MachineIRBuilder &MIRBuilder) {
+  MachineInstrBuilder MIB = MIRBuilder.buildInstr(getConvOpcode(ID));
+  Register OutputReg = getOrCreateConvergenceTokenVReg(CI);
+  MIB.addDef(OutputReg);
+
+  if (ID == Intrinsic::experimental_convergence_loop) {
+    auto Bundle = CI.getOperandBundle(LLVMContext::OB_convergencectrl);
+    assert(Bundle && "Expected a convergence control token.");
+    Register InputReg =
+        getOrCreateConvergenceTokenVReg(*Bundle->Inputs[0].get());
+    MIB.addUse(InputReg);
+  }
+
+  return true;
+}
+
 bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
                                            MachineIRBuilder &MIRBuilder) {
   if (auto *MI = dyn_cast<AnyMemIntrinsic>(&CI)) {
@@ -2410,7 +2442,10 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
 #include "llvm/IR/ConstrainedOps.def"
     return translateConstrainedFPIntrinsic(cast<ConstrainedFPIntrinsic>(CI),
                                            MIRBuilder);
-
+  case Intrinsic::experimental_convergence_anchor:
+  case Intrinsic::experimental_convergence_entry:
+  case Intrinsic::experimental_convergence_loop:
+    return translateConvergenceControlIntrinsic(CI, ID, MIRBuilder);
   }
   return false;
 }
@@ -2461,12 +2496,18 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
     }
   }
 
+  Register ConvergenceCtrlToken = 0;
+  if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl)) {
+    const auto &Token = *Bundle->Inputs[0].get();
+    ConvergenceCtrlToken = getOrCreateConvergenceTokenVReg(Token);
+  }
+
   // We don't set HasCalls on MFI here yet because call lowering may decide to
   // optimize into tail calls. Instead, we defer that to selection where a final
   // scan is done to check if any instructions are calls.
-  bool Success =
-      CLI->lowerCall(MIRBuilder, CB, Res, Args, SwiftErrorVReg,
-                     [&]() { return getOrCreateVReg(*CB.getCalledOperand()); });
+  bool Success = CLI->lowerCall(
+      MIRBuilder, CB, Res, Args, SwiftErrorVReg, ConvergenceCtrlToken,
+      [&]() { return getOrCreateVReg(*CB.getCalledOperand()); });
 
   // Check if we just inserted a tail call.
   if (Success) {
@@ -2581,6 +2622,14 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
         MF->getMachineMemOperand(MPI, Info.flags, MemTy, Alignment, CI.getAAMetadata()));
   }
 
+  if (CI.isConvergent()) {
+    if (auto Bundle = CI.getOperandBundle(LLVMContext::OB_convergencectrl)) {
+      auto *Token = Bundle->Inputs[0].get();
+      Register TokenReg = getOrCreateVReg(*Token);
+      MIB.addUse(TokenReg, RegState::Implicit);
+    }
+  }
+
   return true;
 }
 
diff --git a/llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp b/llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp
index 00dba57fcb80227..2dfaa7381ae3355 100644
--- a/llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp
@@ -592,6 +592,14 @@ bool InlineAsmLowering::lowerInlineAsm(
     }
   }
 
+  if (auto Bundle = Call.getOperandBundle(LLVMContext::OB_convergencectrl)) {
+    auto *Token = Bundle->Inputs[0].get();
+    ArrayRef<Register> SourceRegs = GetOrCreateVRegs(*Token);
+    assert(SourceRegs.size() == 1 &&
+           "Expected the control token to f...
[truncated]

Copy link
Collaborator

@nhaehnle nhaehnle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't had the time to review this in detail, but the overall approach seems sound to me.

// that since the token type may have other implicit uses. Instead we use it
// as a way to identify convergence control token operands.
const auto *Def = MRI.getUniqueVRegDef(MO.getReg());
if (!Def)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't happen for G_* instructions (and ideally would be impossible for all of SSA)

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Mar 7, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@ssahasra
Copy link
Collaborator Author

ssahasra commented Mar 7, 2024

  • Rebased on top of main, which includes the SelectionDAG changes for convergence tokens.
  • Eliminated the G_* opcodes in favour of the opcodes introduced for SelectionDAG

The new token type is used in llvm#67006 for implementing convergence control tokens
in GMIR.
In the IR translator, convert the LLVM token type to LLT::token(), which is an
alias for the s0 type. These show up as implicit uses on convergent operations.

Differential Revision: https://reviews.llvm.org/D158147
@ssahasra ssahasra force-pushed the ssahasra/gmir-convergence-tokens branch from 951b543 to ff7f4ef Compare March 14, 2024 09:29
@ssahasra
Copy link
Collaborator Author

  • rebased to upstream/main
  • split out LLT::token() as a separate commit
  • added support for tail calls along with a test
  • removed redundant return (after an unreachable)

ssahasra added a commit that referenced this pull request Mar 15, 2024
The new token type is used in #67006 for implementing convergence
control tokens in GMIR.
@ssahasra ssahasra merged commit ec34699 into llvm:main Mar 18, 2024
4 checks passed
@ssahasra ssahasra deleted the ssahasra/gmir-convergence-tokens branch March 18, 2024 05:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants