Skip to content

Commit

Permalink
[SPIRV] add SPIRVPrepareFunctions pass and update other passes
Browse files Browse the repository at this point in the history
The patch adds SPIRVPrepareFunctions pass, which modifies function
signatures containing aggregate arguments and/or return values before
IR translation. Information about the original signatures is stored in
metadata. It is used during call lowering to restore correct SPIR-V types
of function arguments and return values. This pass also substitutes some
llvm intrinsic calls to function calls, generating the necessary functions
in the module, as the SPIRV translator does.

The patch also includes changes in other modules, fixing errors and
enabling many SPIR-V features that were omitted earlier. And 15 LIT tests
are also added to demonstrate the new functionality.

Differential Revision: https://reviews.llvm.org/D129730

Co-authored-by: Aleksandr Bezzubikov <zuban32s@gmail.com>
Co-authored-by: Michal Paszkowski <michal.paszkowski@outlook.com>
Co-authored-by: Andrey Tretyakov <andrey1.tretyakov@intel.com>
Co-authored-by: Konrad Trifunovic <konrad.trifunovic@intel.com>
  • Loading branch information
5 people committed Jul 22, 2022
1 parent 0ccb6da commit b8e1544
Show file tree
Hide file tree
Showing 42 changed files with 2,350 additions and 285 deletions.
3 changes: 2 additions & 1 deletion llvm/include/llvm/IR/IntrinsicsSPIRV.td
Expand Up @@ -20,12 +20,13 @@ let TargetPrefix = "spv" in {

def int_spv_gep : Intrinsic<[llvm_anyptr_ty], [llvm_i1_ty, llvm_any_ty, llvm_vararg_ty], [ImmArg<ArgIndex<0>>]>;
def int_spv_load : Intrinsic<[llvm_i32_ty], [llvm_anyptr_ty, llvm_i16_ty, llvm_i8_ty], [ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
def int_spv_store : Intrinsic<[], [llvm_i32_ty, llvm_anyptr_ty, llvm_i16_ty, llvm_i8_ty], [ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>]>;
def int_spv_store : Intrinsic<[], [llvm_any_ty, llvm_anyptr_ty, llvm_i16_ty, llvm_i8_ty], [ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>]>;
def int_spv_extractv : Intrinsic<[llvm_any_ty], [llvm_i32_ty, llvm_vararg_ty]>;
def int_spv_insertv : Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_any_ty, llvm_vararg_ty]>;
def int_spv_extractelt : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_anyint_ty]>;
def int_spv_insertelt : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_any_ty, llvm_anyint_ty]>;
def int_spv_const_composite : Intrinsic<[llvm_i32_ty], [llvm_vararg_ty]>;
def int_spv_bitcast : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
}
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/CMakeLists.txt
Expand Up @@ -25,6 +25,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVMCInstLower.cpp
SPIRVModuleAnalysis.cpp
SPIRVPreLegalizer.cpp
SPIRVPrepareFunctions.cpp
SPIRVRegisterBankInfo.cpp
SPIRVRegisterInfo.cpp
SPIRVSubtarget.cpp
Expand All @@ -43,6 +44,7 @@ add_llvm_target(SPIRVCodeGen
SelectionDAG
Support
Target
TransformUtils

ADD_TO_COMPONENT
SPIRV
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp
Expand Up @@ -1068,5 +1068,15 @@ StringRef getKernelProfilingInfoName(KernelProfilingInfo e) {
}
llvm_unreachable("Unexpected operand");
}

std::string getExtInstSetName(InstructionSet e) {
switch (e) {
CASE(InstructionSet, OpenCL_std)
CASE(InstructionSet, GLSL_std_450)
CASE(InstructionSet, SPV_AMD_shader_trinary_minmax)
break;
}
llvm_unreachable("Unexpected operand");
}
} // namespace SPIRV
} // namespace llvm
13 changes: 13 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
Expand Up @@ -706,6 +706,19 @@ enum class KernelProfilingInfo : uint32_t {
CmdExecTime = 0x1,
};
StringRef getKernelProfilingInfoName(KernelProfilingInfo e);

enum class InstructionSet : uint32_t {
OpenCL_std = 0,
GLSL_std_450 = 1,
SPV_AMD_shader_trinary_minmax = 2,
};
std::string getExtInstSetName(InstructionSet e);

// TODO: implement other mnemonics.
enum class Opcode : uint32_t {
InBoundsPtrAccessChain = 70,
PtrCastToGeneric = 121,
};
} // namespace SPIRV
} // namespace llvm

Expand Down
15 changes: 13 additions & 2 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
Expand Up @@ -59,7 +59,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
}

