From cf6090aeaeb782d1daff54b0ca5c2c281d7008db Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 19 Sep 2018 09:29:52 -0700 Subject: [PATCH] Node refactor --- Makefile | 2 +- src/ir/Expr.h | 30 ++- src/ir/FunctionBase.h | 2 +- src/ir/IR.cpp | 60 ++--- src/ir/IR.h | 4 +- src/ir/Range.h | 7 +- src/tvm/container.h | 591 ------------------------------------------ src/tvm/ir_functor.h | 264 ------------------- src/tvm/node.cpp | 58 ----- src/tvm/node.h | 329 ----------------------- 10 files changed, 58 insertions(+), 1289 deletions(-) delete mode 100644 src/tvm/container.h delete mode 100644 src/tvm/ir_functor.h delete mode 100644 src/tvm/node.cpp delete mode 100644 src/tvm/node.h diff --git a/Makefile b/Makefile index f1c6ea7..ab1d160 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ DMLC_CORE_PATH ?= ../dmlc-core LDFLAGS = -pthread -lm CFLAGS = -std=c++11 -Wall -O2\ - -Iinclude -I${DMLC_CORE_PATH}/include -Isrc -fPIC + -Iinclude -I${DMLC_CORE_PATH}/include -I../include -I../dlpack/include -Isrc -fPIC ifdef no_rtti CFLAGS += -fno-rtti diff --git a/src/ir/Expr.h b/src/ir/Expr.h index 1323acd..cb52ac4 100644 --- a/src/ir/Expr.h +++ b/src/ir/Expr.h @@ -4,6 +4,10 @@ /** \file * Base classes for Halide expressions (\ref HalideIR::Expr) and statements (\ref HalideIR::Internal::Stmt) */ +#include +#include +#include +#include #include #include @@ -13,16 +17,20 @@ #include "base/Float16.h" #include "base/Type.h" #include "base/Util.h" -#include "tvm/node.h" -#include "tvm/ir_functor.h" -#include "tvm/container.h" + namespace HalideIR { -namespace Internal { +using tvm::Node; +using tvm::NodeRef; +using tvm::Array; +using tvm::NodePtr; +using tvm::make_node; + +namespace IR { +using tvm::AttrVisitor; +} // namespace IR -using IR::Node; -using IR::NodeRef; -using IR::Array; +namespace Internal { struct Variable; class IRVisitor; @@ -148,7 +156,7 @@ struct StmtNode : public BaseStmtNode { and dispatches visitors. */ struct IRHandle : public NodeRef { IRHandle() {} - IRHandle(std::shared_ptr p) : NodeRef(p) {} + IRHandle(NodePtr p) : NodeRef(p) {} /** return internal content as IRNode */ inline const IRNode* get() const { @@ -169,7 +177,7 @@ struct Expr : public Internal::IRHandle { Expr() : Internal::IRHandle() {} /** Make an expression from a concrete expression node pointer (e.g. Add) */ - explicit Expr(std::shared_ptr n) : IRHandle(n) {} + explicit Expr(NodePtr n) : IRHandle(n) {} /** Make an expression representing numeric constants of various types. */ // @{ @@ -236,7 +244,7 @@ struct ExprEqual { */ struct VarExpr : public Expr { VarExpr() : Expr() { } - explicit VarExpr(std::shared_ptr n) : Expr(n) {} + explicit VarExpr(NodePtr n) : Expr(n) {} /** * constructor from variable * Choose first have name then type, with default int32 @@ -278,7 +286,7 @@ enum class ForType : int { /** A reference-counted handle to a statement node. */ struct Stmt : public IRHandle { Stmt() : IRHandle() {} - Stmt(std::shared_ptr n) : IRHandle(n) {} + Stmt(NodePtr n) : IRHandle(n) {} /** Dispatch to the correct visitor method for this node. E.g. if * this node is actually an Add node, then this will call diff --git a/src/ir/FunctionBase.h b/src/ir/FunctionBase.h index ba3c803..823c49c 100644 --- a/src/ir/FunctionBase.h +++ b/src/ir/FunctionBase.h @@ -23,7 +23,7 @@ class FunctionRef : public NodeRef { public: /*! \brief constructor */ FunctionRef() {} - FunctionRef(std::shared_ptr n) : NodeRef(n) {} + FunctionRef(NodePtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/src/ir/IR.cpp b/src/ir/IR.cpp index 4876dea..87fae95 100644 --- a/src/ir/IR.cpp +++ b/src/ir/IR.cpp @@ -33,7 +33,7 @@ Expr IntImm::make(Type t, int64_t value) { // Then sign-extending to get them back value >>= (64 - t.bits()); - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = t; node->value = value; return Expr(node); @@ -49,7 +49,7 @@ Expr UIntImm::make(Type t, uint64_t value) { value <<= (64 - t.bits()); value >>= (64 - t.bits()); - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = t; node->value = value; return Expr(node); @@ -58,7 +58,7 @@ Expr UIntImm::make(Type t, uint64_t value) { Expr FloatImm::make(Type t, double value) { internal_assert(t.is_float() && t.is_scalar()) << "FloatImm must be a scalar Float\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = t; switch (t.bits()) { case 16: @@ -78,7 +78,7 @@ Expr FloatImm::make(Type t, double value) { } Expr StringImm::make(const std::string &val) { - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = type_of(); node->value = val; return Expr(node); @@ -89,7 +89,7 @@ Expr Cast::make(Type t, Expr v) { internal_assert(v.defined()) << "Cast of undefined\n"; internal_assert(t.lanes() == v.type().lanes()) << "Cast may not change vector widths\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = t; node->value = std::move(v); return Expr(node); @@ -103,7 +103,7 @@ Expr And::make(Expr a, Expr b) { internal_assert(b.type().is_bool()) << "rhs of And is not a bool\n"; internal_assert(a.type() == b.type()) << "And of mismatched types\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = Bool(a.type().lanes()); node->a = std::move(a); node->b = std::move(b); @@ -117,7 +117,7 @@ Expr Or::make(Expr a, Expr b) { internal_assert(b.type().is_bool()) << "rhs of Or is not a bool\n"; internal_assert(a.type() == b.type()) << "Or of mismatched types\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = Bool(a.type().lanes()); node->a = std::move(a); node->b = std::move(b); @@ -127,7 +127,7 @@ Expr Or::make(Expr a, Expr b) { Expr Not::make(Expr a) { internal_assert(a.defined()) << "Not of undefined\n"; internal_assert(a.type().is_bool()) << "argument of Not is not a bool\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = Bool(a.type().lanes()); node->a = std::move(a); return Expr(node); @@ -143,7 +143,7 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) { condition.type().lanes() == true_value.type().lanes()) << "In Select, vector lanes of condition must either be 1, or equal to vector lanes of arguments\n"; - std::shared_ptr(); + NodePtr(); node->type = true_value.type(); node->condition = std::move(condition); node->true_value = std::move(true_value); @@ -159,7 +159,7 @@ Expr Load::make(Type type, VarExpr buffer_var, Expr index, Expr predicate) { internal_assert(type.lanes() == predicate.type().lanes()) << "Vector lanes of Load must match vector lanes of predicate\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = type; node->buffer_var = std::move(buffer_var); node->index = std::move(index); @@ -176,7 +176,7 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) { internal_assert(stride.type().is_scalar()) << "Ramp with vector stride\n"; internal_assert(lanes > 1) << "Ramp of lanes <= 1\n"; internal_assert(stride.type() == base.type()) << "Ramp of mismatched types\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); internal_assert(base.defined()) << "Ramp of undefined\n"; node->type = base.type().with_lanes(lanes); node->base = base; @@ -189,7 +189,7 @@ Expr Broadcast::make(Expr value, int lanes) { internal_assert(value.defined()) << "Broadcast of undefined\n"; internal_assert(value.type().is_scalar()) << "Broadcast of vector\n"; internal_assert(lanes != 1) << "Broadcast of lanes 1\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = value.type().with_lanes(lanes); node->value = std::move(value); node->lanes = lanes; @@ -200,7 +200,7 @@ Expr Let::make(VarExpr var, Expr value, Expr body) { internal_assert(value.defined()) << "Let of undefined\n"; internal_assert(body.defined()) << "Let of undefined\n"; internal_assert(value.type() == var.type()) << "Let var mismatch\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = body.type(); node->var = std::move(var); node->value = std::move(value); @@ -212,7 +212,7 @@ Stmt LetStmt::make(VarExpr var, Expr value, Stmt body) { internal_assert(value.defined()) << "Let of undefined\n"; internal_assert(body.defined()) << "Let of undefined\n"; internal_assert(value.type() == var.type()) << "Let var mismatch\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); @@ -220,7 +220,7 @@ Stmt LetStmt::make(VarExpr var, Expr value, Stmt body) { } Stmt AttrStmt::make(NodeRef node, std::string attr_key, Expr value, Stmt body) { - auto n = std::make_shared(); + auto n = make_node(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); @@ -234,7 +234,7 @@ Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) { message.as()) << "AssertStmt message must be an int or string:" << message << "\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); @@ -244,7 +244,7 @@ Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) { Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) { internal_assert(body.defined()) << "ProducerConsumer of undefined\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->func = std::move(func); node->is_producer = is_producer; node->body = std::move(body); @@ -262,7 +262,7 @@ Stmt For::make(VarExpr loop_var, internal_assert(loop_var.type().is_scalar()) << "For with vector loop_var"; internal_assert(body.defined()) << "For of undefined\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->loop_var = std::move(loop_var); node->min = std::move(min); node->extent = std::move(extent); @@ -280,7 +280,7 @@ Stmt Store::make(VarExpr buffer_var, Expr value, Expr index, Expr predicate) { internal_assert(value.type().lanes() == predicate.type().lanes()) << "Vector lanes of Store must match vector lanes of predicate\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->buffer_var = std::move(buffer_var); node->value = std::move(value); node->index = std::move(index); @@ -296,7 +296,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array ar internal_assert(args[i].defined()) << "Provide to undefined location\n"; } - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->func = std::move(func); node->value_index = value_index; node->value = std::move(value); @@ -317,7 +317,7 @@ Stmt Allocate::make(VarExpr buffer_var, internal_assert(condition.defined()) << "Allocate with undefined condition\n"; internal_assert(condition.type().is_bool()) << "Allocate condition is not boolean\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->buffer_var = std::move(buffer_var); node->type = type; node->extents = std::move(extents); @@ -365,7 +365,7 @@ int32_t Allocate::constant_allocation_size() const { } Stmt Free::make(VarExpr buffer_var) { - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->buffer_var = buffer_var; return Stmt(node); } @@ -382,7 +382,7 @@ Stmt Realize::make(FunctionRef func, int value_index, Type type, internal_assert(condition.defined()) << "Realize with undefined condition\n"; internal_assert(condition.type().is_bool()) << "Realize condition is not boolean\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->func = std::move(func); node->value_index = value_index; node->type = type; @@ -399,7 +399,7 @@ Stmt Prefetch::make(FunctionRef func, int value_index, Type type, Region bounds) internal_assert(bounds[i]->min.type().is_scalar()) << "Prefetch of vector size\n"; internal_assert(bounds[i]->extent.type().is_scalar()) << "Prefetch of vector size\n"; } - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->func = std::move(func); node->value_index = value_index; node->type = type; @@ -411,7 +411,7 @@ Stmt Block::make(Stmt first, Stmt rest) { internal_assert(first.defined()) << "Block of undefined\n"; internal_assert(rest.defined()) << "Block of undefined\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); if (const Block *b = first.as()) { // Use a canonical block nesting order @@ -440,7 +440,7 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { internal_assert(condition.defined() && then_case.defined()) << "IfThenElse of undefined\n"; // else_case may be null. - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->condition = std::move(condition); node->then_case = std::move(then_case); node->else_case = std::move(else_case); @@ -450,7 +450,7 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { Stmt Evaluate::make(Expr v) { internal_assert(v.defined()) << "Evaluate of undefined\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->value = v; return Stmt(node); } @@ -467,7 +467,7 @@ Expr Call::make(Type type, std::string name, Array args, CallType call_typ } } - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = type; node->name = std::move(name); node->args = std::move(args); @@ -478,7 +478,7 @@ Expr Call::make(Type type, std::string name, Array args, CallType call_typ } VarExpr Variable::make(Type type, std::string name_hint) { - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = type; node->name_hint = std::move(name_hint); return VarExpr(node); @@ -502,7 +502,7 @@ Expr Shuffle::make(Array vectors, << "Shuffle vector index out of range: " << i << "\n"; } - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = element_ty.with_lanes((int)indices.size()); node->vectors = std::move(vectors); node->indices = std::move(indices); diff --git a/src/ir/IR.h b/src/ir/IR.h index 1668926..6374d12 100644 --- a/src/ir/IR.h +++ b/src/ir/IR.h @@ -109,7 +109,7 @@ struct BinaryOpNode : public ExprNode { internal_assert(a.defined()) << "BinaryOp of undefined\n"; internal_assert(b.defined()) << "BinaryOp of undefined\n"; internal_assert(a.type() == b.type()) << "BinaryOp of mismatched types\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = a.type(); node->a = std::move(a); node->b = std::move(b); @@ -175,7 +175,7 @@ struct CmpOpNode : public ExprNode { internal_assert(a.defined()) << "CmpOp of undefined\n"; internal_assert(b.defined()) << "CmpOp of undefined\n"; internal_assert(a.type() == b.type()) << "BinaryOp of mismatched types\n"; - std::shared_ptr node = std::make_shared(); + NodePtr node = make_node(); node->type = Bool(a.type().lanes()); node->a = std::move(a); node->b = std::move(b); diff --git a/src/ir/Range.h b/src/ir/Range.h index 61c4ff3..2ebc624 100644 --- a/src/ir/Range.h +++ b/src/ir/Range.h @@ -20,7 +20,7 @@ class Range : public NodeRef { public: /*! \brief constructor */ Range() {} - Range(std::shared_ptr n) : NodeRef(n) {} + Range(NodePtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -68,7 +68,10 @@ inline const RangeNode* Range::operator->() const { inline Range Range::make_by_min_extent(Expr min, Expr extent) { internal_assert(min.type() == extent.type()) << "Region min and extent must have same type\n"; - return Range(std::make_shared(min, extent)); + NodePtr n = make_node(); + n->min = min; + n->extent = extent; + return Range(n); } // overload print function diff --git a/src/tvm/container.h b/src/tvm/container.h deleted file mode 100644 index 1fc21a3..0000000 --- a/src/tvm/container.h +++ /dev/null @@ -1,591 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file container.h - * \brief Array/Map container in the DSL graph. - */ -#ifndef TVM_CONTAINER_H_ -#define TVM_CONTAINER_H_ - -#include -#include -#include -#include -#include -#include "./node.h" - -namespace tvm { - -/*! \brief array node content in array */ -class ArrayNode : public Node { - public: - /*! \brief the data content */ - std::vector > data; - - void VisitAttrs(AttrVisitor* visitor) final { - // Visitor to array have no effect. - } - - static constexpr const char* _type_key = "Array"; - TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node); -}; - -/*! \brief map node content */ -class MapNode : public Node { - public: - void VisitAttrs(AttrVisitor* visitor) final { - // Visitor to map have no effect. - } - // hash function - struct Hash { - size_t operator()(const std::shared_ptr& n) const { - return std::hash()(n.get()); - } - }; - // comparator - struct Equal { - bool operator()( - const std::shared_ptr& a, - const std::shared_ptr& b) const { - return a.get() == b.get(); - } - }; - - /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map< - std::shared_ptr, - std::shared_ptr, - Hash, Equal>; - - /*! \brief the data content */ - ContainerType data; - - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node); -}; - - -/*! \brief specialized map node with string as key */ -class StrMapNode : public Node { - public: - void VisitAttrs(AttrVisitor* visitor) final { - // Visitor to map have no effect. - } - /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map< - std::string, - std::shared_ptr >; - - /*! \brief the data content */ - ContainerType data; - - static constexpr const char* _type_key = "StrMap"; - TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node); -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class IterAdapter { - public: - explicit IterAdapter(TIter iter) : iter_(iter) {} - inline IterAdapter& operator++() { // NOLINT(*) - ++iter_; - return *this; - } - inline IterAdapter& operator++(int) { // NOLINT(*) - ++iter_; - return *this; - } - inline IterAdapter operator+(int offset) const { // NOLINT(*) - return IterAdapter(iter_ + offset); - } - inline bool operator==(IterAdapter other) const { - return iter_ == other.iter_; - } - inline bool operator!=(IterAdapter other) const { - return !(*this == other); - } - inline const typename Converter::ResultType operator*() const { - return Converter::convert(*iter_); - } - - private: - TIter iter_; -}; - -/*! - * \brief Array container of NodeRef in DSL graph. - * Array implements copy on write semantics, which means array is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam T The content NodeRef type. - */ -template::value>::type > -class Array : public NodeRef { - public: - /*! - * \brief default constructor - */ - Array() { - node_ = std::make_shared(); - } - /*! - * \brief move constructor - * \param other source - */ - Array(Array && other) { // NOLINT(*) - node_ = std::move(other.node_); - } - /*! - * \brief copy constructor - * \param other source - */ - Array(const Array &other) { // NOLINT(*) - node_ = other.node_; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Array(std::shared_ptr n) : NodeRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Array(IterType begin, IterType end) { - assign(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Array(std::initializer_list init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - /*! - * \brief constructor from vector - * \param init The vector - */ - Array(const std::vector& init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(Array && other) { - node_ = std::move(other.node_); - return *this; - } - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(const Array & other) { - node_ = other.node_; - return *this; - } - /*! - * \brief reset the array to content from iterator. - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - void assign(IterType begin, IterType end) { - auto n = std::make_shared(); - for (IterType it = begin; it != end; ++it) { - n->data.push_back((*it).node_); - } - node_ = std::move(n); - } - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - inline const T operator[](size_t i) const { - return T(static_cast(node_.get())->data[i]); - } - /*! \return The size of the array */ - inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); - } - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - inline ArrayNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { - node_ = std::make_shared( - *static_cast(node_.get())); - } - return static_cast(node_.get()); - } - /*! - * \brief push a new item to the back of the list - * \param item The item to be pushed. - */ - inline void push_back(const T& item) { - ArrayNode* n = this->CopyOnWrite(); - n->data.push_back(item.node_); - } - /*! - * \brief set i-th element of the array. - * \param i The index - * \param value The value to be setted. - */ - inline void Set(size_t i, const T& value) { - ArrayNode* n = this->CopyOnWrite(); - n->data[i] = value.node_; - } - /*! \return whether array is empty */ - inline bool empty() const { - return size() == 0; - } - /*! \brief specify container node */ - using ContainerType = ArrayNode; - - struct Ptr2NodeRef { - using ResultType = T; - static inline T convert(const std::shared_ptr& n) { - return T(n); - } - }; - using iterator = IterAdapter >::const_iterator>; - - using reverse_iterator = IterAdapter< - Ptr2NodeRef, - std::vector >::const_reverse_iterator>; - - /*! \return begin iterator */ - inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); - } - /*! \return end iterator */ - inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); - } - /*! \return rbegin iterator */ - inline reverse_iterator rbegin() const { - return reverse_iterator(static_cast(node_.get())->data.rbegin()); - } - /*! \return rend iterator */ - inline reverse_iterator rend() const { - return reverse_iterator(static_cast(node_.get())->data.rend()); - } -}; - -/*! - * \brief Map container of NodeRef->NodeRef in DSL graph. - * Map implements copy on write semantics, which means map is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam K The key NodeRef type. - * \tparam V The value NodeRef type. - */ -template::value || - std::is_base_of::value >::type, - typename = typename std::enable_if::value>::type> -class Map : public NodeRef { - public: - /*! - * \brief default constructor - */ - Map() { - node_ = std::make_shared(); - } - /*! - * \brief move constructor - * \param other source - */ - Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); - } - /*! - * \brief copy constructor - * \param other source - */ - Map(const Map &other) { // NOLINT(*) - node_ = other.node_; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Map(std::shared_ptr n) : NodeRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Map(IterType begin, IterType end) { - assign(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Map(std::initializer_list > init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - /*! - * \brief constructor from vector - * \param init The vector - */ - template - Map(const std::unordered_map& init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(Map && other) { - node_ = std::move(other.node_); - return *this; - } - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(const Map & other) { - node_ = other.node_; - return *this; - } - /*! - * \brief reset the array to content from iterator. - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - void assign(IterType begin, IterType end) { - auto n = std::make_shared(); - for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first.node_, - i->second.node_)); - } - node_ = std::move(n); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - inline const V operator[](const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - inline const V at(const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); - } - /*! \return The size of the array */ - inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); - } - /*! \return The size of the array */ - inline size_t count(const K& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key.node_); - } - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - inline MapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { - node_ = std::make_shared( - *static_cast(node_.get())); - } - return static_cast(node_.get()); - } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - inline void Set(const K& key, const V& value) { - MapNode* n = this->CopyOnWrite(); - n->data[key.node_] = value.node_; - } - - /*! \return whether array is empty */ - inline bool empty() const { - return size() == 0; - } - /*! \brief specify container node */ - using ContainerType = MapNode; - - struct Ptr2NodeRef { - using ResultType = std::pair; - static inline ResultType convert(const std::pair< - std::shared_ptr, - std::shared_ptr >& n) { - return std::make_pair(K(n.first), V(n.second)); - } - }; - - using iterator = IterAdapter< - Ptr2NodeRef, MapNode::ContainerType::const_iterator>; - - /*! \return begin iterator */ - inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); - } - /*! \return end iterator */ - inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); - } - /*! \return begin iterator */ - inline iterator find(const K& key) const { - return iterator(static_cast(node_.get())->data.find(key.node_)); - } -}; - -// specialize of string map -template -class Map : public NodeRef { - public: - // for code reuse - Map() { - node_ = std::make_shared(); - } - Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); - } - Map(const Map &other) { // NOLINT(*) - node_ = other.node_; - } - explicit Map(std::shared_ptr n) : NodeRef(n) {} - template - Map(IterType begin, IterType end) { - assign(begin, end); - } - Map(std::initializer_list > init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - - template - Map(const std::unordered_map& init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - Map& operator=(Map && other) { - node_ = std::move(other.node_); - return *this; - } - Map& operator=(const Map & other) { - node_ = other.node_; - return *this; - } - template - void assign(IterType begin, IterType end) { - auto n = std::make_shared(); - for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first, - i->second.node_)); - } - node_ = std::move(n); - } - inline const V operator[](const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); - } - inline const V at(const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); - } - inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); - } - inline size_t count(const std::string& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key); - } - inline StrMapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { - node_ = std::make_shared( - *static_cast(node_.get())); - } - return static_cast(node_.get()); - } - inline void Set(const std::string& key, const V& value) { - StrMapNode* n = this->CopyOnWrite(); - n->data[key] = value.node_; - } - inline bool empty() const { - return size() == 0; - } - using ContainerType = StrMapNode; - - struct Ptr2NodeRef { - using ResultType = std::pair; - static inline ResultType convert(const std::pair< - std::string, - std::shared_ptr >& n) { - return std::make_pair(n.first, V(n.second)); - } - }; - - using iterator = IterAdapter< - Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>; - - /*! \return begin iterator */ - inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); - } - /*! \return end iterator */ - inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); - } - /*! \return begin iterator */ - inline iterator find(const std::string& key) const { - return iterator(static_cast(node_.get())->data.find(key)); - } -}; - -} // namespace tvm - -namespace HalideIR { -namespace IR { - -using tvm::Array; -using tvm::Map; - -} // namespace IR -} // namespace HalideIR - -#endif // TVM_CONTAINER_H_ diff --git a/src/tvm/ir_functor.h b/src/tvm/ir_functor.h deleted file mode 100644 index 97e9dcb..0000000 --- a/src/tvm/ir_functor.h +++ /dev/null @@ -1,264 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file ir_functor.h - * \brief Defines the IRFunctor data structures. - */ -#ifndef TVM_IR_FUNCTOR_H_ -#define TVM_IR_FUNCTOR_H_ - -#include -#include -#include -#include -#include "base/Debug.h" -#include "./node.h" - -namespace tvm { -/*! - * \brief A dynamical dispatched functor on NodeRef in the first argument. - * - * \code - * IRFunctor tostr; - * tostr.set_dispatch([](const Add* op, std::string prefix) { - * return prefix + "Add"; - * }); - * tostr.set_dispatch([](const IntImm* op) { - * return prefix + "IntImm" - * }); - * - * Expr x = make_const(1); - * Expr y = x + x; - * // dispatch to IntImm, outputs "MyIntImm" - * LOG(INFO) << tostr(x, "My"); - * // dispatch to IntImm, outputs "MyAdd" - * LOG(INFO) << tostr(y, "My"); - * \endcode - * - * \tparam FType function signiture - * This type if only defined for FType with function signiture - */ -template -class IRFunctor; - -template -class IRFunctor { - private: - using Function = std::function; - using TSelf = IRFunctor; - /*! \brief internal function table */ - std::vector func_; - - public: - /*! \brief the result type of this functor */ - using result_type = R; - /*! - * \brief Whether the functor can dispatch the corresponding Node - * \param n The node to be dispatched - * \return Whether dispatching function is registered for n's type. - */ - inline bool can_dispatch(const NodeRef& n) const { - uint32_t type_index = n.type_index(); - return type_index < func_.size() && func_[type_index] != nullptr; - } - /*! - * \brief invoke the functor , dispatch on type of n - * \param n The Node argument - * \param args The additional arguments - * \return The result. - */ - inline R operator()(const NodeRef& n, Args... args) const { - uint32_t type_index = n.type_index(); - internal_assert(type_index < func_.size() && - func_[type_index] != nullptr) - << "IRFunctor calls un-registered function on type " - << Node::TypeIndex2Key(type_index); - return func_[type_index](n, std::forward(args)...); - } - /*! - * \brief set the dispacher for type TNode - * \param f The function to be set. - * \tparam TNode the type of Node to be dispatched. - * \return reference to self. - */ - template - inline TSelf& set_dispatch(Function f) { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); - if (func_.size() <= tindex) { - func_.resize(tindex + 1, nullptr); - } - internal_assert(func_[tindex] == nullptr) - << "Dispatch for " << Node::TypeIndex2Key(tindex) - << " is already set"; - func_[tindex] = f; - return *this; - } - /*! - * \brief set the dispacher for type TNode - * This allows f to used detailed const Node pointer to replace NodeRef - * - * \param f The function to be set. - * \tparam TNode the type of Node to be dispatched. - * \return reference to self. - */ - template - inline TSelf& set_dispatch(std::function f) { // NOLINT(*) - Function fun = [f](const NodeRef& n, Args... args) { - return f(static_cast(n.node_.get()), - std::forward(args)...); - }; - return this->set_dispatch(fun); - } - /*! - * \brief unset the dispacher for type TNode - * - * \tparam TNode the type of Node to be dispatched. - * \return reference to self. - */ - template - inline TSelf& clear_dispatch() { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); - CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; - func_[tindex] = nullptr; - return *this; - } -}; - -#if defined(__GNUC__) -#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) -#else -#define TVM_ATTRIBUTE_UNUSED -#endif - -/*! \brief helper macro to generate string concat */ -#define TVM_STR_CONCAT_(__x, __y) __x##__y -#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) - -#define TVM_REGISTER_VAR_DEF(ClsName) \ - static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName - -/*! - * \brief Useful macro to set IRFunctor dispatch in a global static field. - * - * \code - * // Use IRFunctor to implement IRPrinter similar to Visitor Pattern. - * // vtable allows easy patch in of new Node types, without changing - * // interface of IRPrinter. - * - * class IRPrinter { - * public: - * std::ostream& stream; - * // the dispatch function. - * void print(Expr e) { - * const static FType& f = *vtable(); - * f(e, this); - * } - * - * using FType = IRFunctor; - * // function to return global function table - * static FType& vtable(); - * }; - * - * // in cpp/cc file - * IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*0 - * static FType inst; return inst; - * } - * - * TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) - * .set_dispatch([](const Add* n, IRPrinter* p) { - * p->print(n->a); - * p->stream << '+' - * p->print(n->b); - * }); - * - * - * \endcode - * - * \param ClsName The name of the class - * \param FField The static function that returns a singleton of IRFunctor. - */ -#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ - TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \ - ClsName::FField() - - /*! - * \brief A container for a list of callbacks. All callbacks are invoked when - * the object is destructed. - */ -class IRFunctorCleanList { -public: - ~IRFunctorCleanList() { - for (auto &f : clean_items) { - f(); - } - } - - void append(std::function func) { - clean_items.push_back(func); - } - -private: - std::vector< std::function > clean_items; -}; - -/*! -* \brief A wrapper around IRFunctor that will record calls to set_dispatch -* and make a corresponding call to clear_dispatch when the last copy of -* the IRFunctorStaticRegistry is destructed. When assigned to a static variable, -* this can be used by NNVM and other libraries to unregister callbacks when -* the library is unloaded. This prevents crashes when the underlying IRFunctor -* is destructed as it will no longer contain std::function instances allocated -* by a library that has been unloaded. -*/ -template -class IRFunctorStaticRegistry; - -template -class IRFunctorStaticRegistry { -private: - IRFunctor *irf_; - std::shared_ptr free_list; - - using TSelf = IRFunctorStaticRegistry; - -public: - IRFunctorStaticRegistry(IRFunctor *irf) { - irf_ = irf; - free_list = std::make_shared(); - } - - template - inline TSelf& set_dispatch(std::function f) { // NOLINT(*) - irf_->template set_dispatch(f); - auto irf_copy = irf_; - free_list.get()->append([irf_copy] { - irf_copy->template clear_dispatch(); - }); - return *this; - } -}; - -/*! -* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows -* the compiler to deduce the template types. -*/ -template -IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( - IRFunctor *irf) { - return IRFunctorStaticRegistry(irf); -} - -#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ - static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName - -/*! -* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry. -* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of -* TVM_STATIC_IR_FUNCTOR. -*/ -#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \ - TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \ - MakeIRFunctorStaticRegistry(&ClsName::FField()) - -} // namespace tvm - -#endif // TVM_IR_FUNCTOR_H_ diff --git a/src/tvm/node.cpp b/src/tvm/node.cpp deleted file mode 100644 index 6956bec..0000000 --- a/src/tvm/node.cpp +++ /dev/null @@ -1,58 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * Implementation of IR Node API - * \file node.cpp - */ -#include -#include -#include -#include -#include "./node.h" - -namespace tvm { - -namespace { -// single manager of operator information. -struct TypeManager { - // mutex to avoid registration from multiple threads. - // recursive is needed for trigger(which calls UpdateAttrMap) - std::mutex mutex; - std::atomic type_counter{0}; - std::unordered_map key2index; - std::vector index2key; - // get singleton of the - static TypeManager* Global() { - static TypeManager inst; - return &inst; - } -}; -} // namespace - -const bool Node::_DerivedFrom(uint32_t tid) const { - static uint32_t tindex = TypeKey2Index(Node::_type_key); - return tid == tindex; -} - -// this is slow, usually caller always hold the result in a static variable. -uint32_t Node::TypeKey2Index(const char* key) { - TypeManager *t = TypeManager::Global(); - std::lock_guard(t->mutex); - std::string skey = key; - auto it = t->key2index.find(skey); - if (it != t->key2index.end()) { - return it->second; - } - uint32_t tid = ++(t->type_counter); - t->key2index[skey] = tid; - t->index2key.push_back(skey); - return tid; -} - -const char* Node::TypeIndex2Key(uint32_t index) { - TypeManager *t = TypeManager::Global(); - std::lock_guard(t->mutex); - internal_assert(index != 0); - return t->index2key.at(index - 1).c_str(); -} - -} // namespace tvm diff --git a/src/tvm/node.h b/src/tvm/node.h deleted file mode 100644 index 1bd05b4..0000000 --- a/src/tvm/node.h +++ /dev/null @@ -1,329 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file node.h - * \brief Defines the Node data structures. - */ -#ifndef TVM_NODE_H_ -#define TVM_NODE_H_ - -#include -#include -#include -#include -#include "base/Type.h" - -/** namespace of tvm base code */ -namespace tvm { - -using HalideIR::Type; -// forward declaration -class Node; -class NodeRef; - -namespace runtime { -// forward declaration -class NDArray; -} // namespace runtime -/*! - * \brief Visitor class to each node content. - * The content is going to be called for each field. - */ -class EXPORT AttrVisitor { - public: -//! \cond Doxygen_Suppress - virtual void Visit(const char* key, double* value) = 0; - virtual void Visit(const char* key, int64_t* value) = 0; - virtual void Visit(const char* key, uint64_t* value) = 0; - virtual void Visit(const char* key, int* value) = 0; - virtual void Visit(const char* key, bool* value) = 0; - virtual void Visit(const char* key, std::string* value) = 0; - virtual void Visit(const char* key, void** value) = 0; - virtual void Visit(const char* key, Type* value) = 0; - virtual void Visit(const char* key, NodeRef* value) = 0; - virtual void Visit(const char* key, runtime::NDArray* value) = 0; - template::value>::type> - void Visit(const char* key, ENum* ptr) { - static_assert(std::is_same::type>::value, - "declare enum to be enum int to use visitor"); - this->Visit(key, reinterpret_cast(ptr)); - } -//! \endcond -}; - -/*! - * \brief base class of node container in DSL AST. - * All object's internal is stored as std::shared_ptr - */ -class EXPORT Node : public std::enable_shared_from_this { - public: - /*! \brief virtual destructor */ - virtual ~Node() {} - /*! \return The unique type key of the node */ - virtual const char* type_key() const = 0; - /*! - * \brief Apply visitor to each field of the Node - * Visitor could mutate the content of the node. - * override if Node contains attribute fields. - * \param visitor The visitor - */ - virtual void VisitAttrs(AttrVisitor* visitor) {} - /*! \return the type index of the node */ - virtual const uint32_t type_index() const = 0; - /*! - * \brief Whether this node derives from node with type_index=tid. - * Implemented by TVM_DECLARE_NODE_TYPE_INFO - * - * \param tid The type index. - * \return the check result. - */ - virtual const bool _DerivedFrom(uint32_t tid) const; - /*! - * \brief get a runtime unique type index given a type key - * \param type_key Type key of a type. - * \return the corresponding type index. - */ - static uint32_t TypeKey2Index(const char* type_key); - /*! - * \brief get type key from type index. - * \param index The type index - * \return the corresponding type key. - */ - static const char* TypeIndex2Key(uint32_t index); - /*! - * \return whether the type is derived from - */ - template - inline bool derived_from() const; - /*! - * \return whether the node is of type T - * \tparam The type to be checked. - */ - template - inline bool is_type() const; - /*! - * \brief Get a NodeRef that holds reference to this Node. - * - * \note This is enabled by enable_shared_from_this. - * \return the NodeRef - */ - inline NodeRef GetNodeRef() const; - // node ref can see this - friend class NodeRef; - static constexpr const char* _type_key = "Node"; -}; - - -/*! \brief base class of all node reference object */ -class NodeRef { - public: - /*! \brief type indicate the container type */ - using ContainerType = Node; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator==(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool same_as(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator<(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator!=(const NodeRef& other) const; - /*! \return the hash function for NodeRef */ - inline size_t hash() const; - /*! \return whether the expression is null */ - inline bool defined() const; - /*! \return the internal type index of IRNode */ - inline uint32_t type_index() const; - /*! \return the internal node pointer */ - inline const Node* get() const; - /*! \return the internal node pointer */ - inline const Node* operator->() const; - /*! - * \brief Downcast this ir node to its actual type (e.g. Add, or - * Select). This returns nullptr if the node is not of the requested - * type. Example usage: - * - * if (const Add *add = node->as()) { - * // This is an add node - * } - * - * \note This function only works if T is the final type, - * use as_derived when T can also be base types. - * \tparam T the target type, must be subtype of IRNode - */ - template - inline const T *as() const; - - /*! - * \brief A more powerful version of as that also works with - * intermediate base types. - * \tparam T the target type, must be subtype of IRNode - */ - template - inline const T *as_derived() const; - - /*! \brief default constructor */ - NodeRef() = default; - explicit NodeRef(std::shared_ptr node) : node_(node) {} - - /*! \brief the internal node object, do not touch */ - std::shared_ptr node_; -}; - -/*! - * \brief helper macro to declare type information in a base node. - */ -#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ - const bool _DerivedFrom(uint32_t tid) const override { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } - -/*! - * \brief helper macro to declare type information in a terminal node - */ -#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ - const char* type_key() const final { \ - return TypeName::_type_key; \ - } \ - const uint32_t type_index() const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - return tidx; \ - } \ - const bool _DerivedFrom(uint32_t tid) const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } - -// implementations of inline functions after this -template -inline bool Node::is_type() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return type_id == this->type_index(); -} - -template -inline bool Node::derived_from() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return this->_DerivedFrom(type_id); -} - -inline NodeRef Node::GetNodeRef() const { - // const_cast because NodeRef requires std::shared_ptr, - // This is fine as NodeRef mostly only gives you back const Node*, - // of course things can be breached as it is C++ - return NodeRef(const_cast(this)->shared_from_this()); -} - -inline const Node* NodeRef::get() const { - return node_.get(); -} - -inline const Node* NodeRef::operator->() const { - return node_.get(); -} - -inline bool NodeRef::defined() const { - return node_.get() != nullptr; -} - -inline bool NodeRef::operator==(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} - -inline bool NodeRef::same_as(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} - -inline bool NodeRef::operator<(const NodeRef& other) const { - return node_.get() < other.node_.get(); -} - -inline bool NodeRef::operator!=(const NodeRef& other) const { - return node_.get() != other.node_.get(); -} - -inline size_t NodeRef::hash() const { - return std::hash()(node_.get()); -} - -inline uint32_t NodeRef::type_index() const { - internal_assert(node_.get() != nullptr) - << "null type"; - return get()->type_index(); -} - -template -inline const T* NodeRef::as() const { - const Node* ptr = static_cast(get()); - if (ptr && ptr->is_type()) { - return static_cast(ptr); - } - return nullptr; -} - -template -inline const T* NodeRef::as_derived() const { - const Node* ptr = static_cast(get()); - if (ptr && (ptr->is_type() || ptr->derived_from())) { - return static_cast(ptr); - } - return nullptr; -} - -/*! \brief The hash function for nodes */ -struct NodeHash { - size_t operator()(const NodeRef& a) const { - return a.hash(); - } -}; - -/*! \brief The equal comparator for nodes */ -struct NodeEqual { - bool operator()(const NodeRef& a, const NodeRef& b) const { - return a.get() == b.get(); - } -}; -} // namespace tvm - -// expose the data structure to HalideIR -namespace HalideIR { -namespace IR { - -using tvm::Node; -using tvm::NodeRef; -using tvm::AttrVisitor; - -} // namespace IR -} // namespace HalideIR - -namespace std { -template <> -struct hash<::tvm::NodeRef> { - std::size_t operator()(const ::tvm::NodeRef& k) const { - return k.hash(); - } -}; - -} // namespace std - -#endif // TVM_NODE_H_