Skip to content

Commit

Permalink
Add Ch.6 of the Toy tutorial.
Browse files Browse the repository at this point in the history
This chapters introduces the notion of a full conversion, and adds support for lowering down to the LLVM dialect, LLVM IR, and thus code generation.

PiperOrigin-RevId: 275337786
  • Loading branch information
River707 authored and tensorflower-gardener committed Oct 17, 2019
1 parent 5b03e69 commit 0372eb4
Show file tree
Hide file tree
Showing 33 changed files with 4,254 additions and 2 deletions.
1 change: 1 addition & 0 deletions mlir/examples/toy/CMakeLists.txt
Expand Up @@ -11,3 +11,4 @@ add_subdirectory(Ch2)
add_subdirectory(Ch3)
add_subdirectory(Ch4)
add_subdirectory(Ch5)
add_subdirectory(Ch6)
51 changes: 51 additions & 0 deletions mlir/examples/toy/Ch6/CMakeLists.txt
@@ -0,0 +1,51 @@
add_subdirectory(include)

set(LLVM_LINK_COMPONENTS
Core
Support
)

set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
add_public_tablegen_target(ToyCh6CombineIncGen)

add_toy_chapter(toyc-ch6
toyc.cpp
parser/AST.cpp
mlir/MLIRGen.cpp
mlir/Dialect.cpp
mlir/DeadFunctionEliminationPass.cpp
mlir/LowerToAffineLoops.cpp
mlir/LowerToLLVM.cpp
mlir/ShapeInferencePass.cpp
mlir/ToyCombine.cpp
)

add_dependencies(toyc-ch6 ToyCh6ShapeInferenceInterfaceIncGen)
add_dependencies(toyc-ch6 ToyCh6OpsIncGen)
add_dependencies(toyc-ch6 ToyCh6CombineIncGen)
add_dependencies(toyc-ch6 MLIRCallOpInterfacesIncGen)
include_directories(include/)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
target_link_libraries(toyc-ch6
PRIVATE
MLIRAffineOps
MLIRAnalysis
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
MLIRLoopToStandard
MLIRParser
MLIRPass
MLIRStandardOps
MLIRStandardToLLVM
MLIRTargetLLVMIR
MLIRTransforms
)

whole_archive_link(toyc-ch6
MLIRAffineOps
MLIRLLVMIR
MLIRStandardOps
)
1 change: 1 addition & 0 deletions mlir/examples/toy/Ch6/include/CMakeLists.txt
@@ -0,0 +1 @@
add_subdirectory(toy)
253 changes: 253 additions & 0 deletions mlir/examples/toy/Ch6/include/toy/AST.h
@@ -0,0 +1,253 @@
//===- AST.h - Node definition for the Toy AST ----------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the AST for the Toy language. It is optimized for
// simplicity, not efficiency. The AST forms a tree structure where each node
// references its children using std::unique_ptr<>.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TUTORIAL_TOY_AST_H_
#define MLIR_TUTORIAL_TOY_AST_H_

#include "toy/Lexer.h"

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include <vector>

namespace toy {

/// A variable type with shape information.
struct VarType {
std::vector<int64_t> shape;
};

/// Base class for all expression nodes.
class ExprAST {
public:
enum ExprASTKind {
Expr_VarDecl,
Expr_Return,
Expr_Num,
Expr_Literal,
Expr_Var,
Expr_BinOp,
Expr_Call,
Expr_Print,
};

ExprAST(ExprASTKind kind, Location location)
: kind(kind), location(location) {}

virtual ~ExprAST() = default;

ExprASTKind getKind() const { return kind; }

const Location &loc() { return location; }

private:
const ExprASTKind kind;
Location location;
};

/// A block-list of expressions.
using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;

/// Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;

public:
NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}

double getValue() { return Val; }

/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
};

/// Expression class for a literal value.
class LiteralExprAST : public ExprAST {
std::vector<std::unique_ptr<ExprAST>> values;
std::vector<int64_t> dims;

public:
LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
std::vector<int64_t> dims)
: ExprAST(Expr_Literal, loc), values(std::move(values)),
dims(std::move(dims)) {}

std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
std::vector<int64_t> &getDims() { return dims; }
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
};

/// Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
std::string name;

public:
VariableExprAST(Location loc, const std::string &name)
: ExprAST(Expr_Var, loc), name(name) {}

llvm::StringRef getName() { return name; }

/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
};

