57 changes: 57 additions & 0 deletions mlir/examples/toy/Ch2/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#define TOY_OPS

include "mlir/IR/OpBase.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Provide a definition of the 'toy' dialect in the ODS framework so that we
Expand Down Expand Up @@ -106,6 +108,61 @@ def AddOp : Toy_Op<"add"> {
];
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

def FuncOp : Toy_Op<"func", [
FunctionOpInterface, IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.

Example:

```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];

let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);

let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}

//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }

/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
34 changes: 34 additions & 0 deletions mlir/examples/toy/Ch2/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"

using namespace mlir;
Expand Down Expand Up @@ -187,6 +188,39 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::SymbolRefAttr::get(builder.getContext(), callee));
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}

void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}

//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 9 additions & 13 deletions mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ class MLIRGenImpl {
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());

for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);

// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
Expand Down Expand Up @@ -108,31 +104,31 @@ class MLIRGenImpl {

/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());

// This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}

/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);

// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;

// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();

// Declare all the function arguments in the symbol table.
Expand Down
4 changes: 3 additions & 1 deletion mlir/examples/toy/Ch3/include/toy/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

/// Include the auto-generated header file containing the declaration of the toy
Expand Down
57 changes: 57 additions & 0 deletions mlir/examples/toy/Ch3/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS

include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Provide a definition of the 'toy' dialect in the ODS framework so that we
Expand Down Expand Up @@ -105,6 +107,61 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> {
];
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

def FuncOp : Toy_Op<"func", [
FunctionOpInterface, IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.

Example:

```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];

let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);

let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}

//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }

/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
34 changes: 34 additions & 0 deletions mlir/examples/toy/Ch3/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"

using namespace mlir;
Expand Down Expand Up @@ -174,6 +175,39 @@ mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,

void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}

void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 9 additions & 13 deletions mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ class MLIRGenImpl {
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());

for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);

// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
Expand Down Expand Up @@ -108,31 +104,31 @@ class MLIRGenImpl {

/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());

// This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}

/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);

// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;

// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();

// Declare all the function arguments in the symbol table.
Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch3/toyc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ int dumpMLIR() {
applyPassManagerCLOptions(pm);

// Add a run of the canonicalizer to optimize the mlir module.
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass());
if (mlir::failed(pm.run(*module)))
return 4;
}
Expand Down
4 changes: 3 additions & 1 deletion mlir/examples/toy/Ch4/include/toy/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"
Expand Down
58 changes: 58 additions & 0 deletions mlir/examples/toy/Ch4/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS

include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -134,6 +136,62 @@ def CastOp : Toy_Op<"cast", [
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

def FuncOp : Toy_Op<"func", [
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.

Example:

```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];

let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);

let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}

//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }

/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions mlir/examples/toy/Ch4/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"

Expand Down Expand Up @@ -48,6 +49,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
return true;
}

// All functions within toy can be inlined.
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}

//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -257,6 +264,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return !input.hasRank() || !output.hasRank() || input == output;
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}

void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}

/// Returns the region on the function operation that is callable.
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }

/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getType().getResults();
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 9 additions & 13 deletions mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ class MLIRGenImpl {
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());

for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);

// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
Expand Down Expand Up @@ -108,31 +104,31 @@ class MLIRGenImpl {

/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());

// This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}

/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);

// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;

// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();

// Declare all the function arguments in the symbol table.
Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
public:
void runOnOperation() override {
auto f = getOperation();
Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch4/toyc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ int dumpMLIR() {

// Now that there is only one function, we can infer the shapes of each of
// the operations.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
optPM.addPass(mlir::toy::createShapeInferencePass());
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());
Expand Down
4 changes: 3 additions & 1 deletion mlir/examples/toy/Ch5/include/toy/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"
Expand Down
58 changes: 58 additions & 0 deletions mlir/examples/toy/Ch5/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS

include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -134,6 +136,62 @@ def CastOp : Toy_Op<"cast", [
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

def FuncOp : Toy_Op<"func", [
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.

Example:

```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];

let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);

let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}

//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }

/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions mlir/examples/toy/Ch5/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"

