This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
graph_executor.h
294 lines (288 loc) · 10.5 KB
/
graph_executor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
/*!
* Copyright (c) 2015 by Contributors
* \file graph_executor.h
* \brief Executor to execute the Forward and Backward on Composition Graph.
*/
#ifndef MXNET_SYMBOL_GRAPH_EXECUTOR_H_
#define MXNET_SYMBOL_GRAPH_EXECUTOR_H_
#include <mxnet/c_api.h>
#include <mxnet/symbolic.h>
#include <memory>
#include <string>
#include <vector>
#include <map>
#include <utility>
#include "./static_graph.h"
#include "./graph_memory_allocator.h"
#if MKL_EXPERIMENTAL == 1
#include <mkl_memory.h>
#endif
namespace mxnet {
/*!
* \brief Executor of a computation graph.
*/
class GraphExecutor : public Executor {
public:
GraphExecutor() {}
virtual ~GraphExecutor();
void Forward(bool is_train) override;
void PartialForward(bool is_train, int step, int *step_left) override;
void Backward(const std::vector<NDArray> &head_grads) override;
const std::vector<NDArray> &outputs() const override {
return heads_ndarray_;
}
void Print(std::ostream &os) const override; // NOLINT(*)
// install callback
void SetMonitorCallback(const MonitorCallback& callback) {
CHECK(callback) << "invalid callback";
monitor_callback_ = callback;
}
// implement Executor::Bind, only call it once.
inline void Init(Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray> &in_args,
const std::vector<NDArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states,
Executor* shared_exec = nullptr) {
enable_inplace_allocation_ = dmlc::GetEnv("MXNET_EXEC_ENABLE_INPLACE", true);
prefer_bulk_execution_ = dmlc::GetEnv("MXNET_EXEC_PREFER_BULK_EXEC", true);
if (shared_exec != NULL) {
GraphExecutor* gexec = dynamic_cast<GraphExecutor*>(shared_exec);
CHECK(gexec) << "Input executor for sharing memory must have GraphExecutor type.";
shared_mem_ = gexec->shared_mem_;
} else {
shared_mem_ = std::make_shared<GraphStoragePool>();
}
CHECK_EQ(grad_req_type.size(), arg_grad_store.size());
bool need_backward = false;
for (auto req : grad_req_type) {
if (req != kNullOp) need_backward = true;
}
this->InitGraph(symbol, default_ctx, ctx_map,
in_args, arg_grad_store, grad_req_type,
need_backward);
this->InitDataEntryInfo(in_args, arg_grad_store, grad_req_type, aux_states);
this->InitOperators();
this->InitDataEntryMemory();
this->InitResources();
this->InitCachedOps();
this->InitOpSegs();
}
protected:
// internal class of wrapping BackwardOp as ForwardOp
class BackwardOpWrapper;
// type of data entry
enum DataEntryType {
// memory is bound by external NDArray in Bind
kBindByExternal,
// to be bound by external NDArray in Forward and Backward
kTobeBindByExternal,
// internal memory, allocated
kInternalAllocated,
// internal memory, to be allocated
kNotInitialized
};
// Additional information about each data entry
struct DataEntryInfo {
// the actual data for the entry
NDArray data;
// mkl private memory holder
#if MKL_EXPERIMENTAL == 1
std::shared_ptr<MKLMemHolder> mkl_mem_;
#endif
// write request to this entry
OpReqType op_req;
// the operatio node that will take
// this DataEntry as inplace input
int inplace_op_id;
// data entry type
DataEntryType type;
// shape of this entry
TShape shape;
// data type of this entry
int type_flag;
// storage id from allocator if it is internal allocation.
GraphStorageAllocator::StorageID storage_id;
// reference count on how many times this entry is being used.
// That is how many operators and heads need this DataEntry
// this is a temporal variable that is used during initialization.
uint32_t temp_ref_count;
// real permanent ref count
uint32_t ref_count;
// constructor
DataEntryInfo()
: op_req(kNullOp),
inplace_op_id(-1),
type(kNotInitialized),
storage_id(GraphStorageAllocator::kBadStorageID),
temp_ref_count(0), ref_count(0) {
#if MKL_EXPERIMENTAL == 1
mkl_mem_ = MKLMemHolder::create();
#endif
}
};
// all the information needed to push the op to engine
struct OpExecEntry {
// execution function for
Engine::AsyncFn exec_fun;
// variables to read from
std::vector<Engine::VarHandle> use_vars;
// variables to mutate
std::vector<Engine::VarHandle> mutate_vars;
// constructor
OpExecEntry() : exec_fun(nullptr) {}
};
// Information about operational node
struct OpNode {
// whether this op node is activated
bool activated;
// the context of the node
Context ctx;
// data entry information about outputs of op
std::vector<DataEntryInfo> outputs;
// auxiliary data information of op
std::vector<DataEntryInfo> aux_states;
// The following parts are constructed in InitOpNodes
// the real operator
std::shared_ptr<Operator> op;
// op context, that is defined for this op.
OpContext op_ctx;
// executor, this is only allocated for nodes
// whose inputs, outputs are pre-defined.
// otherwise cached_exec.exec_fun == nullptr
OpExecEntry cached_exec;
// cached operator handle
Engine::OprHandle cached_opr{nullptr};
// constructor
OpNode() : activated(false) {}
// Manual option for delete operator
// need to do this before delete NDArrays
inline void DeleteOperator() {
if (cached_opr != nullptr) {
Engine::Get()->DeleteOperator(cached_opr);
cached_opr = nullptr;
}
}
};
// a cached segment operator that executes a segment
struct CachedSegOpr {
// context of the operator
Context ctx;
// begin in topo order
size_t topo_begin;
// end in topo order
size_t topo_end;
// the cached operator
Engine::OprHandle opr;
};
/*!
* \brief Get input option of a node.
* This function is overriden for both Forward and Backward node.
*
* \param node_id node index of node in StaticGraph
* \param in_data the input data entry to the node
* \param out_data the output data entry in the graph
* \return the paired inplace option.
*/
template<typename T>
inline std::vector<std::pair<T, T> > GetInplaceOption(
uint32_t node_id,
const std::vector<T> &in_data,
const std::vector<T> &out_data) const;
/*!
* \brief Get resource requirement of a node.
* This function is overriden for both Forward and Backward node.
* \param node_id node index of node in StaticGraph
* \return the desired resource request.
*/
inline std::vector<ResourceRequest> GetResource(uint32_t node_id) const;
/*!
* \brief Get number of outputs of a node.
* This function is overriden for both Forward and Backward node.
* \param node_id node index of node in StaticGraph
* \return the number of outputs of the node.
*/
inline int GetNumOutputs(uint32_t node_id) const;
/*!
* \brief get execution entry for an OpNode.
* This function can only be called after initialization is done.
* \param node_id the id of operational node.
* \return the execution entry.
*/
inline OpExecEntry GetOpExecEntry(uint32_t node_id);
/*!
* \brief Try to create a cached operator to run segments between start and end
* \param topo_start beginning of segment
* \param topo_end end of segment
* \return the cached operator.
* The ret.opr can be nullptr if tyhe creation failed
*/
CachedSegOpr CreateCachedSegOpr(size_t topo_start, size_t topo_end);
// initialize the internal graph structure
void InitGraph(const Symbol &symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray> &in_args,
const std::vector<NDArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type,
bool need_backward);
// initialize internal DataEntryInfo, reference counting
void InitDataEntryInfo(const std::vector<NDArray> &in_args,
const std::vector<NDArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states);
// initialize internal data entries NDArray
void InitDataEntryMemory();
// initialize the internal resources for each op
void InitResources();
// initialize OpNode data structure
void InitOperators();
// initialize OpNode data structure
void InitCachedOps();
// initialize segments of code to run together as a group.
void InitOpSegs();
// assign context to the graph, this will mutate the graph.
void AssignContext(const Context default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray> &in_args,
const std::vector<NDArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type,
std::vector<Context> *ctx_plan);
// run ops from topo order start to end
void RunOps(bool is_train, size_t topo_start, size_t topo_end);
// internal computational graph
StaticGraph graph_;
// topological order of nodes in computation graph
// backward nodes always follow forward nodes
std::vector<uint32_t> topo_order_;
// whether to enable inplace space
bool enable_inplace_allocation_;
// total allocated space in bytes
size_t total_allocated_bytes_;
// total allocated temp space
size_t total_allocated_temp_;
// number of forward nodes in the graph
size_t num_forward_nodes_;
// whether to enable bulk execution
bool prefer_bulk_execution_;
// head gradient node in the graph, if there is backward pass
std::vector<uint32_t> head_grad_nodes_;
// mirror map of nodes, experimental feature, normally can be ignored.
std::map<uint32_t, uint32_t> mirror_source_map_;
// argument node in the graph, if there is backward pass
std::vector<StaticGraph::DataEntry> arg_grads_;
// operational nodes
std::vector<OpNode> op_nodes_;
// head NDArrays
std::vector<NDArray> heads_ndarray_;
// shared NDArrays
std::shared_ptr<GraphStoragePool> shared_mem_;
// monitor call back
std::function<void(const char*, void*)> monitor_callback_;
// cached segment operator
std::vector<CachedSegOpr> cached_seg_opr_;
}; // class GraphExecutor
} // namespace mxnet
#endif // MXNET_SYMBOL_GRAPH_EXECUTOR_H_