This repository has been archived by the owner on Feb 1, 2020. It is now read-only.
/
node.h
131 lines (119 loc) · 3.39 KB
/
node.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/*!
* Copyright (c) 2016 by Contributors
* \file node.h
* \brief Graph node data structure.
*/
#ifndef NNVM_NODE_H_
#define NNVM_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "./base.h"
#include "./op.h"
namespace nnvm {
// Forward declare node.
class Node;
/*!
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case.
*
* By default, NodePtr is a std::shared_ptr of node
*/
using NodePtr = std::shared_ptr<Node>;
/*! \brief an entry that represents output data from a node */
struct NodeEntry {
/*! \brief the source node of this data */
NodePtr node;
/*! \brief index of output from the source. */
uint32_t index;
/*!
* \brief version of input Variable.
* This field can only be nonzero when this->node is a Variable node.
* version is increased by one each time a Variable get composed to a mutation Op.
* This information can be helpful to decide order of operations when sequence of mutation happens.
*/
uint32_t version;
};
/*!
* \brief The attributes of the current operation node.
* Usually are additional parameters like axis,
*/
struct NodeAttrs {
/*!
* \brief The operator this node uses.
* For place holder variable, op == nullptr.
*/
const Op *op{nullptr};
/*! \brief name of the node */
std::string name;
/*! \brief Vector representation of positional attributes */
std::vector<double> scalars;
/*! \brief The dictionary representation of attributes */
std::unordered_map<std::string, std::string> dict;
/*!
* \brief A parsed version of attributes,
* This is generated if OpProperty.attr_parser is registered.
* The object can be used to quickly access attributes.
*/
any parsed;
};
/*!
* \brief Node represents an operation in a computation graph.
*/
class Node {
public:
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief inputs to this node */
std::vector<NodeEntry> inputs;
/*!
* \brief Optional control flow dependencies
* Gives operation must be performed before this operation.
*/
std::vector<NodePtr> control_deps;
/*! \brief destructor of node */
~Node();
/*! \return operator in this node */
inline const Op* op() const;
/*!
* \brief return whether node is placeholder variable.
* This is equivalent to op == nullptr
* \return whether node is placeholder input variable
*/
inline bool is_variable() const;
/*! \return number of outputs from this node */
inline uint32_t num_outputs() const;
/*! \return number of inputs from this node */
inline uint32_t num_inputs() const;
/*!
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
*/
static NodePtr Create();
};
// implementation of functions.
inline const Op* Node::op() const {
return this->attrs.op;
}
inline bool Node::is_variable() const {
return this->op() == nullptr;
}
inline uint32_t Node::num_outputs() const {
if (is_variable()) return 1;
if (this->op()->get_num_outputs == nullptr) {
return this->op()->num_outputs;
} else {
return this->op()->get_num_outputs(this->attrs);
}
}
inline uint32_t Node::num_inputs() const {
if (is_variable()) return 1;
if (this->op()->get_num_inputs == nullptr) {
return this->op()->num_inputs;
} else {
return this->op()->get_num_inputs(this->attrs);
}
}
} // namespace nnvm
#endif // NNVM_NODE_H_