Skip to content

Commit

Permalink
[mlir][LLVMIR] Clean up the definitions of ReturnOp/CallOp
Browse files Browse the repository at this point in the history
  • Loading branch information
Mogball committed Aug 11, 2022
1 parent af77e5e commit 5e0c3b4
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 125 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
class CallOpInterface;

namespace LLVM {
namespace detail {
Expand Down
13 changes: 3 additions & 10 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Expand Up @@ -29,10 +29,9 @@ namespace LLVM {
class LLVMFuncOp;

/// Helper functions to lookup or create the declaration for commonly used
/// external C function calls. Such ops can then be invoked by creating a CallOp
/// with the proper arguments via `createLLVMCall`.
/// The list of functions provided here must be implemented separately (e.g. as
/// part of a support runtime library or as part of the libc).
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
Expand All @@ -58,12 +57,6 @@ LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
Type resultType = {});

/// Helper wrapper to create a call to `fn` with `args` and `resultTypes`.
Operation::result_range createLLVMCall(OpBuilder &b, Location loc,
LLVM::LLVMFuncOp fn,
ValueRange args = {},
ArrayRef<Type> resultTypes = {});

} // namespace LLVM
} // namespace mlir

Expand Down
62 changes: 34 additions & 28 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Expand Up @@ -645,14 +645,16 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//