Expand Down Expand Up @@ -48,6 +49,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
return true;
}

// All functions within toy can be inlined.
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}

//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -257,6 +264,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return !input.hasRank() || !output.hasRank() || input == output;
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}

void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}

/// Returns the region on the function operation that is callable.
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }

/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getType().getResults();
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
54 changes: 37 additions & 17 deletions mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/BuiltinDialect.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"

Expand Down Expand Up @@ -197,6 +198,37 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
}
};

//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Func operations
//===----------------------------------------------------------------------===//

struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
using OpConversionPattern<toy::FuncOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// We only lower the main function as we expect that all other functions
// have been inlined.
if (op.getName() != "main")
return failure();

// Verify that the given main has no inputs and results.
if (op.getNumArguments() || op.getType().getNumResults()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "expected 'main' to have 0 inputs and 0 results";
});
}

// Create a new non-toy function, with the same region.
auto func =
rewriter.create<mlir::FuncOp>(op.getLoc(), op.getName(), op.getType());
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
rewriter.eraseOp(op);
return success();
}
};

//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Print operations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -277,7 +309,7 @@ struct TransposeOpLowering : public ConversionPattern {
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<FuncOp>> {
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, func::FuncDialect, memref::MemRefDialect>();
}
Expand All @@ -286,28 +318,16 @@ struct ToyToAffineLoweringPass
} // namespace

void ToyToAffineLoweringPass::runOnOperation() {
FuncOp function = getOperation();

// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;

// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}

// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());

// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine`, `Arithmetic`, `Func`, and `MemRef` dialects.
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();
target
.addLegalDialect<AffineDialect, BuiltinDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();

// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
Expand All @@ -324,7 +344,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
&getContext());

Expand Down
22 changes: 9 additions & 13 deletions mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ class MLIRGenImpl {
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());

for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);

// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
Expand Down Expand Up @@ -108,31 +104,31 @@ class MLIRGenImpl {

/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());

// This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}

/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);

// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;

// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();

// Declare all the function arguments in the symbol table.
Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
public:
void runOnOperation() override {
auto f = getOperation();
Expand Down
9 changes: 5 additions & 4 deletions mlir/examples/toy/Ch5/toyc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,18 @@ int dumpMLIR() {

// Now that there is only one function, we can infer the shapes of each of
// the operations.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
optPM.addPass(mlir::toy::createShapeInferencePass());
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());
}

if (isLoweringToAffine) {
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
// Partially lower the toy dialect.
pm.addPass(mlir::toy::createLowerToAffinePass());

// Partially lower the toy dialect with a few cleanups afterwards.
optPM.addPass(mlir::toy::createLowerToAffinePass());
// Add a few cleanups post lowering.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());

Expand Down
4 changes: 3 additions & 1 deletion mlir/examples/toy/Ch6/include/toy/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"
Expand Down
58 changes: 58 additions & 0 deletions mlir/examples/toy/Ch6/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS

include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -134,6 +136,62 @@ def CastOp : Toy_Op<"cast", [
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

def FuncOp : Toy_Op<"func", [
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.

Example:

```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];

let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);

let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}

//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }

/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions mlir/examples/toy/Ch6/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"

Expand Down Expand Up @@ -48,6 +49,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
return true;
}

// All functions within toy can be inlined.
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}

//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -257,6 +264,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return !input.hasRank() || !output.hasRank() || input == output;
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}

void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}

/// Returns the region on the function operation that is callable.
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }

/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getType().getResults();
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
54 changes: 37 additions & 17 deletions mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/BuiltinDialect.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"

Expand Down Expand Up @@ -197,6 +198,37 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
}
};

//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Func operations
//===----------------------------------------------------------------------===//

struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
using OpConversionPattern<toy::FuncOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// We only lower the main function as we expect that all other functions
// have been inlined.
if (op.getName() != "main")
return failure();

// Verify that the given main has no inputs and results.
if (op.getNumArguments() || op.getType().getNumResults()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "expected 'main' to have 0 inputs and 0 results";
});
}

// Create a new non-toy function, with the same region.
auto func =
rewriter.create<mlir::FuncOp>(op.getLoc(), op.getName(), op.getType());
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
rewriter.eraseOp(op);
return success();
}
};

//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Print operations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -277,7 +309,7 @@ struct TransposeOpLowering : public ConversionPattern {
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<FuncOp>> {
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, func::FuncDialect, memref::MemRefDialect>();
}
Expand All @@ -286,28 +318,16 @@ struct ToyToAffineLoweringPass
} // namespace

void ToyToAffineLoweringPass::runOnOperation() {
auto function = getOperation();

// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;

// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}

// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());

// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine`, `Arithmetic`, `Func`, and `MemRef` dialects.
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();
target
.addLegalDialect<AffineDialect, BuiltinDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();

// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
Expand All @@ -324,7 +344,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
&getContext());

Expand Down
22 changes: 9 additions & 13 deletions mlir/examples/toy/Ch6/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ class MLIRGenImpl {
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());

for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);

// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
Expand Down Expand Up @@ -108,31 +104,31 @@ class MLIRGenImpl {

/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());

// This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}

/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);

// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;

// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();

// Declare all the function arguments in the symbol table.
Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
public:
void runOnOperation() override {
auto f = getOperation();
Expand Down
9 changes: 5 additions & 4 deletions mlir/examples/toy/Ch6/toyc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,18 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,

// Now that there is only one function, we can infer the shapes of each of
// the operations.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
optPM.addPass(mlir::toy::createShapeInferencePass());
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());
}

if (isLoweringToAffine) {
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
// Partially lower the toy dialect.
pm.addPass(mlir::toy::createLowerToAffinePass());

// Partially lower the toy dialect with a few cleanups afterwards.
optPM.addPass(mlir::toy::createLowerToAffinePass());
// Add a few cleanups post lowering.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());

Expand Down
4 changes: 3 additions & 1 deletion mlir/examples/toy/Ch7/include/toy/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"
Expand Down
58 changes: 58 additions & 0 deletions mlir/examples/toy/Ch7/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS

include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -153,6 +155,62 @@ def CastOp : Toy_Op<"cast", [
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

def FuncOp : Toy_Op<"func", [
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.

Example:

```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];

let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);

let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}

//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }

/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions mlir/examples/toy/Ch7/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"

Expand Down Expand Up @@ -49,6 +50,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
return true;
}

// All functions within toy can be inlined.
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}

//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -284,6 +291,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return !input.hasRank() || !output.hasRank() || input == output;
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}

void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}

/// Returns the region on the function operation that is callable.
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }

/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getType().getResults();
}

//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
Expand Down
54 changes: 37 additions & 17 deletions mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/BuiltinDialect.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"

Expand Down Expand Up @@ -197,6 +198,37 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
}
};

//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Func operations
//===----------------------------------------------------------------------===//

struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
using OpConversionPattern<toy::FuncOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// We only lower the main function as we expect that all other functions
// have been inlined.
if (op.getName() != "main")
return failure();

// Verify that the given main has no inputs and results.
if (op.getNumArguments() || op.getType().getNumResults()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "expected 'main' to have 0 inputs and 0 results";
});
}

// Create a new non-toy function, with the same region.
auto func =
rewriter.create<mlir::FuncOp>(op.getLoc(), op.getName(), op.getType());
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
rewriter.eraseOp(op);
return success();
}
};

//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Print operations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -277,7 +309,7 @@ struct TransposeOpLowering : public ConversionPattern {
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<FuncOp>> {
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, func::FuncDialect, memref::MemRefDialect>();
}
Expand All @@ -286,28 +318,16 @@ struct ToyToAffineLoweringPass
} // namespace

void ToyToAffineLoweringPass::runOnOperation() {
auto function = getOperation();

// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;

// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}

// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());

// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine`, `Arithmetic`, `Func`, and `MemRef` dialects.
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();
target
.addLegalDialect<AffineDialect, BuiltinDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();

// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
Expand All @@ -324,7 +344,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
&getContext());

Expand Down
22 changes: 10 additions & 12 deletions mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,9 @@ class MLIRGenImpl {

for (auto &record : moduleAST) {
if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) {
auto func = mlirGen(*funcAST);
mlir::toy::FuncOp func = mlirGen(*funcAST);
if (!func)
return nullptr;

theModule.push_back(func);
functionMap.insert({func.getName(), func});
} else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) {
if (failed(mlirGen(*str)))
Expand Down Expand Up @@ -105,7 +103,7 @@ class MLIRGenImpl {
std::pair<mlir::Value, VarDeclExprAST *>>;

/// A mapping for the functions that have been code generated to MLIR.
llvm::StringMap<mlir::FuncOp> functionMap;
llvm::StringMap<mlir::toy::FuncOp> functionMap;

/// A mapping for named struct types to the underlying MLIR type and the
/// original AST node.
Expand Down Expand Up @@ -157,7 +155,7 @@ class MLIRGenImpl {

/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());

// This is a generic function, the return type will be inferred later.
Expand All @@ -170,23 +168,23 @@ class MLIRGenImpl {
argTypes.push_back(type);
}
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}

/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
SymbolTableScopeT varScope(symbolTable);

// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;

// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();

// Declare all the function arguments in the symbol table.
Expand Down Expand Up @@ -519,7 +517,7 @@ class MLIRGenImpl {
emitError(location) << "no defined function found for '" << callee << "'";
return nullptr;
}
mlir::FuncOp calledFunc = calledFuncIt->second;
mlir::toy::FuncOp calledFunc = calledFuncIt->second;
return builder.create<GenericCallOp>(
location, calledFunc.getType().getResult(0),
mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);
Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
public:
void runOnOperation() override {
auto f = getOperation();
Expand Down
9 changes: 5 additions & 4 deletions mlir/examples/toy/Ch7/toyc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,19 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,

// Now that there is only one function, we can infer the shapes of each of
// the operations.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::toy::createShapeInferencePass());
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());
}

if (isLoweringToAffine) {
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
// Partially lower the toy dialect.
pm.addPass(mlir::toy::createLowerToAffinePass());

// Partially lower the toy dialect with a few cleanups afterwards.
optPM.addPass(mlir::toy::createLowerToAffinePass());
// Add a few cleanups post lowering.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

Expand Down
60 changes: 60 additions & 0 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS

include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -627,6 +629,64 @@ def PDLInterp_ForEachOp
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// pdl_interp::FuncOp
//===----------------------------------------------------------------------===//

def PDLInterp_FuncOp : PDLInterp_Op<"func", [
FunctionOpInterface, IsolatedFromAbove, Symbol
]> {
let summary = "PDL Interpreter Function Operation";
let description = [{
`pdl_interp.func` operations act as interpreter functions. These are
callable SSA-region operations that contain other interpreter operations.
Interpreter functions are used for both the matching and the rewriting
portion of the interpreter.

Example:

```mlir
pdl_interp.func @rewriter(%root: !pdl.operation) {
%op = pdl_interp.create_operation "foo.new_operation"
pdl_interp.erase %root
pdl_interp.finalize
}
```
}];

let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region MinSizedRegion<1>:$body);

// Create the function with the given name and type. This also automatically
// inserts the entry block for the function.
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}

//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }

/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// pdl_interp::GetAttributeOp
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 0 additions & 11 deletions mlir/include/mlir/IR/BuiltinOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,13 @@ def FuncOp : Builtin_Op<"func", [
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getType().getResults(); }

/// Verify the type attribute of this function. Returns failure and emits
/// an error if the attribute is invalid.
LogicalResult verifyType() {
auto type = getTypeAttr().getValue();
if (!type.isa<FunctionType>())
return emitOpError("requires '" + FunctionOpInterface::getTypeAttrName() +
"' attribute of function type");
return success();
}

//===------------------------------------------------------------------===//
// SymbolOpInterface Methods
//===------------------------------------------------------------------===//

bool isDeclaration() { return isExternal(); }
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 2 additions & 4 deletions mlir/include/mlir/IR/FunctionImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,8 @@ ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
bool allowVariadic,
FuncTypeBuilder funcTypeBuilder);

/// Printer implementation for function-like operations. Accepts lists of
/// argument and result types to use while printing.
void printFunctionOp(OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes,
bool isVariadic, ArrayRef<Type> resultTypes);
/// Printer implementation for function-like operations.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic);

/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/FunctionInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef MLIR_IR_FUNCTIONINTERFACES_H
#define MLIR_IR_FUNCTIONINTERFACES_H

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
Expand Down
53 changes: 47 additions & 6 deletions mlir/include/mlir/IR/FunctionInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,41 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
InterfaceMethod<[{
Verify the contents of the body of this function.

Note: The default implementation merely checks that of the entry block
exists, it has the same number arguments as the function type.
Note: The default implementation merely checks that if the entry block
exists, it has the same number and type of arguments as the function type.
}],
"::mlir::LogicalResult", "verifyBody", (ins),
/*methodBody=*/[{}], /*defaultImplementation=*/[{
if ($_op.isExternal())
return success();

unsigned numArguments = $_op.getNumArguments();
if ($_op.front().getNumArguments() != numArguments)
ArrayRef<Type> fnInputTypes = $_op.getArgumentTypes();
Block &entryBlock = $_op.front();

unsigned numArguments = fnInputTypes.size();
if (entryBlock.getNumArguments() != numArguments)
return $_op.emitOpError("entry block must have ")
<< numArguments << " arguments to match function signature";

for (unsigned i = 0, e = fnInputTypes.size(); i != e; ++i) {
Type argType = entryBlock.getArgument(i).getType();
if (fnInputTypes[i] != argType) {
return $_op.emitOpError("type of entry block argument #")
<< i << '(' << argType
<< ") must match the type of the corresponding argument in "
<< "function signature(" << fnInputTypes[i] << ')';
}
}

return success();
}]>,
InterfaceMethod<[{
Verify the type attribute of the function for derived op-specific
invariants.
}],
"::mlir::LogicalResult", "verifyType">,
"::mlir::LogicalResult", "verifyType", (ins),
/*methodBody=*/[{}], /*defaultImplementation=*/[{
return success();
}]>,
];

