1818#include " mlir/Target/LLVMIR/Export.h"
1919#include " mlir/Target/LLVMIR/ModuleTranslation.h"
2020
21+ #include " llvm/ADT/ScopeExit.h"
2122#include " llvm/IR/Constants.h"
2223#include " llvm/IR/IRBuilder.h"
2324#include " llvm/IR/LLVMContext.h"
2425#include " llvm/IR/Module.h"
2526#include " llvm/Support/FormatVariadic.h"
27+ #include " llvm/Transforms/Utils/ModuleUtils.h"
2628
2729using namespace mlir ;
2830
@@ -31,9 +33,13 @@ namespace {
3133class SelectObjectAttrImpl
3234 : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
3335 SelectObjectAttrImpl> {
36+ // Returns the selected object for embedding.
37+ gpu::ObjectAttr getSelectedObject (gpu::BinaryOp op) const ;
38+
3439public:
3540 // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
36- // global binary string.
41+ // global binary string which gets loaded/unloaded into a global module
42+ // object through a global ctor/dtor.
3743 LogicalResult embedBinary (Attribute attribute, Operation *operation,
3844 llvm::IRBuilderBase &builder,
3945 LLVM::ModuleTranslation &moduleTranslation) const ;
@@ -45,23 +51,9 @@ class SelectObjectAttrImpl
4551 Operation *binaryOperation,
4652 llvm::IRBuilderBase &builder,
4753 LLVM::ModuleTranslation &moduleTranslation) const ;
48-
49- // Returns the selected object for embedding.
50- gpu::ObjectAttr getSelectedObject (gpu::BinaryOp op) const ;
5154};
52- // Returns an identifier for the global string holding the binary.
53- std::string getBinaryIdentifier (StringRef binaryName) {
54- return binaryName.str () + " _bin_cst" ;
55- }
5655} // namespace
5756
58- void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels (
59- DialectRegistry ®istry) {
60- registry.addExtension (+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
61- SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
62- });
63- }
64-
6557gpu::ObjectAttr
6658SelectObjectAttrImpl::getSelectedObject (gpu::BinaryOp op) const {
6759 ArrayRef<Attribute> objects = op.getObjectsAttr ().getValue ();
@@ -96,6 +88,94 @@ SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
9688 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
9789}
9890
91+ static Twine getModuleIdentifier (StringRef moduleName) {
92+ return moduleName + " _module" ;
93+ }
94+
95+ namespace llvm {
96+ static LogicalResult embedBinaryImpl (StringRef moduleName,
97+ gpu::ObjectAttr object, Module &module ) {
98+
99+ // Embed the object as a global string.
100+ // Add null for assembly output for JIT paths that expect null-terminated
101+ // strings.
102+ bool addNull = (object.getFormat () == gpu::CompilationTarget::Assembly);
103+ StringRef serializedStr = object.getObject ().getValue ();
104+ Constant *serializedCst =
105+ ConstantDataArray::getString (module .getContext (), serializedStr, addNull);
106+ GlobalVariable *serializedObj =
107+ new GlobalVariable (module , serializedCst->getType (), true ,
108+ GlobalValue::LinkageTypes::InternalLinkage,
109+ serializedCst, moduleName + " _binary" );
110+ serializedObj->setAlignment (MaybeAlign (8 ));
111+ serializedObj->setUnnamedAddr (GlobalValue::UnnamedAddr::None);
112+
113+ // Default JIT optimization level.
114+ auto optLevel = APInt::getZero (32 );
115+
116+ if (DictionaryAttr objectProps = object.getProperties ()) {
117+ if (auto section = dyn_cast_or_null<StringAttr>(
118+ objectProps.get (gpu::elfSectionName))) {
119+ serializedObj->setSection (section.getValue ());
120+ }
121+ // Check if there's an optimization level embedded in the object.
122+ if (auto optAttr = dyn_cast_or_null<IntegerAttr>(objectProps.get (" O" )))
123+ optLevel = optAttr.getValue ();
124+ }
125+
126+ IRBuilder<> builder (module .getContext ());
127+ auto i32Ty = builder.getInt32Ty ();
128+ auto i64Ty = builder.getInt64Ty ();
129+ auto ptrTy = builder.getPtrTy (0 );
130+ auto voidTy = builder.getVoidTy ();
131+
132+ // Embed the module as a global object.
133+ auto *modulePtr = new GlobalVariable (
134+ module , ptrTy, /* isConstant=*/ false , GlobalValue::InternalLinkage,
135+ /* Initializer=*/ ConstantPointerNull::get (ptrTy),
136+ getModuleIdentifier (moduleName));
137+
138+ auto *loadFn = Function::Create (FunctionType::get (voidTy, /* IsVarArg=*/ false ),
139+ GlobalValue::InternalLinkage,
140+ moduleName + " _load" , module );
141+ loadFn->setSection (" .text.startup" );
142+ auto *loadBlock = BasicBlock::Create (module .getContext (), " entry" , loadFn);
143+ builder.SetInsertPoint (loadBlock);
144+ Value *moduleObj = [&] {
145+ if (object.getFormat () == gpu::CompilationTarget::Assembly) {
146+ FunctionCallee moduleLoadFn = module .getOrInsertFunction (
147+ " mgpuModuleLoadJIT" , FunctionType::get (ptrTy, {ptrTy, i32Ty}, false ));
148+ Constant *optValue = ConstantInt::get (i32Ty, optLevel);
149+ return builder.CreateCall (moduleLoadFn, {serializedObj, optValue});
150+ } else {
151+ FunctionCallee moduleLoadFn = module .getOrInsertFunction (
152+ " mgpuModuleLoad" , FunctionType::get (ptrTy, {ptrTy, i64Ty}, false ));
153+ Constant *binarySize =
154+ ConstantInt::get (i64Ty, serializedStr.size () + (addNull ? 1 : 0 ));
155+ return builder.CreateCall (moduleLoadFn, {serializedObj, binarySize});
156+ }
157+ }();
158+ builder.CreateStore (moduleObj, modulePtr);
159+ builder.CreateRetVoid ();
160+ appendToGlobalCtors (module , loadFn, /* Priority=*/ 123 );
161+
162+ auto *unloadFn = Function::Create (
163+ FunctionType::get (voidTy, /* IsVarArg=*/ false ),
164+ GlobalValue::InternalLinkage, moduleName + " _unload" , module );
165+ unloadFn->setSection (" .text.startup" );
166+ auto *unloadBlock =
167+ BasicBlock::Create (module .getContext (), " entry" , unloadFn);
168+ builder.SetInsertPoint (unloadBlock);
169+ FunctionCallee moduleUnloadFn = module .getOrInsertFunction (
170+ " mgpuModuleUnload" , FunctionType::get (voidTy, ptrTy, false ));
171+ builder.CreateCall (moduleUnloadFn, builder.CreateLoad (ptrTy, modulePtr));
172+ builder.CreateRetVoid ();
173+ appendToGlobalDtors (module , unloadFn, /* Priority=*/ 123 );
174+
175+ return success ();
176+ }
177+ } // namespace llvm
178+
99179LogicalResult SelectObjectAttrImpl::embedBinary (
100180 Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
101181 LLVM::ModuleTranslation &moduleTranslation) const {
@@ -113,29 +193,8 @@ LogicalResult SelectObjectAttrImpl::embedBinary(
113193 if (!object)
114194 return failure ();
115195
116- llvm::Module *module = moduleTranslation.getLLVMModule ();
117-
118- // Embed the object as a global string.
119- // Add null for assembly output for JIT paths that expect null-terminated
120- // strings.
121- bool addNull = (object.getFormat () == gpu::CompilationTarget::Assembly);
122- llvm::Constant *binary = llvm::ConstantDataArray::getString (
123- builder.getContext (), object.getObject ().getValue (), addNull);
124- llvm::GlobalVariable *serializedObj =
125- new llvm::GlobalVariable (*module , binary->getType (), true ,
126- llvm::GlobalValue::LinkageTypes::InternalLinkage,
127- binary, getBinaryIdentifier (op.getName ()));
128-
129- if (object.getProperties ()) {
130- if (auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
131- object.getProperties ().get (gpu::elfSectionName))) {
132- serializedObj->setSection (section.getValue ());
133- }
134- }
135- serializedObj->setLinkage (llvm::GlobalValue::LinkageTypes::InternalLinkage);
136- serializedObj->setAlignment (llvm::MaybeAlign (8 ));
137- serializedObj->setUnnamedAddr (llvm::GlobalValue::UnnamedAddr::None);
138- return success ();
196+ return embedBinaryImpl (op.getName (), object,
197+ *moduleTranslation.getLLVMModule ());
139198}
140199
141200namespace llvm {
@@ -153,15 +212,6 @@ class LaunchKernel {
153212 // Get the module function callee.
154213 FunctionCallee getModuleFunctionFn ();
155214
156- // Get the module load callee.
157- FunctionCallee getModuleLoadFn ();
158-
159- // Get the module load JIT callee.
160- FunctionCallee getModuleLoadJITFn ();
161-
162- // Get the module unload callee.
163- FunctionCallee getModuleUnloadFn ();
164-
165215 // Get the stream create callee.
166216 FunctionCallee getStreamCreateFn ();
167217
@@ -261,24 +311,6 @@ llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
261311 FunctionType::get (ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false ));
262312}
263313
264- llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn () {
265- return module .getOrInsertFunction (
266- " mgpuModuleLoad" ,
267- FunctionType::get (ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false ));
268- }
269-
270- llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn () {
271- return module .getOrInsertFunction (
272- " mgpuModuleLoadJIT" ,
273- FunctionType::get (ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false ));
274- }
275-
276- llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn () {
277- return module .getOrInsertFunction (
278- " mgpuModuleUnload" ,
279- FunctionType::get (voidTy, ArrayRef<Type *>({ptrTy}), false ));
280- }
281-
282314llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn () {
283315 return module .getOrInsertFunction (" mgpuStreamCreate" ,
284316 FunctionType::get (ptrTy, false ));
@@ -301,9 +333,9 @@ llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
301333llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName (StringRef moduleName,
302334 StringRef kernelName) {
303335 std::string globalName =
304- std::string (formatv (" {0}_{1}_kernel_name " , moduleName, kernelName));
336+ std::string (formatv (" {0}_{1}_name " , moduleName, kernelName));
305337
306- if (GlobalVariable *gv = module .getGlobalVariable (globalName))
338+ if (GlobalVariable *gv = module .getGlobalVariable (globalName, true ))
307339 return gv;
308340
309341 return builder.CreateGlobalString (kernelName, globalName);
@@ -346,16 +378,13 @@ llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
346378}
347379
348380// Emits LLVM IR to launch a kernel function:
349- // %0 = call %binarygetter
350- // %1 = call %moduleLoad(%0)
351- // %2 = <see generateKernelNameConstant>
352- // %3 = call %moduleGetFunction(%1, %2)
353- // %4 = call %streamCreate()
354- // %5 = <see generateParamsArray>
355- // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
356- // call %streamSynchronize(%4)
357- // call %streamDestroy(%4)
358- // call %moduleUnload(%1)
381+ // %1 = load %global_module_object
382+ // %2 = call @mgpuModuleGetFunction(%1, %global_kernel_name)
383+ // %3 = call @mgpuStreamCreate()
384+ // %4 = <see createKernelArgArray()>
385+ // call @mgpuLaunchKernel(%2, ..., %3, %4, ...)
386+ // call @mgpuStreamSynchronize(%3)
387+ // call @mgpuStreamDestroy(%3)
359388llvm::LogicalResult
360389llvm::LaunchKernel::createKernelLaunch (mlir::gpu::LaunchFuncOp op,
361390 mlir::gpu::ObjectAttr object) {
@@ -385,58 +414,29 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
385414 // Create the argument array.
386415 Value *argArray = createKernelArgArray (op);
387416
388- // Default JIT optimization level.
389- llvm::Constant *optV = llvm::ConstantInt::get (i32Ty, 0 );
390- // Check if there's an optimization level embedded in the object.
391- DictionaryAttr objectProps = object.getProperties ();
392- mlir::Attribute optAttr;
393- if (objectProps && (optAttr = objectProps.get (" O" ))) {
394- auto optLevel = dyn_cast<IntegerAttr>(optAttr);
395- if (!optLevel)
396- return op.emitError (" the optimization level must be an integer" );
397- optV = llvm::ConstantInt::get (i32Ty, optLevel.getValue ());
398- }
399-
400- // Load the kernel module.
401- StringRef moduleName = op.getKernelModuleName ().getValue ();
402- std::string binaryIdentifier = getBinaryIdentifier (moduleName);
403- Value *binary = module .getGlobalVariable (binaryIdentifier, true );
404- if (!binary)
405- return op.emitError () << " Couldn't find the binary: " << binaryIdentifier;
406-
407- auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
408- if (!binaryVar)
409- return op.emitError () << " Binary is not a global variable: "
410- << binaryIdentifier;
411- llvm::Constant *binaryInit = binaryVar->getInitializer ();
412- auto binaryDataSeq =
413- dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
414- if (!binaryDataSeq)
415- return op.emitError () << " Couldn't find binary data array: "
416- << binaryIdentifier;
417- llvm::Constant *binarySize =
418- llvm::ConstantInt::get (i64Ty, binaryDataSeq->getNumElements () *
419- binaryDataSeq->getElementByteSize ());
420-
421- Value *moduleObject =
422- object.getFormat () == gpu::CompilationTarget::Assembly
423- ? builder.CreateCall (getModuleLoadJITFn (), {binary, optV})
424- : builder.CreateCall (getModuleLoadFn (), {binary, binarySize});
425-
426417 // Load the kernel function.
427- Value *moduleFunction = builder.CreateCall (
428- getModuleFunctionFn (),
429- {moduleObject,
430- getOrCreateFunctionName (moduleName, op.getKernelName ().getValue ())});
418+ StringRef moduleName = op.getKernelModuleName ().getValue ();
419+ Twine moduleIdentifier = getModuleIdentifier (moduleName);
420+ Value *modulePtr = module .getGlobalVariable (moduleIdentifier.str (), true );
421+ if (!modulePtr)
422+ return op.emitError () << " Couldn't find the binary: " << moduleIdentifier;
423+ Value *moduleObj = builder.CreateLoad (ptrTy, modulePtr);
424+ Value *functionName = getOrCreateFunctionName (moduleName, op.getKernelName ());
425+ Value *moduleFunction =
426+ builder.CreateCall (getModuleFunctionFn (), {moduleObj, functionName});
431427
432428 // Get the stream to use for execution. If there's no async object then create
433429 // a stream to make a synchronous kernel launch.
434430 Value *stream = nullptr ;
435- bool handleStream = false ;
431+ // Sync & destroy the stream, for synchronous launches.
432+ auto destroyStream = make_scope_exit ([&]() {
433+ builder.CreateCall (getStreamSyncFn (), {stream});
434+ builder.CreateCall (getStreamDestroyFn (), {stream});
435+ });
436436 if (mlir::Value asyncObject = op.getAsyncObject ()) {
437437 stream = llvmValue (asyncObject);
438+ destroyStream.release ();
438439 } else {
439- handleStream = true ;
440440 stream = builder.CreateCall (getStreamCreateFn (), {});
441441 }
442442
@@ -462,14 +462,12 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
462462 argArray, nullPtr, paramsCount}));
463463 }
464464
465- // Sync & destroy the stream, for synchronous launches.
466- if (handleStream) {
467- builder.CreateCall (getStreamSyncFn (), {stream});
468- builder.CreateCall (getStreamDestroyFn (), {stream});
469- }
470-
471- // Unload the kernel module.
472- builder.CreateCall (getModuleUnloadFn (), {moduleObject});
473-
474465 return success ();
475466}
467+
468+ void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels (
469+ DialectRegistry ®istry) {
470+ registry.addExtension (+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
471+ SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
472+ });
473+ }
0 commit comments