Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,37 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}

// Creating external wrappers with UseBarePtrCallConv=true.
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
LLVM::LLVMFuncOp newFuncOp) {
// Create the auxiliary function.
auto wrapperFunc = rewriter.cloneWithoutRegions(newFuncOp);
wrapperFunc.setSymName(
llvm::formatv("_mlir_ciface_{0}", newFuncOp.getName()).str());

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(wrapperFunc.addEntryBlock());
auto call =
rewriter.create<LLVM::CallOp>(loc, newFuncOp, wrapperFunc.getArguments());
rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
}

static void wrapExternalFunction(OpBuilder &builder, Location loc,
LLVM::LLVMFuncOp newFuncOp) {
OpBuilder::InsertionGuard guard(builder);
// Create the auxiliary function.
auto wrapperFunc = builder.cloneWithoutRegions(newFuncOp);
wrapperFunc.setSymName(
llvm::formatv("_mlir_ciface_{0}", newFuncOp.getName()).str());

// This wrapper should only be visible in this module.
newFuncOp.setLinkage(LLVM::Linkage::Private);
builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
auto call =
builder.create<LLVM::CallOp>(loc, wrapperFunc, newFuncOp.getArguments());
builder.create<LLVM::ReturnOp>(loc, call.getResults());
}

/// Modifies the body of the function to construct the `MemRefDescriptor` from
/// the bare pointer calling convention lowering of `memref` types.
static void modifyFuncOpToUseBarePtrCallingConv(
Expand Down Expand Up @@ -502,6 +533,16 @@ struct FuncOpConversion : public FuncOpConversionBase {
modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp->getLoc(),
*getTypeConverter(), *newFuncOp,
funcOp.getFunctionType().getInputs());
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
if (newFuncOp->isVarArg())
return funcOp->emitError("C interface for variadic functions is not "
"supported yet.");
if (newFuncOp->isExternal())
wrapExternalFunction(rewriter, newFuncOp->getLoc(), *newFuncOp);
else
wrapForExternalCallers(rewriter, funcOp->getLoc(), *newFuncOp);
}
}

rewriter.eraseOp(funcOp);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1" -split-input-file | FileCheck %s
// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1 use-bare-pointers-for-host=1" -split-input-file | FileCheck %s

module attributes {gpu.container_module} {
gpu.module @kernels [#nvvm.target] {
Expand All @@ -15,7 +15,8 @@ module attributes {gpu.container_module} {
llvm.return
}
}
func.func @foo() {
// CHECK: @foo
func.func @foo() attributes {llvm.emit_c_interface} {
// CHECK: [[MEMREF:%.*]] = gpu.alloc () : memref<10xf32, 1>
// CHECK: [[DESCRIPTOR:%.*]] = builtin.unrealized_conversion_cast [[MEMREF]] : memref<10xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[PTR:%.*]] = llvm.extractvalue [[DESCRIPTOR]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
Expand All @@ -28,3 +29,4 @@ module attributes {gpu.container_module} {
return
}
}
// CHECK: @_mlir_ciface_foo