let extraClassDeclaration = [{
Expand All @@ -108,6 +124,31 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
/// Return the name of the function.
StringRef getName() { return SymbolTable::getSymbolName(*this); }
}];
let extraTraitClassDeclaration = [{
//===------------------------------------------------------------------===//
// Builders
//===------------------------------------------------------------------===//

/// Build the function with the given name, attributes, and type. This
/// builder also inserts an entry block into the function body with the
/// given argument types.
static void buildWithEntryBlock(
OpBuilder &builder, OperationState &state, StringRef name, Type type,
ArrayRef<NamedAttribute> attrs, ArrayRef<Type> inputTypes) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute(function_interface_impl::getTypeAttrName(),
TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());

// Add the function body.
Region *bodyRegion = state.addRegion();
Block *body = new Block();
bodyRegion->push_back(body);
for (Type input : inputTypes)
body->addArgument(input, state.location);
}
}];
let extraSharedClassDeclaration = [{
/// Block list iterator types.
using BlockListType = Region::BlockListType;
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -1942,6 +1942,11 @@ class SizedRegion<int numBlocks> : Region<
CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">,
"region with " # numBlocks # " blocks">;

// A region with at least the given number of blocks.
class MinSizedRegion<int numBlocks> : Region<
CPred<"::llvm::hasNItemsOrMore($_self, " # numBlocks # ")">,
"region with at least " # numBlocks # " blocks">;

// A variadic region constraint. It expands to zero or more of the base region.
class VariadicRegion<Region region>
: Region<region.predicate, region.summary>;
Expand Down
20 changes: 11 additions & 9 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace {
/// given module containing PDL pattern operations.
struct PatternLowering {
public:
PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule);

/// Generate code for matching and rewriting based on the pattern operations
/// within the module.
Expand Down Expand Up @@ -110,7 +110,7 @@ struct PatternLowering {
OpBuilder builder;

/// The matcher function used for all match related logic within PDL patterns.
FuncOp matcherFunc;
pdl_interp::FuncOp matcherFunc;

/// The rewriter module containing the all rewrite related logic within PDL
/// patterns.
Expand All @@ -137,7 +137,8 @@ struct PatternLowering {
};
} // namespace

PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc,
ModuleOp rewriterModule)
: builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}

Expand All @@ -150,7 +151,7 @@ void PatternLowering::lower(ModuleOp module) {

// Insert the root operation, i.e. argument to the matcher, at the root
// position.
Block *matcherEntryBlock = matcherFunc.addEntryBlock();
Block *matcherEntryBlock = &matcherFunc.front();
values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));

