Skip to content

Commit

Permalink
Use llvm.func to define functions with wrapped LLVM IR function type
Browse files Browse the repository at this point in the history
This function-like operation allows one to define functions that have wrapped
LLVM IR function type, in particular variadic functions. The operation was
added in parallel to the existing lowering flow, this commit only switches the
flow to use it.

Using a custom function type makes the LLVM IR dialect type system more
consistent and avoids complex conversion rules for functions that previously
had to use the built-in function type instead of a wrapped LLVM IR dialect type
and perform conversions during the analysis.

PiperOrigin-RevId: 273910855
  • Loading branch information
ftynse authored and tensorflower-gardener committed Oct 10, 2019
1 parent 309b455 commit 5e7959a
Show file tree
Hide file tree
Showing 29 changed files with 324 additions and 307 deletions.
2 changes: 0 additions & 2 deletions mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
Expand Up @@ -152,8 +152,6 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
ConversionTarget target(*module.getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(applyFullConversion(module, target, patterns, &converter)))
return failure();

Expand Down
23 changes: 11 additions & 12 deletions mlir/examples/toy/Ch5/mlir/LateLowering.cpp
Expand Up @@ -138,14 +138,14 @@ class PrintOpConversion : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
FuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
LLVM::LLVMFuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());

auto print = cast<toy::PrintOp>(op);
auto loc = print.getLoc();
// We will operate on a MemRef abstraction, we use a type.cast to get one
// if our operand is still a Toy array.
Value *operand = memRefTypeCast(rewriter, operands[0]);
Type retTy = printfFunc.getType().getResult(0);
Type retTy = printfFunc.getType().getFunctionResultType();