def LLVM_CallOp : LLVM_Op<"call",
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Call to an LLVM function.";
let description = [{


In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
implements this behavior by providing a variadic `call` operation for 0- and
1-result functions. Even though MLIR supports multi-result functions, LLVM
Expand All @@ -678,29 +680,20 @@ def LLVM_CallOp : LLVM_Op<"call",
llvm.call %1(%0) : (f32) -> ()
```
}];

let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>,
DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
let results = (outs Variadic<LLVM_Type>);
Variadic<LLVM_Type>,
DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
let results = (outs Optional<LLVM_Type>:$result);

let builders = [
OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), [{
Type resultType = func.getFunctionType().getReturnType();
if (!resultType.isa<LLVM::LLVMVoidType>())
$_state.addTypes(resultType);
$_state.addAttribute("callee", SymbolRefAttr::get(func));
$_state.addAttributes(attributes);
$_state.addOperands(operands);
}]>,
OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>,
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
CArg<"ValueRange", "{}">:$operands), [{
build($_builder, $_state, results, SymbolRefAttr::get(callee), operands);
}]>,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
CArg<"ValueRange", "{}">:$operands), [{
build($_builder, $_state, results,
StringAttr::get($_builder.getContext(), callee), operands);
}]>];
CArg<"ValueRange", "{}">:$args)>
];

let hasCustomAssemblyFormat = 1;
}

Expand Down Expand Up @@ -878,25 +871,38 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
falseOperands);
}]>, LLVM_TerminatorPassthroughOpBuilder];
}

//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//

def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]> {
let arguments = (ins Variadic<LLVM_Type>:$args);
let arguments = (ins Optional<LLVM_Type>:$arg);
let assemblyFormat = "attr-dict ($arg^ `:` type($arg))?";

let builders = [
OpBuilder<(ins "ValueRange":$args), [{
build($_builder, $_state, TypeRange(), args);
}]>
];

let hasVerifier = 1;

string llvmBuilder = [{
if ($_numOperands != 0)
builder.CreateRet($args[0]);
builder.CreateRet($arg[0]);
else
builder.CreateRetVoid();
}];

let assemblyFormat = "attr-dict ($args^ `:` type($args))?";
let hasVerifier = 1;
}
def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> {

def LLVM_ResumeOp : LLVM_TerminatorOp<"resume"> {
let arguments = (ins LLVM_Type:$value);
string llvmBuilder = [{ builder.CreateResume($value); }];
let assemblyFormat = "$value attr-dict `:` type($value)";
let hasVerifier = 1;
}
def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
string llvmBuilder = [{ builder.CreateUnreachable(); }];
let assemblyFormat = "attr-dict";
}
Expand Down
12 changes: 5 additions & 7 deletions mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
Expand Up @@ -350,8 +350,8 @@ class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
// requires the size parameter be an integral multiple of the alignment
// parameter.
auto makeConstant = [&](uint64_t c) {
return rewriter.create<LLVM::ConstantOp>(
op->getLoc(), rewriter.getI64Type(), c);
return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
rewriter.getI64Type(), c);
};
coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
coroSize =
Expand All @@ -365,13 +365,12 @@ class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
auto coroAlloc = rewriter.create<LLVM::CallOp>(
loc, i8Ptr, SymbolRefAttr::get(allocFuncOp),
ValueRange{coroAlign, coroSize});
loc, allocFuncOp, ValueRange{coroAlign, coroSize});

// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
op, i8Ptr, ValueRange({coroId, coroAlloc.getResult()}));

return success();
}
Expand Down Expand Up @@ -400,8 +399,7 @@ class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
// Free the memory.
auto freeFuncOp =
LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
SymbolRefAttr::get(freeFuncOp),
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
ValueRange(coroMem.getResult()));

return success();
Expand Down
21 changes: 8 additions & 13 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Expand Up @@ -164,7 +164,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);

if (resultIsNowArg) {
rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
wrapperFuncOp.getArgument(0));
rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
} else {
Expand Down Expand Up @@ -265,7 +265,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,

if (resultIsNowArg) {
Value result = builder.create<LLVM::LoadOp>(loc, args.front());
builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
builder.create<LLVM::ReturnOp>(loc, result);
} else {
builder.create<LLVM::ReturnOp>(loc, call.getResults());
}
Expand Down Expand Up @@ -617,26 +617,21 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
}

// If ReturnOp has 0 or 1 operand, create it and return immediately.
if (numArguments == 0) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
return success();
}
if (numArguments == 1) {
if (numArguments <= 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, TypeRange(), updatedOperands, op->getAttrs());
return success();
}

// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
auto packedType = getTypeConverter()->packFunctionResults(
llvm::to_vector<4>(op.getOperandTypes()));
auto packedType =
getTypeConverter()->packFunctionResults(op.getOperandTypes());

Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
for (unsigned i = 0; i < numArguments; ++i) {
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed,
updatedOperands[i], i);
for (auto &it : llvm::enumerate(updatedOperands)) {
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, it.value(),
it.index());
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
op->getAttrs());
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Expand Up @@ -220,7 +220,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
/// Start the printf hostcall
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult(0);
Value printfDesc = printfBeginCall.getResult();

// Create a global constant for the format string
unsigned stringNumber = 0;
Expand Down Expand Up @@ -259,7 +259,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
loc, ocklAppendStringN,
ValueRange{printfDesc, stringStart, stringLen,
adaptor.args().empty() ? oneI32 : zeroI32});
printfDesc = appendFormatCall.getResult(0);
printfDesc = appendFormatCall.getResult();

// __ockl_printf_append_args takes 7 values per append call
constexpr size_t argsPerAppend = 7;
Expand Down Expand Up @@ -293,7 +293,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
printfDesc = call.getResult(0);
printfDesc = call.getResult();
}
rewriter.eraseOp(gpuPrintfOp);
return success();
Expand Down
17 changes: 8 additions & 9 deletions mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Expand Up @@ -482,7 +482,7 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
Type elementPtrType = this->getElementPtrType(memRefType);
auto stream = adaptor.asyncDependencies().front();
Value allocatedPtr =
allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult();
allocatedPtr =
rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);

Expand Down Expand Up @@ -539,7 +539,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
continue;
auto idx = operand.getOperandNumber();
auto stream = adaptor.getOperands()[idx];
auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
eventRecordCallBuilder.create(loc, rewriter, {event, stream});
newOperands[idx] = event;
streams.insert(stream);
Expand Down Expand Up @@ -612,8 +612,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
// into the stream just after the last use of the original token operand.
auto *defOp = std::get<0>(pair).getDefiningOp();
rewriter.setInsertionPointAfter(defOp);
auto event =
eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
eventRecordCallBuilder.create(loc, rewriter, {event, operand});
events.push_back(event);
} else {
Expand All @@ -623,7 +622,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
}
}
rewriter.restoreInsertionPoint(insertionPoint);
auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
for (auto event : events)
streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
for (auto event : events)
Expand Down Expand Up @@ -784,11 +783,11 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
launchOp.getKernelModuleName().getValue(),
launchOp.getKernelName().getValue(), loc, rewriter);
auto function = moduleGetFunctionCallBuilder.create(
loc, rewriter, {module.getResult(0), kernelName});
loc, rewriter, {module.getResult(), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
Value stream =
adaptor.asyncDependencies().empty()
? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
: adaptor.asyncDependencies().front();
// Create array of pointers to kernel arguments.
auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
Expand All @@ -798,7 +797,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
: zero;
launchKernelCallBuilder.create(
loc, rewriter,
{function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
{function.getResult(), adaptor.gridSizeX(), adaptor.gridSizeY(),
adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
/*extra=*/nullpointer});
Expand All @@ -814,7 +813,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
streamDestroyCallBuilder.create(loc, rewriter, stream);
rewriter.eraseOp(launchOp);
}
moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());

return success();
}
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
Expand Up @@ -60,17 +60,17 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return failure();

LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp = rewriter.create<LLVM::CallOp>(
op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
auto callOp =
rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);

if (resultType == adaptor.getOperands().front().getType()) {
rewriter.replaceOp(op, {callOp.getResult(0)});
rewriter.replaceOp(op, {callOp.getResult()});
return success();
}

Value truncated = rewriter.create<LLVM::FPTruncOp>(
op->getLoc(), adaptor.getOperands().front().getType(),
callOp.getResult(0));
callOp.getResult());
rewriter.replaceOp(op, {truncated});
return success();
}
Expand Down
Expand Up @@ -374,7 +374,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
loc, TypeRange{getPointerType()}, kInitVulkan);
// The result of `initVulkan` function is a pointer to Vulkan runtime, we
// need to pass that pointer to each Vulkan runtime call.
auto vulkanRuntime = initVulkanCall.getResult(0);
auto vulkanRuntime = initVulkanCall.getResult();

// Create LLVM global with SPIR-V binary data, so we can pass a pointer with
// that data to runtime call.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Expand Up @@ -273,7 +273,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
Value memory =
toDynamic
? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
.getResult(0)
.getResult()
: builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
/*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
Expand Down

0 comments on commit 5e0c3b4

Please sign in to comment.