/// Expression class for defining a variable.
class VarDeclExprAST : public ExprAST {
std::string name;
VarType type;
std::unique_ptr<ExprAST> initVal;

public:
VarDeclExprAST(Location loc, const std::string &name, VarType type,
std::unique_ptr<ExprAST> initVal)
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
initVal(std::move(initVal)) {}

llvm::StringRef getName() { return name; }
ExprAST *getInitVal() { return initVal.get(); }
VarType &getType() { return type; }

/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
};

/// Expression class for a return operator.
class ReturnExprAST : public ExprAST {
llvm::Optional<std::unique_ptr<ExprAST>> expr;

public:
ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
: ExprAST(Expr_Return, loc), expr(std::move(expr)) {}

llvm::Optional<ExprAST *> getExpr() {
if (expr.hasValue())
return expr->get();
return llvm::NoneType();
}

/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
};

/// Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
char Op;
std::unique_ptr<ExprAST> LHS, RHS;

public:
char getOp() { return Op; }
ExprAST *getLHS() { return LHS.get(); }
ExprAST *getRHS() { return RHS.get(); }

BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
std::unique_ptr<ExprAST> RHS)
: ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
RHS(std::move(RHS)) {}

/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
};

/// Expression class for function calls.
class CallExprAST : public ExprAST {
std::string Callee;
std::vector<std::unique_ptr<ExprAST>> Args;

public:
CallExprAST(Location loc, const std::string &Callee,
std::vector<std::unique_ptr<ExprAST>> Args)
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}

llvm::StringRef getCallee() { return Callee; }
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }

/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
};

/// Expression class for builtin print calls.
class PrintExprAST : public ExprAST {
std::unique_ptr<ExprAST> Arg;

public:
PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
: ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}

ExprAST *getArg() { return Arg.get(); }

/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
};

/// This class represents the "prototype" for a function, which captures its
/// name, and its argument names (thus implicitly the number of arguments the
/// function takes).
class PrototypeAST {
Location location;
std::string name;
std::vector<std::unique_ptr<VariableExprAST>> args;

public:
PrototypeAST(Location location, const std::string &name,
std::vector<std::unique_ptr<VariableExprAST>> args)
: location(location), name(name), args(std::move(args)) {}

const Location &loc() { return location; }
const std::string &getName() const { return name; }
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
return args;
}
};

/// This class represents a function definition itself.
class FunctionAST {
std::unique_ptr<PrototypeAST> Proto;
std::unique_ptr<ExprASTList> Body;

public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
std::unique_ptr<ExprASTList> Body)
: Proto(std::move(Proto)), Body(std::move(Body)) {}
PrototypeAST *getProto() { return Proto.get(); }
ExprASTList *getBody() { return Body.get(); }
};

/// This class represents a list of functions to be processed together
class ModuleAST {
std::vector<FunctionAST> functions;

public:
ModuleAST(std::vector<FunctionAST> functions)
: functions(std::move(functions)) {}

auto begin() -> decltype(functions.begin()) { return functions.begin(); }
auto end() -> decltype(functions.end()) { return functions.end(); }
};

void dump(ModuleAST &);

} // namespace toy

#endif // MLIR_TUTORIAL_TOY_AST_H_
9 changes: 9 additions & 0 deletions mlir/examples/toy/Ch6/include/toy/CMakeLists.txt
@@ -0,0 +1,9 @@
set(LLVM_TARGET_DEFINITIONS Ops.td)
mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
add_public_tablegen_target(ToyCh6OpsIncGen)

set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(ToyCh6ShapeInferenceInterfaceIncGen)
55 changes: 55 additions & 0 deletions mlir/examples/toy/Ch6/include/toy/Dialect.h
@@ -0,0 +1,55 @@
//===- Dialect.h - Dialect definition for the Toy IR ----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the IR Dialect for the Toy language.
// See g3doc/Tutorials/Toy/Ch-2.md for more information.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_

#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StandardTypes.h"
#include "toy/ShapeInferenceInterface.h"

namespace mlir {
namespace toy {

/// This is the definition of the Toy dialect. A dialect inherits from
/// mlir::Dialect and registers custom attributes, operations, and types (in its
/// constructor). It can also override some general behavior exposed via virtual
/// methods.
class ToyDialect : public mlir::Dialect {
public:
explicit ToyDialect(mlir::MLIRContext *ctx);

/// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects.
static llvm::StringRef getDialectNamespace() { return "toy"; }
};

/// Include the auto-generated header file containing the declarations of the
/// toy operations.
#define GET_OP_CLASSES
#include "toy/Ops.h.inc"

} // end namespace toy
} // end namespace mlir

#endif // MLIR_TUTORIAL_TOY_DIALECT_H_

0 comments on commit 0372eb4

Please sign in to comment.