// Generate a root matcher node from the provided PDL module.
Expand Down Expand Up @@ -590,13 +591,14 @@ void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {

SymbolRefAttr PatternLowering::generateRewriter(
pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
FuncOp rewriterFunc =
FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
builder.getFunctionType(llvm::None, llvm::None));
builder.setInsertionPointToEnd(rewriterModule.getBody());
auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
pattern.getLoc(), "pdl_generated_rewriter",
builder.getFunctionType(llvm::None, llvm::None));
rewriterSymbolTable.insert(rewriterFunc);

// Generate the rewriter function body.
builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
builder.setInsertionPointToEnd(&rewriterFunc.front());

// Map an input operand of the pattern to a generated interpreter value.
DenseMap<Value, Value> rewriteValues;
Expand Down Expand Up @@ -902,7 +904,7 @@ void PDLToPDLInterpPass::runOnOperation() {
// Create the main matcher function This function contains all of the match
// related functionality from patterns in the module.
OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
FuncOp matcherFunc = builder.create<FuncOp>(
auto matcherFunc = builder.create<pdl_interp::FuncOp>(
module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
builder.getFunctionType(builder.getType<pdl::OperationType>(),
/*results=*/llvm::None),
Expand Down
9 changes: 3 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1997,8 +1997,8 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
// Returns a null type if any of the types provided are non-LLVM types, or if
// there is more than one output type.
static Type
buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc,
ArrayRef<Type> inputs, ArrayRef<Type> outputs,
buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
ArrayRef<Type> outputs,
function_interface_impl::VariadicFlag variadicFlag) {
Builder &b = parser.getBuilder();
if (outputs.size() > 1) {
Expand Down Expand Up @@ -2159,7 +2159,7 @@ LogicalResult LLVMFuncOp::verify() {
}

/// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
/// - entry block arguments are of LLVM types and match the function signature.
/// - entry block arguments are of LLVM types.
LogicalResult LLVMFuncOp::verifyRegions() {
if (isExternal())
return success();
Expand All @@ -2171,9 +2171,6 @@ LogicalResult LLVMFuncOp::verifyRegions() {
if (!isCompatibleType(argType))
return emitOpError("entry block argument #")
<< i << " is not of LLVM type";
if (getType().getParamType(i) != argType)
return emitOpError("the type of entry block argument #")
<< i << " does not match the function signature";
}

return success();
Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"

using namespace mlir;
using namespace mlir::pdl_interp;
Expand Down Expand Up @@ -161,6 +162,29 @@ LogicalResult ForEachOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// pdl_interp::FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs) {
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}

void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}

//===----------------------------------------------------------------------===//
// pdl_interp::GetValueTypeOp
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 1 addition & 23 deletions mlir/lib/IR/BuiltinDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,29 +124,7 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
}

void FuncOp::print(OpAsmPrinter &p) {
FunctionType fnType = getType();
function_interface_impl::printFunctionOp(
p, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
}

LogicalResult FuncOp::verify() {
// If this function is external there is nothing to do.
if (isExternal())
return success();

// Verify that the argument list of the function and the arg list of the entry
// block line up. The trait already verified that the number of arguments is
// the same between the signature and the block.
auto fnInputTypes = getType().getInputs();
Block &entryBlock = front();
for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
return emitOpError("type of entry block argument #")
<< i << '(' << entryBlock.getArgument(i).getType()
<< ") must match the type of the corresponding argument in "
<< "function signature(" << fnInputTypes[i] << ')';

return success();
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}

/// Clone the internal blocks from this function into dest and all attributes
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/IR/FunctionImplementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,9 @@ void mlir::function_interface_impl::printFunctionAttributes(
p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
}

void mlir::function_interface_impl::printFunctionOp(
OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
ArrayRef<Type> resultTypes) {
void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
FunctionOpInterface op,
bool isVariadic) {
// Print the operation and the function name.
auto funcName =
op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
Expand All @@ -358,6 +358,8 @@ void mlir::function_interface_impl::printFunctionOp(
p << visibility.getValue() << ' ';
p.printSymbolName(funcName);

ArrayRef<Type> argTypes = op.getArgumentTypes();
ArrayRef<Type> resultTypes = op.getResultTypes();
printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
{visibilityAttrName});
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Rewrite/ByteCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ class Generator {
private:
/// Allocate memory indices for the results of operations within the matcher
/// and rewriters.
void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
ModuleOp rewriterModule);

/// Generate the bytecode for the given operation.
void generate(Region *region, ByteCodeWriter &writer);
Expand Down Expand Up @@ -482,7 +483,7 @@ struct ByteCodeLiveRange {
} // namespace

void Generator::generate(ModuleOp module) {
FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
pdl_interp::PDLInterpDialect::getMatcherFunctionName());
ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
pdl_interp::PDLInterpDialect::getRewriterModuleName());
Expand All @@ -494,7 +495,7 @@ void Generator::generate(ModuleOp module) {

// Generate code for the rewriter functions.
ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
for (Operation &op : rewriterFunc.getOps())
generate(&op, rewriterByteCodeWriter);
Expand All @@ -514,11 +515,11 @@ void Generator::generate(ModuleOp module) {
}
}

void Generator::allocateMemoryIndices(FuncOp matcherFunc,
void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
ModuleOp rewriterModule) {
// Rewriters use simplistic allocation scheme that simply assigns an index to
// each result.
for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
auto processRewriterValue = [&](Value val) {
valueToMemIndex.try_emplace(val, index++);
Expand Down
12 changes: 1 addition & 11 deletions mlir/test/Dialect/LLVMIR/func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ module {
// -----

module {
// expected-error@+1 {{entry block argument #0 is not of LLVM type}}
// expected-error@+1 {{entry block argument #0('tensor<*xf32>') must match the type of the corresponding argument in function signature('i64')}}
"llvm.func"() ({
^bb0(%arg0: tensor<*xf32>):
llvm.return
Expand All @@ -189,16 +189,6 @@ module {

// -----

module {
// expected-error@+1 {{entry block argument #0 does not match the function signature}}
"llvm.func"() ({
^bb0(%arg0: i32):
llvm.return
}) {sym_name = "wrong_arg_number", type = !llvm.func<void (i64)>} : () -> ()
}

// -----

module {
// expected-error@+1 {{failed to construct function type: expected LLVM type for function arguments}}
llvm.func @foo(tensor<*xf32>)
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Examples/Toy/Ch2/codegen.toy
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ def main() {
print(d);
}

# CHECK-LABEL: func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK-LABEL: toy.func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch2/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch2/scalar.toy
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def main() {
print(a);
}

# CHECK-LABEL: func @main() {
# CHECK-LABEL: toy.func @main() {
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Examples/Toy/Ch3/codegen.toy
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ def main() {
print(d);
}

# CHECK-LABEL: func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK-LABEL: toy.func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch3/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch3/scalar.toy
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def main() {
print(a);
}

# CHECK-LABEL: func @main() {
# CHECK-LABEL: toy.func @main() {
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Examples/Toy/Ch3/transpose_transpose.toy
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def main() {
print(b);
}

# CHECK-LABEL: func @transpose_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK-LABEL: toy.func @transpose_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_0]] : tensor<*xf64>

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_2:%.*]] = toy.generic_call @transpose_transpose([[VAL_1]]) : (tensor<2x3xf64>) -> tensor<*xf64>
# CHECK-NEXT: toy.print [[VAL_2]] : tensor<*xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch3/trivial_reshape.toy
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def main() {
print(c);
}

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
# CHECK-SAME: dense<[
# CHECK-SAME: [1.000000e+00], [2.000000e+00]
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Examples/Toy/Ch4/codegen.toy
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ def main() {
print(d);
}

# CHECK-LABEL: func private @multiply_transpose(
# CHECK-LABEL: toy.func private @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch4/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch4/scalar.toy
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def main() {
print(a);
}

# CHECK-LABEL: func @main() {
# CHECK-LABEL: toy.func @main() {
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Examples/Toy/Ch4/shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

// Check the result of inlining+shape inference on an input module.

func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
%2 = toy.mul %0, %1 : tensor<*xf64>
toy.return %2 : tensor<*xf64>
}
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
Expand All @@ -19,10 +19,10 @@ func @main() {
toy.return
}

// CHECK-NOT: func private @multiply_transpose
// CHECK-NOT: toy.func private @multiply_transpose
// CHECK-NOT: tensor<*xf64>

// CHECK-LABEL: func @main()
// CHECK-LABEL: toy.func @main()
// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch4/transpose_transpose.toy
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main() {
print(b);
}

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
# CHECK-NEXT: toy.return
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch4/trivial_reshape.toy
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def main() {
print(c);
}

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
# CHECK-SAME: dense<[
# CHECK-SAME: [1.000000e+00], [2.000000e+00]
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch5/affine-lowering.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: toyc-ch5 %s -emit=mlir-affine 2>&1 | FileCheck %s
// RUN: toyc-ch5 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT

func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Examples/Toy/Ch5/codegen.toy
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ def main() {
print(d);
}

# CHECK-LABEL: func private @multiply_transpose(
# CHECK-LABEL: toy.func private @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch5/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}
8 changes: 4 additions & 4 deletions mlir/test/Examples/Toy/Ch5/shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

// Check the result of inlining+shape inference on an input module.

func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
%2 = toy.mul %0, %1 : tensor<*xf64>
toy.return %2 : tensor<*xf64>
}
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
Expand All @@ -19,10 +19,10 @@ func @main() {
toy.return
}

// CHECK-NOT: func @multiply_transpose
// CHECK-NOT: toy.func @multiply_transpose
// CHECK-NOT: tensor<*xf64>

// CHECK-LABEL: func @main()
// CHECK-LABEL: toy.func @main()
// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch5/transpose_transpose.toy
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main() {
print(b);
}

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
# CHECK-NEXT: toy.return
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch5/trivial_reshape.toy
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def main() {
print(c);
}

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
# CHECK-SAME: dense<[
# CHECK-SAME: [1.000000e+00], [2.000000e+00]
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch6/affine-lowering.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: toyc-ch6 %s -emit=mlir-affine 2>&1 | FileCheck %s
// RUN: toyc-ch6 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT

func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Examples/Toy/Ch6/codegen.toy
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ def main() {
print(d);
}

# CHECK-LABEL: func private @multiply_transpose(
# CHECK-LABEL: toy.func private @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch6/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch6/llvm-lowering.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: toyc-ch6 %s -emit=llvm -opt

func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch6/scalar.toy
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def main() {
print(a);
}

# CHECK-LABEL: func @main() {
# CHECK-LABEL: toy.func @main() {
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Examples/Toy/Ch6/shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

// Check the result of inlining+shape inference on an input module.

func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
%2 = toy.mul %0, %1 : tensor<*xf64>
toy.return %2 : tensor<*xf64>
}
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
Expand All @@ -19,10 +19,10 @@ func @main() {
toy.return
}

// CHECK-NOT: func @multiply_transpose
// CHECK-NOT: toy.func @multiply_transpose
// CHECK-NOT: tensor<*xf64>

// CHECK-LABEL: func @main()
// CHECK-LABEL: toy.func @main()
// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch6/transpose_transpose.toy
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main() {
print(b);
}

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
# CHECK-NEXT: toy.return
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch6/trivial_reshape.toy
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def main() {
print(c);
}

# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
# CHECK-SAME: dense<[
# CHECK-SAME: [1.000000e+00], [2.000000e+00]
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
// RUN: toyc-ch7 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT

func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>
Expand Down
Loading