Skip to content

Commit

Permalink
Implement convergence control in MIR using SelectionDAG (#71785)
Browse files Browse the repository at this point in the history
LLVM function calls carry convergence control tokens as operand bundles, where
the tokens themselves are produced by convergence control intrinsics. This patch
implements convergence control tokens in MIR as follows:

1. Introduce target-independent ISD opcodes and MIR opcodes for convergence
   control intrinsics.
2. Model token values as untyped virtual registers in MIR.

The change also introduces an additional ISD opcode CONVERGENCECTRL_GLUE and a
corresponding machine opcode with the same spelling. This glues the convergence
control token to SDNodes that represent calls to intrinsics. The glued token is
later translated to an implicit argument in the MIR.

The lowering of calls to user-defined functions is target-specific. On AMDGPU,
the convergence control operand bundle at a non-intrinsic call is translated to
an explicit argument to the SI_CALL_ISEL instruction. Post-selection adjustment
converts this explicit argument to an implicit argument on the SI_CALL
instruction.
  • Loading branch information
ssahasra committed Feb 21, 2024
1 parent 03203b7 commit 7988973
Show file tree
Hide file tree
Showing 52 changed files with 831 additions and 162 deletions.
9 changes: 8 additions & 1 deletion llvm/include/llvm/ADT/GenericConvergenceVerifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ template <typename ContextT> class GenericConvergenceVerifier {

void initialize(raw_ostream *OS,
function_ref<void(const Twine &Message)> FailureCB,
const FunctionT &F) {
const FunctionT &F, bool _IsSSA) {
clear();
this->OS = OS;
this->FailureCB = FailureCB;
Context = ContextT(&F);
IsSSA = _IsSSA;
}

void clear();
Expand All @@ -52,6 +53,7 @@ template <typename ContextT> class GenericConvergenceVerifier {
DominatorTreeT *DT;
CycleInfoT CI;
ContextT Context;
bool IsSSA;

/// Whether the current function has convergencectrl operand bundles.
enum {
Expand All @@ -60,6 +62,10 @@ template <typename ContextT> class GenericConvergenceVerifier {
NoConvergence
} ConvergenceKind = NoConvergence;

/// The control token operation performed by a convergence control Intrinsic
/// in LLVM IR, or by a CONVERGENCECTRL* instruction in MIR
enum ConvOpKind { CONV_ANCHOR, CONV_ENTRY, CONV_LOOP, CONV_NONE };

// Cache token uses found so far. Note that we track the unique definitions
// and not the token values.
DenseMap<const InstructionT *, const InstructionT *> Tokens;
Expand All @@ -68,6 +74,7 @@ template <typename ContextT> class GenericConvergenceVerifier {

static bool isInsideConvergentFunction(const InstructionT &I);
static bool isConvergent(const InstructionT &I);
static ConvOpKind getConvOp(const InstructionT &I);
const InstructionT *findAndCheckConvergenceTokenUsed(const InstructionT &I);

void reportFailure(const Twine &Message, ArrayRef<Printable> Values);
Expand Down
10 changes: 1 addition & 9 deletions llvm/include/llvm/CodeGen/FunctionLoweringInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,7 @@ class FunctionLoweringInfo {

Register CreateRegs(Type *Ty, bool isDivergent = false);

Register InitializeRegForValue(const Value *V) {
// Tokens never live in vregs.
if (V->getType()->isTokenTy())
return 0;
Register &R = ValueMap[V];
assert(R == 0 && "Already initialized this value register!");
assert(VirtReg2Value.empty());
return R = CreateRegs(V);
}
Register InitializeRegForValue(const Value *V);

/// GetLiveOutRegInfo - Gets LiveOutInfo for a register, returning NULL if the
/// register is a PHI destination and the PHI's LiveOutInfo is not valid.
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,15 @@ enum NodeType {
#define BEGIN_REGISTER_VP_SDNODE(VPSDID, ...) VPSDID,
#include "llvm/IR/VPIntrinsics.def"

// The `llvm.experimental.convergence.*` intrinsics.
CONVERGENCECTRL_ANCHOR,
CONVERGENCECTRL_ENTRY,
CONVERGENCECTRL_LOOP,
// This does not correspond to any convergence control intrinsic. It used to
// glue a convergence control token to a convergent operation in the DAG,
// which is later translated to an implicit use in the MIR.
CONVERGENCECTRL_GLUE,

/// BUILTIN_OP_END - This must be the last enum value in this list.
/// The target-specific pre-isel opcode values start here.
BUILTIN_OP_END
Expand Down
28 changes: 28 additions & 0 deletions llvm/include/llvm/CodeGen/MachineConvergenceVerifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- MachineConvergenceVerifier.h - Verify convergenctrl ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
/// \file
///
/// This file declares the MIR specialization of the GenericConvergenceVerifier
/// template.
///
//===----------------------------------------------------------------------===//

#ifndef LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
#define LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H

#include "llvm/ADT/GenericConvergenceVerifier.h"
#include "llvm/CodeGen/MachineSSAContext.h"

namespace llvm {

using MachineConvergenceVerifier =
GenericConvergenceVerifier<MachineSSAContext>;

} // namespace llvm

#endif // LLVM_CODEGEN_MACHINECONVERGENCEVERIFIER_H
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGISel.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ class SelectionDAGISel : public MachineFunctionPass {
void Select_ARITH_FENCE(SDNode *N);
void Select_MEMBARRIER(SDNode *N);

void Select_CONVERGENCECTRL_ANCHOR(SDNode *N);
void Select_CONVERGENCECTRL_ENTRY(SDNode *N);
void Select_CONVERGENCECTRL_LOOP(SDNode *N);

void pushStackMapLiveVariable(SmallVectorImpl<SDValue> &Ops, SDValue Operand,
SDLoc DL);
void Select_STACKMAP(SDNode *N);
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -4401,6 +4401,7 @@ class TargetLowering : public TargetLoweringBase {
SmallVector<ISD::InputArg, 32> Ins;
SmallVector<SDValue, 4> InVals;
const ConstantInt *CFIType = nullptr;
SDValue ConvergenceControlToken;

CallLoweringInfo(SelectionDAG &DAG)
: RetSExt(false), RetZExt(false), IsVarArg(false), IsInReg(false),
Expand Down Expand Up @@ -4534,6 +4535,11 @@ class TargetLowering : public TargetLoweringBase {
return *this;
}

CallLoweringInfo &setConvergenceControlToken(SDValue Token) {
ConvergenceControlToken = Token;
return *this;
}

ArgListTy &getArgs() {
return Args;
}
Expand Down
25 changes: 14 additions & 11 deletions llvm/include/llvm/IR/GenericConvergenceVerifierImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ template <class ContextT> void GenericConvergenceVerifier<ContextT>::clear() {
Tokens.clear();
CI.clear();
ConvergenceKind = NoConvergence;
IsSSA = false;
}

template <class ContextT>
Expand All @@ -61,12 +62,16 @@ void GenericConvergenceVerifier<ContextT>::visit(const BlockT &BB) {

template <class ContextT>
void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
auto ID = ContextT::getIntrinsicID(I);
ConvOpKind ConvOp = getConvOp(I);
if (!IsSSA) {
Check(ConvOp == CONV_NONE, "Convergence control requires SSA.",
{Context.print(&I)});
return;
}
auto *TokenDef = findAndCheckConvergenceTokenUsed(I);
bool IsCtrlIntrinsic = true;

switch (ID) {
case Intrinsic::experimental_convergence_entry:
switch (ConvOp) {
case CONV_ENTRY:
Check(isInsideConvergentFunction(I),
"Entry intrinsic can occur only in a convergent function.",
{Context.print(&I)});
Expand All @@ -78,13 +83,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
"same basic block.",
{Context.print(&I)});
LLVM_FALLTHROUGH;
case Intrinsic::experimental_convergence_anchor:
case CONV_ANCHOR:
Check(!TokenDef,
"Entry or anchor intrinsic cannot have a convergencectrl token "
"operand.",
{Context.print(&I)});
break;
case Intrinsic::experimental_convergence_loop:
case CONV_LOOP:
Check(TokenDef, "Loop intrinsic must have a convergencectrl token operand.",
{Context.print(&I)});
Check(!SeenFirstConvOp,
Expand All @@ -93,14 +98,13 @@ void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
{Context.print(&I)});
break;
default:
IsCtrlIntrinsic = false;
break;
}

if (isConvergent(I))
SeenFirstConvOp = true;

if (TokenDef || IsCtrlIntrinsic) {
if (TokenDef || ConvOp != CONV_NONE) {
Check(isConvergent(I),
"Convergence control token can only be used in a convergent call.",
{Context.print(&I)});
Expand Down Expand Up @@ -161,8 +165,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
return;
}

Check(ContextT::getIntrinsicID(*User) ==
Intrinsic::experimental_convergence_loop,
Check(getConvOp(*User) == CONV_LOOP,
"Convergence token used by an instruction other than "
"llvm.experimental.convergence.loop in a cycle that does "
"not contain the token's definition.",
Expand Down Expand Up @@ -199,7 +202,7 @@ void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
for (auto &I : *BB) {
if (auto *Token = Tokens.lookup(&I))
checkToken(Token, &I, LiveTokens);
if (isConvergenceControlIntrinsic(ContextT::getIntrinsicID(I)))
if (getConvOp(I) != CONV_NONE)
LiveTokens.push_back(&I);
}

Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/Support/TargetOpcodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,11 @@ HANDLE_TARGET_OPCODE(MEMBARRIER)
// using.
HANDLE_TARGET_OPCODE(JUMP_TABLE_DEBUG_INFO)

HANDLE_TARGET_OPCODE(CONVERGENCECTRL_ENTRY)
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_ANCHOR)
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_LOOP)
HANDLE_TARGET_OPCODE(CONVERGENCECTRL_GLUE)

/// The following generic opcodes are not supposed to appear after ISel.
/// This is something we might want to relax, but for now, this is convenient
/// to produce diagnostics.
Expand Down
19 changes: 19 additions & 0 deletions llvm/include/llvm/Target/Target.td
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,25 @@ def JUMP_TABLE_DEBUG_INFO : StandardPseudoInstruction {
let isMeta = true;
}

let hasSideEffects = false, isMeta = true, isConvergent = true in {
def CONVERGENCECTRL_ANCHOR : StandardPseudoInstruction {
let OutOperandList = (outs unknown:$dst);
let InOperandList = (ins);
}
def CONVERGENCECTRL_ENTRY : StandardPseudoInstruction {
let OutOperandList = (outs unknown:$dst);
let InOperandList = (ins);
}
def CONVERGENCECTRL_LOOP : StandardPseudoInstruction {
let OutOperandList = (outs unknown:$dst);
let InOperandList = (ins unknown:$src);
}
def CONVERGENCECTRL_GLUE : StandardPseudoInstruction {
let OutOperandList = (outs);
let InOperandList = (ins unknown:$src);
}
}

// Generic opcodes used in GlobalISel.
include "llvm/Target/GenericOpcodes.td"

Expand Down
10 changes: 10 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,16 @@ def assertsext : SDNode<"ISD::AssertSext", SDT_assert>;
def assertzext : SDNode<"ISD::AssertZext", SDT_assert>;
def assertalign : SDNode<"ISD::AssertAlign", SDT_assert>;

def convergencectrl_anchor : SDNode<"ISD::CONVERGENCECTRL_ANCHOR",
SDTypeProfile<1, 0, [SDTCisVT<0,untyped>]>>;
def convergencectrl_entry : SDNode<"ISD::CONVERGENCECTRL_ENTRY",
SDTypeProfile<1, 0, [SDTCisVT<0,untyped>]>>;
def convergencectrl_loop : SDNode<"ISD::CONVERGENCECTRL_LOOP",
SDTypeProfile<1, 1,
[SDTCisVT<0,untyped>, SDTCisVT<1,untyped>]>>;
def convergencectrl_glue : SDNode<"ISD::CONVERGENCECTRL_GLUE",
SDTypeProfile<0, 1, [SDTCisVT<0, untyped>]>>;

//===----------------------------------------------------------------------===//
// Selection DAG Condition Codes

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ add_llvm_component_library(LLVMCodeGen
MachineBranchProbabilityInfo.cpp
MachineCFGPrinter.cpp
MachineCombiner.cpp
MachineConvergenceVerifier.cpp
MachineCopyPropagation.cpp
MachineCSE.cpp
MachineCheckDebugify.cpp
Expand Down
86 changes: 86 additions & 0 deletions llvm/lib/CodeGen/MachineConvergenceVerifier.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//===- ConvergenceVerifier.cpp - Verify convergence control -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/MachineConvergenceVerifier.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/MachineSSAContext.h"
#include "llvm/IR/GenericConvergenceVerifierImpl.h"

using namespace llvm;

template <>
auto GenericConvergenceVerifier<MachineSSAContext>::getConvOp(
const MachineInstr &MI) -> ConvOpKind {
switch (MI.getOpcode()) {
default:
return CONV_NONE;
case TargetOpcode::CONVERGENCECTRL_ENTRY:
return CONV_ENTRY;
case TargetOpcode::CONVERGENCECTRL_ANCHOR:
return CONV_ANCHOR;
case TargetOpcode::CONVERGENCECTRL_LOOP:
return CONV_LOOP;
}
}

template <>
const MachineInstr *
GenericConvergenceVerifier<MachineSSAContext>::findAndCheckConvergenceTokenUsed(
const MachineInstr &MI) {
const MachineRegisterInfo &MRI = Context.getFunction()->getRegInfo();
const MachineInstr *TokenDef = nullptr;

for (const MachineOperand &MO : MI.uses()) {
if (!MO.isReg())
continue;
Register OpReg = MO.getReg();
if (!OpReg.isVirtual())
continue;

const MachineInstr *Def = MRI.getVRegDef(OpReg);
if (!Def)
continue;
if (getConvOp(*Def) == CONV_NONE)
continue;

CheckOrNull(
MI.isConvergent(),
"Convergence control tokens can only be used by convergent operations.",
{Context.print(OpReg), Context.print(&MI)});

CheckOrNull(!TokenDef,
"An operation can use at most one convergence control token.",
{Context.print(OpReg), Context.print(&MI)});

TokenDef = Def;
}

if (TokenDef)
Tokens[&MI] = TokenDef;

return TokenDef;
}

template <>
bool GenericConvergenceVerifier<MachineSSAContext>::isInsideConvergentFunction(
const MachineInstr &MI) {
// The class MachineFunction does not have any property to indicate whether it
// is convergent. Trivially return true so that the check always passes.
return true;
}

template <>
bool GenericConvergenceVerifier<MachineSSAContext>::isConvergent(
const MachineInstr &MI) {
return MI.isConvergent();
}

template class llvm::GenericConvergenceVerifier<MachineSSAContext>;

0 comments on commit 7988973

Please sign in to comment.