void SPIRVInstPrinter::recordOpExtInstImport(const MCInst *MI) {
llvm_unreachable("Unimplemented recordOpExtInstImport");
// TODO: insert {Reg, Set} into ExtInstSetIDs map.
}

void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
Expand Down Expand Up @@ -176,7 +176,18 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
}

void SPIRVInstPrinter::printOpExtInst(const MCInst *MI, raw_ostream &O) {
llvm_unreachable("Unimplemented printOpExtInst");
// The fixed operands have already been printed, so just need to decide what
// type of ExtInst operands to print based on the instruction set and number.
MCInstrDesc MCDesc = MII.get(MI->getOpcode());
unsigned NumFixedOps = MCDesc.getNumOperands();
const auto NumOps = MI->getNumOperands();
if (NumOps == NumFixedOps)
return;

O << ' ';

// TODO: implement special printing for OpenCLExtInst::vstor*.
printRemainingVariableOps(MI, NumFixedOps, O, true);
}

void SPIRVInstPrinter::printOpDecorate(const MCInst *MI, raw_ostream &O) {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRV.h
Expand Up @@ -19,6 +19,7 @@ class SPIRVSubtarget;
class InstructionSelector;
class RegisterBankInfo;

ModulePass *createSPIRVPrepareFunctionsPass();
FunctionPass *createSPIRVPreLegalizerPass();
FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
InstructionSelector *
Expand Down
164 changes: 161 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
Expand Up @@ -21,6 +21,7 @@
#include "SPIRVUtils.h"
#include "TargetInfo/SPIRVTargetInfo.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineConstantPool.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
Expand Down Expand Up @@ -58,9 +59,14 @@ class SPIRVAsmPrinter : public AsmPrinter {
void outputModuleSection(SPIRV::ModuleSectionType MSType);
void outputEntryPoints();
void outputDebugSourceAndStrings(const Module &M);
void outputOpExtInstImports(const Module &M);
void outputOpMemoryModel();
void outputOpFunctionEnd();
void outputExtFuncDecls();
void outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
SPIRV::ExecutionMode EM);
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
void outputModuleSections();

void emitInstruction(const MachineInstr *MI) override;
Expand Down Expand Up @@ -127,6 +133,8 @@ void SPIRVAsmPrinter::emitFunctionBodyEnd() {
}

void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
if (MAI->MBBsToSkip.contains(&MBB))
return;
MCInst LabelInst;
LabelInst.setOpcode(SPIRV::OpLabel);
LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
Expand Down Expand Up @@ -237,6 +245,13 @@ void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) {
}

void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
// Output OpSourceExtensions.
for (auto &Str : MAI->SrcExt) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpSourceExtension);
addStringImm(Str.first(), Inst);
outputMCInst(Inst);
}
// Output OpSource.
MCInst Inst;
Inst.setOpcode(SPIRV::OpSource);
Expand All @@ -246,6 +261,19 @@ void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
outputMCInst(Inst);
}

void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) {
for (auto &CU : MAI->ExtInstSetMap) {
unsigned Set = CU.first;
Register Reg = CU.second;
MCInst Inst;
Inst.setOpcode(SPIRV::OpExtInstImport);
Inst.addOperand(MCOperand::createReg(Reg));
addStringImm(getExtInstSetName(static_cast<SPIRV::InstructionSet>(Set)),
Inst);
outputMCInst(Inst);
}
}

void SPIRVAsmPrinter::outputOpMemoryModel() {
MCInst Inst;
Inst.setOpcode(SPIRV::OpMemoryModel);
Expand Down Expand Up @@ -301,6 +329,135 @@ void SPIRVAsmPrinter::outputExtFuncDecls() {
}
}

// Encode LLVM type by SPIR-V execution mode VecTypeHint.
static unsigned encodeVecTypeHint(Type *Ty) {
if (Ty->isHalfTy())
return 4;
if (Ty->isFloatTy())
return 5;
if (Ty->isDoubleTy())
return 6;
if (IntegerType *IntTy = dyn_cast<IntegerType>(Ty)) {
switch (IntTy->getIntegerBitWidth()) {
case 8:
return 0;
case 16:
return 1;
case 32:
return 2;
case 64:
return 3;
default:
llvm_unreachable("invalid integer type");
}
}
if (FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty)) {
Type *EleTy = VecTy->getElementType();
unsigned Size = VecTy->getNumElements();
return Size << 16 | encodeVecTypeHint(EleTy);
}
llvm_unreachable("invalid type");
}

