Skip to content

Commit

Permalink
Make memref promotion during std->LLVM lowering the default calling c…
Browse files Browse the repository at this point in the history
…onvention

During the conversion from the standard dialect to the LLVM dialect,
memref-typed arguments are promoted from registers to memory and passed into
functions by pointer. This had been introduced into the lowering to work around
the abesnce of calling convention modeling in MLIR to enable better
interoperability with LLVM IR generated from C, and has been exerciced for
several months. Make this promotion the default calling covention when
converting to the LLVM dialect. This adds the documentation, simplifies the
code and makes the conversion consistent across function operations and
function types used in other places, e.g. in high-order functions or
attributes, which would not follow the same rule previously.

PiperOrigin-RevId: 285751280
  • Loading branch information
ftynse authored and tensorflower-gardener committed Dec 16, 2019
1 parent 6a43e0f commit 7c64c7d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 29 deletions.
73 changes: 73 additions & 0 deletions g3doc/ConversionToLLVMDialect.md
Expand Up @@ -107,6 +107,18 @@ Examples:
memref<*xf32> -> !llvm.type<"{i64, i8*}">
```

**In function signatures,** `memref` is passed as a _pointer_ to the structured
defined above to comply with the calling convention.

Example:

```mlir
// A function type with memref as argument
(memref<?xf32>) -> ()
// is transformed into the LLVM function with pointer-to-structure argument.
!llvm.type<"void({ float*, float*, i64, [1 x i64], [1 x i64]}*) ">
```

### Function Types

Function types get converted to LLVM function types. The arguments are converted
Expand Down Expand Up @@ -231,6 +243,67 @@ func @bar() {
"use_i32"(%3) : (!llvm.type<"i32">) -> ()
"use_i64"(%4) : (!llvm.type<"i64">) -> ()
}
```

### Calling Convention for `memref`

For function _arguments_ of `memref` type, ranked or unranked, the type of the
argument is a _pointer_ to the memref descriptor type defined above. The caller
of such function is required to store the descriptor in memory and guarantee
that the storage remains live until the callee returns. The caller can than pass
the pointer to that memory as function argument. The callee loads from the
pointers it was passed as arguments in the entry block of the function, making
the descriptor passed in as argument available for use similarly to
ocally-defined descriptors.

This convention is implemented in the conversion of `std.func` and `std.call` to
the LLVM dialect. Conversions from other dialects should take it into account.
The motivation for this convention is to simplify the ABI for interfacing with
other LLVM modules, in particular those generated from C sources, while avoiding
platform-specific aspects until MLIR has a proper ABI modeling.

Example:

```mlir
func @foo(memref<?xf32>) -> () {
%c0 = constant 0 : index
load %arg0[%c0] : memref<?xf32>
return
}
func @bar(%arg0: index) {
%0 = alloc(%arg0) : memref<?xf32>
call @foo(%0) : (memref<?xf32>)-> ()
return
}
// Gets converted to the following IR.
// Accepts a pointer to the memref descriptor.
llvm.func @foo(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) {
// Loads the descriptor so that it can be used similarly to locally
// created descriptors.
%0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
}
llvm.func @bar(%arg0: !llvm.i64) {
// ... Allocation ...
// Definition of the descriptor.
%7 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
// ... Filling in the descriptor ...
%14 = // The final value of the allocated descriptor.
// Allocate the memory for the descriptor and store it.
%15 = llvm.mlir.constant(1 : index) : !llvm.i64
%16 = llvm.alloca %15 x !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
: (!llvm.i64) -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
llvm.store %14, %16 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
// Pass the pointer to the function.
llvm.call @foo(%16) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> ()
llvm.return
}
```

## Repeated Successor Removal
Expand Down
46 changes: 17 additions & 29 deletions lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Expand Up @@ -123,9 +123,15 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
FunctionType type, bool isVariadic,
LLVMTypeConverter::SignatureConversion &result) {
// Convert argument types one by one and check for errors.
for (auto &en : llvm::enumerate(type.getInputs()))
if (failed(convertSignatureArg(en.index(), en.value(), result)))
for (auto &en : llvm::enumerate(type.getInputs())) {
Type type = en.value();
auto converted = convertType(type).dyn_cast_or_null<LLVM::LLVMType>();
if (!converted)
return {};
if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>())
converted = converted.getPointerTo();
result.addInputs(en.index(), converted);
}

SmallVector<LLVM::LLVMType, 8> argTypes;
argTypes.reserve(llvm::size(result.getConvertedTypes()));
Expand Down Expand Up @@ -522,41 +528,23 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
// Pack the result types into a struct.
Type packedResult;
if (type.getNumResults() != 0)
if (!(packedResult = lowering.packFunctionResults(type.getResults())))
return matchFailure();
LLVM::LLVMType resultType = packedResult
? packedResult.cast<LLVM::LLVMType>()
: LLVM::LLVMType::getVoidTy(&dialect);

SmallVector<LLVM::LLVMType, 4> argTypes;
argTypes.reserve(type.getNumInputs());

// Store the positions of memref-typed arguments so that we can emit loads
// from them to follow the calling convention.
SmallVector<unsigned, 4> promotedArgIndices;
promotedArgIndices.reserve(type.getNumInputs());
for (auto en : llvm::enumerate(type.getInputs())) {
if (en.value().isa<MemRefType>() || en.value().isa<UnrankedMemRefType>())
promotedArgIndices.push_back(en.index());
}

// Convert the original function arguments. Struct arguments are promoted to
// pointer to struct arguments to allow calling external functions with
// various ABIs (e.g. compiled from C/C++ on platform X).
auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
for (auto en : llvm::enumerate(type.getInputs())) {
auto t = en.value();
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
if (!converted)
return matchFailure();
if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>()) {
converted = converted.getPointerTo();
promotedArgIndices.push_back(en.index());
}
argTypes.push_back(converted);
}
for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx)
result.addInputs(idx, argTypes[idx]);

auto llvmType = LLVM::LLVMType::getFunctionTy(
resultType, argTypes, varargsAttr && varargsAttr.getValue());
auto llvmType = lowering.convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);

// Only retain those attributes that are not constructed by build.
SmallVector<NamedAttribute, 4> attributes;
Expand Down
8 changes: 8 additions & 0 deletions test/Conversion/StandardToLLVM/convert-funcs.mlir
Expand Up @@ -18,6 +18,14 @@ func @fifth_order_left(%arg0: (((() -> ()) -> ()) -> ()) -> ())
//CHECK: llvm.func @fifth_order_right(!llvm<"void ()* ()* ()* ()*">)
func @fifth_order_right(%arg0: () -> (() -> (() -> (() -> ()))))

// Check that memrefs are converted to pointers-to-struct if appear as function arguments.
// CHECK: llvm.func @memref_call_conv(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">)
func @memref_call_conv(%arg0: memref<?xf32>)

// Same in nested functions.
// CHECK: llvm.func @memref_call_conv_nested(!llvm<"void ({ float*, float*, i64, [1 x i64], [1 x i64] }*)*">)
func @memref_call_conv_nested(%arg0: (memref<?xf32>) -> ())

//CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm<"void ()*">) -> !llvm<"void ()*"> {
func @pass_through(%arg0: () -> ()) -> (() -> ()) {
// CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm<"void ()*">)
Expand Down

0 comments on commit 7c64c7d

Please sign in to comment.