diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index 4299501bc5c6..3283943f34e1 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -211,7 +211,7 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): work_load_list=self._work_load_list) module.bind(data_shapes, label_shapes, self._curr_module.for_training, self._curr_module.inputs_need_grad, - force_rebind=False, shared_module=self._curr_module) + force_rebind=False, shared_module=self._buckets[self._default_bucket_key]) self._buckets[bucket_key] = module self._curr_module = self._buckets[bucket_key] diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index e0f2085001b5..7e096f68e338 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -321,16 +321,24 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const std::vector& grad_req_type, const std::vector& aux_states, Executor* shared_exec) { + std::vector shared_pool; + if (shared_exec != nullptr) { + for (auto& nd : dynamic_cast(shared_exec)->data_pool_) { + size_t bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype()); + shared_pool.emplace_back(nd.ctx().dev_id, bytes); + } + } + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_args, arg_grad_store, - grad_req_type, aux_states); + grad_req_type, aux_states, shared_pool); g = AttachOpExecs(g); g = AttachOpResources(g); graph_ = std::move(g); if (shared_exec != nullptr) { - this->InitDataEntryMemory(dynamic_cast(shared_exec)->data_pool_); + this->InitDataEntryMemory((&dynamic_cast(shared_exec)->data_pool_)); } else { - this->InitDataEntryMemory({}); + this->InitDataEntryMemory(nullptr); } { // initialize output arrays @@ -356,9 +364,11 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, const std::vector& in_args, const std::vector& arg_grad_store, const std::vector& grad_req_type, - const std::vector& aux_states) { + const std::vector& aux_states, + const std::vector shared_pool) { // setup gradient nnvm::Graph g = InitFullGraph(symbol, grad_req_type, arg_grad_store); + g.attrs["shared_pool"] = std::make_shared(shared_pool); g = AssignContext(g, default_ctx, ctx_map, in_args, grad_store_, @@ -420,7 +430,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, } // initialize the memory of each entries -void GraphExecutor::InitDataEntryMemory(const std::vector& shared_pool) { +void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { using nnvm::DTypeVector; using nnvm::ShapeVector; using nnvm::StorageVector; @@ -475,9 +485,11 @@ void GraphExecutor::InitDataEntryMemory(const std::vector& shared_pool) } // construct the re-use pool, if needed std::multimap free_pool; - for (const NDArray& nd : shared_pool) { - size_t bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype()); - free_pool.insert(std::make_pair(bytes, nd)); + if (shared_pool != nullptr) { + for (const NDArray& nd : *shared_pool) { + size_t bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype()); + free_pool.insert(std::make_pair(bytes, nd)); + } } // remake the data pool data_pool_.clear(); @@ -498,7 +510,12 @@ void GraphExecutor::InitDataEntryMemory(const std::vector& shared_pool) CHECK_LE(nword, std::numeric_limits::max()); // allocate float arrays TShape shape{index_t(nword)}; - data_pool_.emplace_back(NDArray(shape, ctx)); + NDArray nd(shape, ctx); + data_pool_.push_back(nd); + // put the new allocated arrays to shared pool + if (shared_pool != nullptr) { + shared_pool->push_back(nd); + } } } CHECK_EQ(data_pool_.size(), pool_info.size()); diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index cae7c28aafd6..0f07f9716d5c 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -28,6 +28,8 @@ using nnvm::Graph; class GraphExecutor : public Executor { public: using Executor::MonitorCallback; + using SharedStorageEntry = std::pair; + virtual ~GraphExecutor(); void Forward(bool is_train) override; void PartialForward(bool is_train, int step, int *step_left) override; @@ -66,7 +68,8 @@ class GraphExecutor : public Executor { const std::vector& in_args, const std::vector& arg_grad_store, const std::vector& grad_req_type, - const std::vector& aux_states); + const std::vector& aux_states, + const std::vector shared_pool); // initialize the full graph, including gradient. Graph InitFullGraph(nnvm::Symbol symbol, const std::vector& grad_req_type, @@ -76,7 +79,7 @@ class GraphExecutor : public Executor { // initialize the resources in the graph // initialize the memory of data entries // shared_pool: extra memory shared from other parts - void InitDataEntryMemory(const std::vector& shared_pool); + void InitDataEntryMemory(std::vector* shared_pool); // run ops from topo order start to end void RunOps(bool is_train, size_t topo_start, size_t topo_end); // internal graph