Skip to content

Commit

Permalink
Node refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 19, 2018
1 parent f519848 commit cf6090a
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 1,289 deletions.
2 changes: 1 addition & 1 deletion Makefile
@@ -1,7 +1,7 @@
DMLC_CORE_PATH ?= ../dmlc-core
LDFLAGS = -pthread -lm
CFLAGS = -std=c++11 -Wall -O2\
-Iinclude -I${DMLC_CORE_PATH}/include -Isrc -fPIC
-Iinclude -I${DMLC_CORE_PATH}/include -I../include -I../dlpack/include -Isrc -fPIC

ifdef no_rtti
CFLAGS += -fno-rtti
Expand Down
30 changes: 19 additions & 11 deletions src/ir/Expr.h
Expand Up @@ -4,6 +4,10 @@
/** \file
* Base classes for Halide expressions (\ref HalideIR::Expr) and statements (\ref HalideIR::Internal::Stmt)
*/
#include <tvm/node/node.h>
#include <tvm/node/memory.h>
#include <tvm/node/ir_functor.h>
#include <tvm/node/container.h>

#include <string>
#include <vector>
Expand All @@ -13,16 +17,20 @@
#include "base/Float16.h"
#include "base/Type.h"
#include "base/Util.h"
#include "tvm/node.h"
#include "tvm/ir_functor.h"
#include "tvm/container.h"