// Create our loop nest now
using namespace edsc;
Expand Down Expand Up @@ -218,24 +218,23 @@ class PrintOpConversion : public ConversionPattern {

/// Return the prototype declaration for printf in the module, create it if
/// necessary.
FuncOp getPrintf(ModuleOp module) const {
auto printfFunc = module.lookupSymbol<FuncOp>("printf");
LLVM::LLVMFuncOp getPrintf(ModuleOp module) const {
auto printfFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("printf");
if (printfFunc)
return printfFunc;

// Create a function declaration for printf, signature is `i32 (i8*, ...)`
Builder builder(module);
OpBuilder builder(module.getBodyRegion());
auto *dialect =
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();

auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
printfFunc = FuncOp::create(builder.getUnknownLoc(), "printf", printfTy);
// It should be variadic, but we don't support it fully just yet.
printfFunc.setAttr("std.varargs", builder.getBoolAttr(true));
module.push_back(printfFunc);
return printfFunc;
auto printfTy = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
return builder.create<LLVM::LLVMFuncOp>(builder.getUnknownLoc(), "printf",
printfTy,
ArrayRef<NamedAttribute>());
}
};

Expand Down Expand Up @@ -369,10 +368,10 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
ConversionTarget target(getContext());
target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
LLVM::LLVMDialect, StandardOpsDialect>();
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType());
});
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
if (failed(applyPartialConversion(getModule(), target, toyPatterns,
&typeConverter))) {
emitError(UnknownLoc::get(getModule().getContext()),
Expand Down
42 changes: 10 additions & 32 deletions mlir/g3doc/ConversionToLLVMDialect.md
Expand Up @@ -137,49 +137,27 @@ Examples:

### Function Signature Conversion

MLIR function type is built into the representation, even the functions in
dialects including a first-class function type must have the built-in MLIR
function type. During the conversion to LLVM IR, function signatures are
converted as follows:

- the outer type remains the built-in MLIR function;
- function arguments are converted individually following these rules;
- function results:
- zero-result functions remain zero-result;
- single-result functions have their result type converted according to
these rules;
- multi-result functions have a single result type of the wrapped LLVM IR
structure type with elements corresponding to the converted original
results.

Rationale: function definitions remain analyzable within MLIR without having to
abstract away the function type. In order to remain consistent with the regular
MLIR functions, we do not introduce a `void` result type since we cannot create
a value of `void` type that MLIR passes might expect to be returned from a
function.
LLVM IR functions are defined by a custom operation. The function itself has a
wrapped LLVM IR function type converted as described above. The function
definition operation uses MLIR syntax.

Examples:

```mlir {.mlir}
// zero-ary function type with no results.
func @foo() -> ()
// remains as is
func @foo() -> ()
// gets LLVM type void().
llvm.func @foo() -> ()
// unary function with one result
// function with one result
func @bar(i32) -> (i64)
// has its argument and result type converted
func @bar(!llvm.type<"i32">) -> !llvm.type<"i64">
// binary function with one result
func @baz(i32, f32) -> (i64)
// has its arguments handled separately
func @baz(!llvm.type<"i32">, !llvm.type<"float">) -> !llvm.type<"i64">
// gets converted to LLVM type i64(i32).
func @bar(!llvm.i32) -> !llvm.i64
// binary function with two results
// function with two results
func @qux(i32, f32) -> (i64, f64)
// has its result aggregated into a structure type
func @qux(!llvm.type<"i32">, !llvm.type<"float">) -> !llvm.type<"{i64, double}">
func @qux(!llvm.i32, !llvm.float) -> !llvm.type<"{i64, double}">
// function-typed arguments or results in higher-order functions
func @quux(() -> ()) -> (() -> ())
Expand Down
24 changes: 24 additions & 0 deletions mlir/g3doc/Dialects/LLVM.md
Expand Up @@ -50,6 +50,30 @@ specific LLVM IR type.
All operations in the LLVM IR dialect have a custom form in MLIR. The mnemonic
of an operation is that used in LLVM IR prefixed with "`llvm.`".

### LLVM functions

MLIR functions are defined by an operation that is not built into the IR itself.
The LLVM IR dialect provides an `llvm.func` operation to define functions
compatible with LLVM IR. These functions have wrapped LLVM IR function type but
use MLIR syntax to express it. They are required to have exactly one result
type. LLVM function operation is intended to capture additional properties of
LLVM functions, such as linkage and calling convention, that may be modeled
differently by the built-in MLIR function.

```mlir {.mlir}
// The type of @bar is !llvm<"i64 (i64)">
llvm.func @bar(%arg0: !llvm.i64) -> !llvm.i64 {
llvm.return %arg0 : !llvm.i64
}
// Type type of @foo is !llvm<"void (i64)">
// !llvm.void type is omitted
llvm.func @foo(%arg0: !llvm.i64) {
llvm.return
}
```

### LLVM IR operations

The following operations are currently supported. The semantics of these
Expand Down
5 changes: 1 addition & 4 deletions mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
Expand Up @@ -25,15 +25,12 @@

namespace mlir {

class FuncOp;
class Location;
class ModuleOp;
class OpBuilder;
class Value;

namespace LLVM {
class LLVMDialect;
}
} // namespace LLVM

template <typename T> class OpPassBase;

Expand Down
Expand Up @@ -50,6 +50,12 @@ class LLVMTypeConverter : public TypeConverter {
/// non-standard or non-builtin types.
Type convertType(Type t) override;

/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
SignatureConversion &result);

/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one values is
/// returned, create an LLVM IR structure type with elements that correspond
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/GPU/GPUDialect.h
Expand Up @@ -55,7 +55,7 @@ class GPUDialect : public Dialect {

/// Returns whether the given function is a kernel function, i.e., has the
/// 'gpu.kernel' attribute.
static bool isKernel(FuncOp function);
static bool isKernel(Operation *op);

LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute attr) override;
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Expand Up @@ -64,6 +64,9 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
LLVMDialect &getDialect();
llvm::Type *getUnderlyingType() const;

/// Utilities to identify types.
bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }

/// Array type utilities.
LLVMType getArrayElementType();
unsigned getArrayNumElements();
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Expand Up @@ -525,11 +525,15 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",

let builders = [
OpBuilder<"Builder *builder, OperationState &result, StringRef name, "
"LLVMType type, ArrayRef<NamedAttribute> attrs, "
"LLVMType type, ArrayRef<NamedAttribute> attrs = {}, "
"ArrayRef<NamedAttributeList> argAttrs = {}">
];

let extraClassDeclaration = [{
// Add an entry block to an empty function, and set up the block arguments
// to match the signature of the function.
Block *addEntryBlock();

LLVMType getType() {
return getAttrOfType<TypeAttr>(getTypeAttrName())
.getValue().cast<LLVMType>();
Expand Down
7 changes: 4 additions & 3 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Expand Up @@ -34,13 +34,14 @@

namespace mlir {
class Attribute;
class FuncOp;
class Location;
class ModuleOp;
class Operation;

namespace LLVM {

class LLVMFuncOp;

// Implementation class for module translation. Holds a reference to the module
// being translated, and the mappings between the original and the translated
// functions, basic blocks and values. It is practically easier to hold these
Expand Down Expand Up @@ -75,8 +76,8 @@ class ModuleTranslation {
private:
LogicalResult convertFunctions();
void convertGlobals();
LogicalResult convertOneFunction(FuncOp func);
void connectPHINodes(FuncOp func);
LogicalResult convertOneFunction(LLVMFuncOp func);
void connectPHINodes(LLVMFuncOp func);
LogicalResult convertBlock(Block &bb, bool ignoreArguments);

template <typename Range>
Expand Down

0 comments on commit 5e7959a

Please sign in to comment.