static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
SPIRV::ModuleAnalysisInfo *MAI) {
for (const MDOperand &MDOp : MDN->operands()) {
if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
Constant *C = CMeta->getValue();
if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
Inst.addOperand(MCOperand::createImm(Const->getZExtValue()));
} else if (auto *CE = dyn_cast<Function>(C)) {
Register FuncReg = MAI->getFuncReg(CE->getName().str());
assert(FuncReg.isValid());
Inst.addOperand(MCOperand::createReg(FuncReg));
}
}
}
}

void SPIRVAsmPrinter::outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
SPIRV::ExecutionMode EM) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(Reg));
Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM)));
addOpsFromMDNode(Node, Inst, MAI);
outputMCInst(Inst);
}

void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
if (Node) {
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
outputMCInst(Inst);
}
}
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
const Function &F = *FI;
if (F.isDeclaration())
continue;
Register FReg = MAI->getFuncReg(F.getGlobalIdentifier());
assert(FReg.isValid());
if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
outputExecutionModeFromMDNode(FReg, Node,
SPIRV::ExecutionMode::LocalSize);
if (MDNode *Node = F.getMetadata("work_group_size_hint"))
outputExecutionModeFromMDNode(FReg, Node,
SPIRV::ExecutionMode::LocalSizeHint);
if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size"))
outputExecutionModeFromMDNode(FReg, Node,
SPIRV::ExecutionMode::SubgroupSize);
if (MDNode *Node = F.getMetadata("vec_type_hint")) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(FReg));
unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::VecTypeHint);
Inst.addOperand(MCOperand::createImm(EM));
unsigned TypeCode = encodeVecTypeHint(getMDOperandAsType(Node, 0));
Inst.addOperand(MCOperand::createImm(TypeCode));
outputMCInst(Inst);
}
}
}

void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
outputModuleSection(SPIRV::MB_Annotations);
// Process llvm.global.annotations special global variable.
for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) {
if ((*F).getName() != "llvm.global.annotations")
continue;
const GlobalVariable *V = &(*F);
const ConstantArray *CA = cast<ConstantArray>(V->getOperand(0));
for (Value *Op : CA->operands()) {
ConstantStruct *CS = cast<ConstantStruct>(Op);
// The first field of the struct contains a pointer to
// the annotated variable.
Value *AnnotatedVar = CS->getOperand(0)->stripPointerCasts();
if (!isa<Function>(AnnotatedVar))
llvm_unreachable("Unsupported value in llvm.global.annotations");
Function *Func = cast<Function>(AnnotatedVar);
Register Reg = MAI->getFuncReg(Func->getGlobalIdentifier());

// The second field contains a pointer to a global annotation string.
GlobalVariable *GV =
cast<GlobalVariable>(CS->getOperand(1)->stripPointerCasts());

StringRef AnnotationString;
getConstantStringInfo(GV, AnnotationString);
MCInst Inst;
Inst.setOpcode(SPIRV::OpDecorate);
Inst.addOperand(MCOperand::createReg(Reg));
unsigned Dec = static_cast<unsigned>(SPIRV::Decoration::UserSemantic);
Inst.addOperand(MCOperand::createImm(Dec));
addStringImm(AnnotationString, Inst);
outputMCInst(Inst);
}
}
}

void SPIRVAsmPrinter::outputModuleSections() {
const Module *M = MMI->getModule();
// Get the global subtarget to output module-level info.
Expand All @@ -311,13 +468,14 @@ void SPIRVAsmPrinter::outputModuleSections() {
// Output instructions according to the Logical Layout of a Module:
// TODO: 1,2. All OpCapability instructions, then optional OpExtension
// instructions.
// TODO: 3. Optional OpExtInstImport instructions.
// 3. Optional OpExtInstImport instructions.
outputOpExtInstImports(*M);
// 4. The single required OpMemoryModel instruction.
outputOpMemoryModel();
// 5. All entry point declarations, using OpEntryPoint.
outputEntryPoints();
// 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
// TODO:
outputExecutionMode(*M);
// 7a. Debug: all OpString, OpSourceExtension, OpSource, and
// OpSourceContinued, without forward references.
outputDebugSourceAndStrings(*M);
Expand All @@ -326,7 +484,7 @@ void SPIRVAsmPrinter::outputModuleSections() {
// 7c. Debug: all OpModuleProcessed instructions.
outputModuleSection(SPIRV::MB_DebugModuleProcessed);
// 8. All annotation instructions (all decorations).
outputModuleSection(SPIRV::MB_Annotations);
outputAnnotations(*M);
// 9. All type declarations (OpTypeXXX instructions), all constant
// instructions, and all global variable declarations. This section is
// the first section to allow use of: OpLine and OpNoLine debug information;
Expand Down

0 comments on commit b8e1544

Please sign in to comment.