This repository has been archived by the owner on Feb 1, 2020. It is now read-only.
/
graph.h
295 lines (281 loc) · 9.4 KB
/
graph.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
/*!
* Copyright (c) 2016 by Contributors
* \file graph.h
* \brief Configuation of nnvm as well as basic data structure.
*/
#ifndef NNVM_GRAPH_H_
#define NNVM_GRAPH_H_
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "./base.h"
#include "./node.h"
#include "./symbolic.h"
namespace nnvm {
class IndexedGraph;
/*!
* \brief Symbolic computation graph.
* This is the intermediate representation for optimization pass.
*/
class Graph {
public:
/*! \brief outputs of the computation graph. */
std::vector<NodeEntry> outputs;
/*!
* \brief attributes of a graph
* Note that attribute is shared pointer and can be shared across graphs.
*
* It is highly recommended to keep each attribute immutable.
* It is also safe to implement an copy-on-write semnatics.
*
* Copy when shared_ptr.unique is not true, while reuse original space
* when shared_ptr.unique is true.
*/
std::unordered_map<std::string, std::shared_ptr<any> > attrs;
/*!
* \brief Get the immutable attribute from attrs.
* \param attr_name the name of the attribute
* \return the reference to corresponding attribute
* \tparam T the type of the attribute.
*/
template<typename T>
inline const T& GetAttr(const std::string& attr_name) const;
/*!
* \brief Get a move copy of the attribute, implement copy on write semantics.
* The content is moved if the reference counter of shared_ptr is 1.
* The attribute is erased from attrs after the call.
*
* \param attr_name the name of the attribute
* \return a new copy of the corresponding attribute.
* \tparam T the type of the attribute.
*/
template<typename T>
inline T MoveCopyAttr(const std::string& attr_name);
/*!
* \brief get a indexed graph of current graph, if not exist, create it on demand
* \return The indexed graph.
* \sa IndexedGraph
*/
const IndexedGraph& indexed_graph();
private:
// internal structure of indexed graph
std::shared_ptr<const IndexedGraph> indexed_graph_;
};
/*!
* \brief Auxililary data structure to index a graph.
* It maps Nodes in the graph to consecutive integers node_id.
* It also maps IndexedGraph::NodeEntry to consecutive integer entry_id.
* This allows storing properties of Node and NodeEntry into
* compact vector and quickly access them without resorting to hashmap.
*
* The node_id and entry_rptr are the same as the JSON graph produced by SaveJSON Pass.
*/
class IndexedGraph {
public:
/*! \brief represents a data in the graph */
struct NodeEntry {
/*! \brief the source node id in the computation graph */
uint32_t node_id;
/*! \brief index of output from the source. */
uint32_t index;
/*! \brief version of the node */
uint32_t version;
};
/*! \brief Node data structure in IndexedGraph */
struct Node {
/*! \brief pointer to the source node */
const nnvm::Node* source;
/*! \brief inputs to the node */
array_view<NodeEntry> inputs;
/*! \brief control flow dependencies to the node */
array_view<uint32_t> control_deps;
};
/*! \return number of nodes in the graph */
inline size_t num_nodes() const {
return nodes_.size();
}
/*! \return total number of NodeEntry in the graph */
inline size_t num_node_entries() const {
return entry_rptr_.back();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param node_id The node index
* \param index the output index
* \return the unique index.
*/
inline uint32_t entry_id(uint32_t node_id, uint32_t index) const {
return entry_rptr_[node_id] + index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const NodeEntry& e) const {
return entry_rptr_[e.node_id] + e.index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const nnvm::NodeEntry& e) const {
return entry_rptr_[node_id(e.node.get())] + e.index;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline uint32_t node_id(const nnvm::Node* node) const {
return node2index_.at(node);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](uint32_t node_id) const {
return nodes_[node_id];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](const nnvm::Node* node) const {
return nodes_[node_id(node)];
}
/*! \return list of argument nodes */
inline const std::vector<uint32_t>& input_nodes() const {
return input_nodes_;
}
/*! \return list of mutable nodes */
inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return mutable_input_nodes_;
}
/*! \return list of output entries */
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
}
// disalllow copy assign
IndexedGraph(const IndexedGraph&) = delete;
private:
friend class Graph;
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// Node pointers in CSR structure.
std::vector<Node> nodes_;
// Index to all input nodes.
std::vector<uint32_t> input_nodes_;
// Index to all mutable input nodes.
std::unordered_set<uint32_t> mutable_input_nodes_;
// space to store the outputs entries
std::vector<NodeEntry> outputs_;
// mapping from node to index.
std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
// CSR pointer of node entries
std::vector<size_t> entry_rptr_;
// space to store input entries of each
std::vector<NodeEntry> input_entries_;
// control flow dependencies
std::vector<uint32_t> control_deps_;
};
/*!
* \brief perform a Post Order DFS visit to each node in the graph.
* This order is deterministic and is also topoligical sorted.
* \param heads The heads in the graph.
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)>
* \tparam FVisit The function type to perform the visit.
*/
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);
// inline function implementations
template<typename T>
inline const T& Graph::GetAttr(const std::string& attr_name) const {
auto it = attrs.find(attr_name);
CHECK(it != attrs.end())
<< "Cannot find attribute " << attr_name << " in the graph";
return nnvm::get<T>(*it->second);
}
template<typename T>
inline T Graph::MoveCopyAttr(const std::string& attr_name) {
auto it = attrs.find(attr_name);
CHECK(it != attrs.end())
<< "Cannot find attribute " << attr_name << " in the graph";
std::shared_ptr<any> sptr = it->second;
attrs.erase(it);
if (sptr.unique()) {
return std::move(nnvm::get<T>(*sptr));
} else {
return nnvm::get<T>(*sptr);
}
}
template <typename GNode, typename HashType,
typename FVisit, typename HashFunc,
typename InDegree, typename GetInput>
void PostOrderDFSVisit(const std::vector<GNode>& heads,
FVisit fvisit,
HashFunc hash,
InDegree indegree,
GetInput getinput) {
std::vector<std::pair<GNode, uint32_t> > stack;
std::unordered_set<HashType> visited;
for (auto& head : heads) {
HashType head_hash = hash(head);
if (visited.count(head_hash) == 0) {
stack.push_back(std::make_pair(head, 0));
visited.insert(head_hash);
}
while (!stack.empty()) {
std::pair<GNode, uint32_t>& back = stack.back();
if (back.second == indegree(back.first)) {
fvisit(back.first);
stack.pop_back();
} else {
const GNode& input = getinput(back.first, back.second++);
HashType input_hash = hash(input);
if (visited.count(input_hash) == 0) {
stack.push_back(std::make_pair(input, 0));
visited.insert(input_hash);
}
}
}
}
}
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) {
typedef const NodePtr* GNode;
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode {
return &e.node;
});
PostOrderDFSVisit<GNode, Node*>(
head_nodes,
[fvisit](GNode n) { fvisit(*n); }, // FVisit
[](GNode n)->Node* { return n->get(); }, // HashFunc
[](GNode n)->uint32_t { // InDegree
return (*n)->inputs.size() + (*n)->control_deps.size();
},
[](GNode n, uint32_t index)->GNode { // GetInput
if (index < (*n)->inputs.size()) {
return &(*n)->inputs.at(index).node;
} else {
return &(*n)->control_deps.at(index - (*n)->inputs.size());
}
});
}
} // namespace nnvm
#endif // NNVM_GRAPH_H_