namespace HalideIR {
namespace Internal {
using tvm::Node;
using tvm::NodeRef;
using tvm::Array;
using tvm::NodePtr;
using tvm::make_node;

namespace IR {
using tvm::AttrVisitor;
} // namespace IR

using IR::Node;
using IR::NodeRef;
using IR::Array;
namespace Internal {

struct Variable;
class IRVisitor;
Expand Down Expand Up @@ -148,7 +156,7 @@ struct StmtNode : public BaseStmtNode {
and dispatches visitors. */
struct IRHandle : public NodeRef {
IRHandle() {}
IRHandle(std::shared_ptr<Node> p) : NodeRef(p) {}
IRHandle(NodePtr<Node> p) : NodeRef(p) {}

/** return internal content as IRNode */
inline const IRNode* get() const {
Expand All @@ -169,7 +177,7 @@ struct Expr : public Internal::IRHandle {
Expr() : Internal::IRHandle() {}

/** Make an expression from a concrete expression node pointer (e.g. Add) */
explicit Expr(std::shared_ptr<IR::Node> n) : IRHandle(n) {}
explicit Expr(NodePtr<Node> n) : IRHandle(n) {}

/** Make an expression representing numeric constants of various types. */
// @{
Expand Down Expand Up @@ -236,7 +244,7 @@ struct ExprEqual {
*/
struct VarExpr : public Expr {
VarExpr() : Expr() { }
explicit VarExpr(std::shared_ptr<IR::Node> n) : Expr(n) {}
explicit VarExpr(NodePtr<Node> n) : Expr(n) {}
/**
* constructor from variable
* Choose first have name then type, with default int32
Expand Down Expand Up @@ -278,7 +286,7 @@ enum class ForType : int {
/** A reference-counted handle to a statement node. */
struct Stmt : public IRHandle {
Stmt() : IRHandle() {}
Stmt(std::shared_ptr<IR::Node> n) : IRHandle(n) {}
Stmt(NodePtr<Node> n) : IRHandle(n) {}

/** Dispatch to the correct visitor method for this node. E.g. if
* this node is actually an Add node, then this will call
Expand Down
2 changes: 1 addition & 1 deletion src/ir/FunctionBase.h
Expand Up @@ -23,7 +23,7 @@ class FunctionRef : public NodeRef {
public:
/*! \brief constructor */
FunctionRef() {}
FunctionRef(std::shared_ptr<Node> n) : NodeRef(n) {}
FunctionRef(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down
60 changes: 30 additions & 30 deletions src/ir/IR.cpp
Expand Up @@ -33,7 +33,7 @@ Expr IntImm::make(Type t, int64_t value) {
// Then sign-extending to get them back
value >>= (64 - t.bits());

std::shared_ptr<IntImm> node = std::make_shared<IntImm>();
NodePtr<IntImm> node = make_node<IntImm>();
node->type = t;
node->value = value;
return Expr(node);
Expand All @@ -49,7 +49,7 @@ Expr UIntImm::make(Type t, uint64_t value) {
value <<= (64 - t.bits());
value >>= (64 - t.bits());

std::shared_ptr<UIntImm> node = std::make_shared<UIntImm>();
NodePtr<UIntImm> node = make_node<UIntImm>();
node->type = t;
node->value = value;
return Expr(node);
Expand All @@ -58,7 +58,7 @@ Expr UIntImm::make(Type t, uint64_t value) {
Expr FloatImm::make(Type t, double value) {
internal_assert(t.is_float() && t.is_scalar())
<< "FloatImm must be a scalar Float\n";
std::shared_ptr<FloatImm> node = std::make_shared<FloatImm>();
NodePtr<FloatImm> node = make_node<FloatImm>();
node->type = t;
switch (t.bits()) {
case 16:
Expand All @@ -78,7 +78,7 @@ Expr FloatImm::make(Type t, double value) {
}

Expr StringImm::make(const std::string &val) {
std::shared_ptr<StringImm> node = std::make_shared<StringImm>();
NodePtr<StringImm> node = make_node<StringImm>();
node->type = type_of<const char *>();
node->value = val;
return Expr(node);
Expand All @@ -89,7 +89,7 @@ Expr Cast::make(Type t, Expr v) {
internal_assert(v.defined()) << "Cast of undefined\n";
internal_assert(t.lanes() == v.type().lanes()) << "Cast may not change vector widths\n";

std::shared_ptr<Cast> node = std::make_shared<Cast>();
NodePtr<Cast> node = make_node<Cast>();
node->type = t;
node->value = std::move(v);
return Expr(node);
Expand All @@ -103,7 +103,7 @@ Expr And::make(Expr a, Expr b) {
internal_assert(b.type().is_bool()) << "rhs of And is not a bool\n";
internal_assert(a.type() == b.type()) << "And of mismatched types\n";

std::shared_ptr<And> node = std::make_shared<And>();
NodePtr<And> node = make_node<And>();
node->type = Bool(a.type().lanes());
node->a = std::move(a);
node->b = std::move(b);
Expand All @@ -117,7 +117,7 @@ Expr Or::make(Expr a, Expr b) {
internal_assert(b.type().is_bool()) << "rhs of Or is not a bool\n";
internal_assert(a.type() == b.type()) << "Or of mismatched types\n";

std::shared_ptr<Or> node = std::make_shared<Or>();
NodePtr<Or> node = make_node<Or>();
node->type = Bool(a.type().lanes());
node->a = std::move(a);
node->b = std::move(b);
Expand All @@ -127,7 +127,7 @@ Expr Or::make(Expr a, Expr b) {
Expr Not::make(Expr a) {
internal_assert(a.defined()) << "Not of undefined\n";
internal_assert(a.type().is_bool()) << "argument of Not is not a bool\n";
std::shared_ptr<Not> node = std::make_shared<Not>();
NodePtr<Not> node = make_node<Not>();
node->type = Bool(a.type().lanes());
node->a = std::move(a);
return Expr(node);
Expand All @@ -143,7 +143,7 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
condition.type().lanes() == true_value.type().lanes())
<< "In Select, vector lanes of condition must either be 1, or equal to vector lanes of arguments\n";

std::shared_ptr<Select> node = std::make_shared<Select>();
NodePtr<Select> node = make_node<Select>();
node->type = true_value.type();
node->condition = std::move(condition);
node->true_value = std::move(true_value);
Expand All @@ -159,7 +159,7 @@ Expr Load::make(Type type, VarExpr buffer_var, Expr index, Expr predicate) {
internal_assert(type.lanes() == predicate.type().lanes())
<< "Vector lanes of Load must match vector lanes of predicate\n";

std::shared_ptr<Load> node = std::make_shared<Load>();
NodePtr<Load> node = make_node<Load>();
node->type = type;
node->buffer_var = std::move(buffer_var);
node->index = std::move(index);
Expand All @@ -176,7 +176,7 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) {
internal_assert(stride.type().is_scalar()) << "Ramp with vector stride\n";
internal_assert(lanes > 1) << "Ramp of lanes <= 1\n";
internal_assert(stride.type() == base.type()) << "Ramp of mismatched types\n";
std::shared_ptr<Ramp> node = std::make_shared<Ramp>();
NodePtr<Ramp> node = make_node<Ramp>();
internal_assert(base.defined()) << "Ramp of undefined\n";
node->type = base.type().with_lanes(lanes);
node->base = base;
Expand All @@ -189,7 +189,7 @@ Expr Broadcast::make(Expr value, int lanes) {
internal_assert(value.defined()) << "Broadcast of undefined\n";
internal_assert(value.type().is_scalar()) << "Broadcast of vector\n";
internal_assert(lanes != 1) << "Broadcast of lanes 1\n";
std::shared_ptr<Broadcast> node = std::make_shared<Broadcast>();
NodePtr<Broadcast> node = make_node<Broadcast>();
node->type = value.type().with_lanes(lanes);
node->value = std::move(value);
node->lanes = lanes;
Expand All @@ -200,7 +200,7 @@ Expr Let::make(VarExpr var, Expr value, Expr body) {
internal_assert(value.defined()) << "Let of undefined\n";
internal_assert(body.defined()) << "Let of undefined\n";
internal_assert(value.type() == var.type()) << "Let var mismatch\n";
std::shared_ptr<Let> node = std::make_shared<Let>();
NodePtr<Let> node = make_node<Let>();
node->type = body.type();
node->var = std::move(var);
node->value = std::move(value);
Expand All @@ -212,15 +212,15 @@ Stmt LetStmt::make(VarExpr var, Expr value, Stmt body) {
internal_assert(value.defined()) << "Let of undefined\n";
internal_assert(body.defined()) << "Let of undefined\n";
internal_assert(value.type() == var.type()) << "Let var mismatch\n";
std::shared_ptr<LetStmt> node = std::make_shared<LetStmt>();
NodePtr<LetStmt> node = make_node<LetStmt>();
node->var = std::move(var);
node->value = std::move(value);
node->body = std::move(body);
return Stmt(node);
}

Stmt AttrStmt::make(NodeRef node, std::string attr_key, Expr value, Stmt body) {
auto n = std::make_shared<AttrStmt>();
auto n = make_node<AttrStmt>();
n->node = node;
n->attr_key = std::move(attr_key);
n->value = std::move(value);
Expand All @@ -234,7 +234,7 @@ Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) {
message.as<StringImm>()) << "AssertStmt message must be an int or string:"
<< message << "\n";

std::shared_ptr<AssertStmt> node = std::make_shared<AssertStmt>();
NodePtr<AssertStmt> node = make_node<AssertStmt>();
node->condition = std::move(condition);
node->message = std::move(message);
node->body = std::move(body);
Expand All @@ -244,7 +244,7 @@ Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) {
Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) {
internal_assert(body.defined()) << "ProducerConsumer of undefined\n";

std::shared_ptr<ProducerConsumer> node = std::make_shared<ProducerConsumer>();
NodePtr<ProducerConsumer> node = make_node<ProducerConsumer>();
node->func = std::move(func);
node->is_producer = is_producer;
node->body = std::move(body);
Expand All @@ -262,7 +262,7 @@ Stmt For::make(VarExpr loop_var,
internal_assert(loop_var.type().is_scalar()) << "For with vector loop_var";
internal_assert(body.defined()) << "For of undefined\n";

std::shared_ptr<For> node = std::make_shared<For>();
NodePtr<For> node = make_node<For>();
node->loop_var = std::move(loop_var);
node->min = std::move(min);
node->extent = std::move(extent);
Expand All @@ -280,7 +280,7 @@ Stmt Store::make(VarExpr buffer_var, Expr value, Expr index, Expr predicate) {
internal_assert(value.type().lanes() == predicate.type().lanes())
<< "Vector lanes of Store must match vector lanes of predicate\n";

std::shared_ptr<Store> node = std::make_shared<Store>();
NodePtr<Store> node = make_node<Store>();
node->buffer_var = std::move(buffer_var);
node->value = std::move(value);
node->index = std::move(index);
Expand All @@ -296,7 +296,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> ar
internal_assert(args[i].defined()) << "Provide to undefined location\n";
}

std::shared_ptr<Provide> node = std::make_shared<Provide>();
NodePtr<Provide> node = make_node<Provide>();
node->func = std::move(func);
node->value_index = value_index;
node->value = std::move(value);
Expand All @@ -317,7 +317,7 @@ Stmt Allocate::make(VarExpr buffer_var,
internal_assert(condition.defined()) << "Allocate with undefined condition\n";
internal_assert(condition.type().is_bool()) << "Allocate condition is not boolean\n";

std::shared_ptr<Allocate> node = std::make_shared<Allocate>();
NodePtr<Allocate> node = make_node<Allocate>();
node->buffer_var = std::move(buffer_var);
node->type = type;
node->extents = std::move(extents);
Expand Down Expand Up @@ -365,7 +365,7 @@ int32_t Allocate::constant_allocation_size() const {
}

Stmt Free::make(VarExpr buffer_var) {
std::shared_ptr<Free> node = std::make_shared<Free>();
NodePtr<Free> node = make_node<Free>();
node->buffer_var = buffer_var;
return Stmt(node);
}
Expand All @@ -382,7 +382,7 @@ Stmt Realize::make(FunctionRef func, int value_index, Type type,
internal_assert(condition.defined()) << "Realize with undefined condition\n";
internal_assert(condition.type().is_bool()) << "Realize condition is not boolean\n";

std::shared_ptr<Realize> node = std::make_shared<Realize>();
NodePtr<Realize> node = make_node<Realize>();
node->func = std::move(func);
node->value_index = value_index;
node->type = type;
Expand All @@ -399,7 +399,7 @@ Stmt Prefetch::make(FunctionRef func, int value_index, Type type, Region bounds)
internal_assert(bounds[i]->min.type().is_scalar()) << "Prefetch of vector size\n";
internal_assert(bounds[i]->extent.type().is_scalar()) << "Prefetch of vector size\n";
}
std::shared_ptr<Prefetch> node = std::make_shared<Prefetch>();
NodePtr<Prefetch> node = make_node<Prefetch>();
node->func = std::move(func);
node->value_index = value_index;
node->type = type;
Expand All @@ -411,7 +411,7 @@ Stmt Block::make(Stmt first, Stmt rest) {
internal_assert(first.defined()) << "Block of undefined\n";
internal_assert(rest.defined()) << "Block of undefined\n";

std::shared_ptr<Block> node = std::make_shared<Block>();
NodePtr<Block> node = make_node<Block>();

if (const Block *b = first.as<Block>()) {
// Use a canonical block nesting order
Expand Down Expand Up @@ -440,7 +440,7 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
internal_assert(condition.defined() && then_case.defined()) << "IfThenElse of undefined\n";
// else_case may be null.

std::shared_ptr<IfThenElse> node = std::make_shared<IfThenElse>();
NodePtr<IfThenElse> node = make_node<IfThenElse>();
node->condition = std::move(condition);
node->then_case = std::move(then_case);
node->else_case = std::move(else_case);
Expand All @@ -450,7 +450,7 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
Stmt Evaluate::make(Expr v) {
internal_assert(v.defined()) << "Evaluate of undefined\n";

std::shared_ptr<Evaluate> node = std::make_shared<Evaluate>();
NodePtr<Evaluate> node = make_node<Evaluate>();
node->value = v;
return Stmt(node);
}
Expand All @@ -467,7 +467,7 @@ Expr Call::make(Type type, std::string name, Array<Expr> args, CallType call_typ
}
}

std::shared_ptr<Call> node = std::make_shared<Call>();
NodePtr<Call> node = make_node<Call>();
node->type = type;
node->name = std::move(name);
node->args = std::move(args);
Expand All @@ -478,7 +478,7 @@ Expr Call::make(Type type, std::string name, Array<Expr> args, CallType call_typ
}

VarExpr Variable::make(Type type, std::string name_hint) {
std::shared_ptr<Variable> node = std::make_shared<Variable>();
NodePtr<Variable> node = make_node<Variable>();
node->type = type;
node->name_hint = std::move(name_hint);
return VarExpr(node);
Expand All @@ -502,7 +502,7 @@ Expr Shuffle::make(Array<Expr> vectors,
<< "Shuffle vector index out of range: " << i << "\n";
}

std::shared_ptr<Shuffle> node = std::make_shared<Shuffle>();
NodePtr<Shuffle> node = make_node<Shuffle>();
node->type = element_ty.with_lanes((int)indices.size());
node->vectors = std::move(vectors);
node->indices = std::move(indices);
Expand Down
4 changes: 2 additions & 2 deletions src/ir/IR.h
Expand Up @@ -109,7 +109,7 @@ struct BinaryOpNode : public ExprNode<T> {
internal_assert(a.defined()) << "BinaryOp of undefined\n";
internal_assert(b.defined()) << "BinaryOp of undefined\n";
internal_assert(a.type() == b.type()) << "BinaryOp of mismatched types\n";
std::shared_ptr<T> node = std::make_shared<T>();
NodePtr<T> node = make_node<T>();
node->type = a.type();
node->a = std::move(a);
node->b = std::move(b);
Expand Down Expand Up @@ -175,7 +175,7 @@ struct CmpOpNode : public ExprNode<T> {
internal_assert(a.defined()) << "CmpOp of undefined\n";
internal_assert(b.defined()) << "CmpOp of undefined\n";
internal_assert(a.type() == b.type()) << "BinaryOp of mismatched types\n";
std::shared_ptr<T> node = std::make_shared<T>();
NodePtr<T> node = make_node<T>();
node->type = Bool(a.type().lanes());
node->a = std::move(a);
node->b = std::move(b);
Expand Down

0 comments on commit cf6090a

Please sign in to comment.