From 4c2ece12fa652f573e1c8cfb24bd8a95f9699f96 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Sat, 2 Apr 2022 02:03:58 +0000 Subject: [PATCH] [DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run --- .../eager/accumulation/accumulation_node.h | 15 +-- .../eager_generated/backwards/scale_node.h | 5 - .../auto_code_generator/eager_generator.cc | 7 +- .../final_state_generator/eager_gen.py | 14 +-- paddle/fluid/eager/backward.cc | 100 +++++++++++++++++- .../custom_operator/custom_operator_node.h | 11 +- paddle/fluid/eager/grad_node_info.cc | 4 + paddle/fluid/eager/grad_node_info.h | 27 ++++- paddle/fluid/eager/pylayer/py_layer_node.h | 5 - .../data_structure_tests/grad_node_test.h | 6 +- .../eager/to_static/run_program_op_node.h | 4 - 11 files changed, 147 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/eager/accumulation/accumulation_node.h b/paddle/fluid/eager/accumulation/accumulation_node.h index 2e38d7e9e91e2..38d5533c3d606 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.h +++ b/paddle/fluid/eager/accumulation/accumulation_node.h @@ -25,7 +25,10 @@ class GradNodeAccumulation : public GradNodeBase { // Constructor: configure fwd input tensors to grad node explicit GradNodeAccumulation(AutogradMeta* meta) : GradNodeBase(1, 1) { VLOG(6) << "Construct GradNodeAccumulation"; - weak_grad_ = meta->WeakGrad(); + if (meta) { + weak_grad_ = meta->WeakGrad(); + } + SetDefaultGradInOutMeta(); } @@ -40,11 +43,6 @@ class GradNodeAccumulation : public GradNodeBase { void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } - std::string name() { return "GradNodeAccumulation"; } /** @@ -58,6 +56,11 @@ class GradNodeAccumulation : public GradNodeBase { inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; } void ApplyReduceHooks(); + std::shared_ptr Copy() const override { + return std::shared_ptr( + new GradNodeAccumulation(nullptr)); + } + private: std::weak_ptr weak_grad_; diff --git a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h index 0b942d2a06707..25293d7bdbf1a 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h +++ b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h @@ -44,11 +44,6 @@ class GradNodeScale : public GradNodeBase { void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } - void SetTensorWrappers_X( const std::vector& tensors); diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 3c322565884f2..303d6482500f9 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -2479,7 +2479,7 @@ static std::string GenerateGradNodeHeaderContents( "\n" " void ClearTensorWrappers() override { \n" "%s\n" - " is_tensor_wrappers_cleared = true;\n" + " SetIsTensorWrappersCleared(true);\n" " }\n" " std::string name() override { return \" GradNode%s \"; } \n " "\n" @@ -2487,14 +2487,9 @@ static std::string GenerateGradNodeHeaderContents( "%s\n" " // SetAttrMap\n" "%s\n" - " bool IsTensorWrappersCleared() override { \n" - " return is_tensor_wrappers_cleared;\n" - " }\n" " private:\n" " // TensorWrappers\n" "%s\n" - " bool is_tensor_wrappers_cleared = false;\n" - "\n" " // Attribute Map\n" "%s\n" "};"; diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index fb86c5da6856c..c18e8936b2219 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -125,7 +125,12 @@ class {} : public egr::GradNodeBase {{ void ClearTensorWrappers() override {{ {} - is_tensor_wrappers_cleared = true; + SetIsTensorWrappersCleared(true); + }} + + std::shared_ptr Copy() const override {{ + auto copied_node = std::make_shared<{}>(*this); + return copied_node; }} // SetTensorWrapperX, SetTensorWrapperY, ... @@ -133,15 +138,10 @@ class {} : public egr::GradNodeBase {{ // SetAttributes {} - bool IsTensorWrappersCleared() override {{ - return is_tensor_wrappers_cleared; - }} private: // TensorWrappers {} - bool is_tensor_wrappers_cleared = false; - // Attributes {} }}; @@ -1212,7 +1212,7 @@ def GenerateNodeDeclaration(self): grad_node_name = GetGradNodeName(forward_op_name) self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format( grad_node_name, grad_node_name, grad_node_name, grad_node_name, - grad_node_name, clear_tensor_wrapper_str, + grad_node_name, clear_tensor_wrapper_str, grad_node_name, set_tensor_wrapper_methods_str, set_attribute_methods_str, tensor_wrapper_members_str, attribute_members_str) diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 0ce2f17cb45be..12ff48a06ffb5 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -50,7 +50,15 @@ class GeneralGrad { for (size_t i = 0; i < num_inputs; i++) { AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(inputs[i]); - auto target_node = auto_grad_meta->GetMutableGradNode().get(); + auto* orig_target_node = auto_grad_meta->GetMutableGradNode().get(); + + PADDLE_ENFORCE( + orig_to_copied_node_mapping_.count(orig_target_node), + paddle::platform::errors::InvalidArgument( + "Unable to find target node in orig_to_copied_node_mapping_." + "Most likely the starting nodes were not copied correctly.")); + auto* target_node = orig_to_copied_node_mapping_[orig_target_node]; + PADDLE_ENFORCE_NOT_NULL(target_node, paddle::platform::errors::Fatal( "There is no grad op for %s:[%d] or it's" @@ -249,7 +257,14 @@ class GeneralGrad { for (size_t i = 0; i < inputs.size(); ++i) { auto& input = inputs[i]; AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input); - auto target_node = auto_grad_meta->GetMutableGradNode().get(); + auto* orig_target_node = auto_grad_meta->GetMutableGradNode().get(); + + PADDLE_ENFORCE( + orig_to_copied_node_mapping_.count(orig_target_node), + paddle::platform::errors::InvalidArgument( + "Unable to find target node in orig_to_copied_node_mapping_." + "Most likely the starting nodes were not copied correctly.")); + auto* target_node = orig_to_copied_node_mapping_[orig_target_node]; auto iter = results_map.find(target_node); if (iter != results_map.end()) { @@ -328,6 +343,64 @@ class GeneralGrad { results_map.clear(); } + GradNodeBase* CopyGradNode(const std::shared_ptr& orig_node) { + if (orig_to_copied_node_mapping_.count(orig_node.get())) { + return orig_to_copied_node_mapping_[orig_node.get()]; + } + std::shared_ptr copied_node = orig_node->Copy(); + + // Save node and update mapping + orig_to_copied_node_mapping_[orig_node.get()] = copied_node.get(); + copied_grad_nodes_.push_back(copied_node); + + return copied_node.get(); + } + + void ReconstructBackwardGraph( + const std::queue& orig_init_queue) { + std::queue queue = orig_init_queue; + + // BFS and recursively copy the grad nodes + while (!queue.empty()) { + GradNodeBase* orig_node = queue.front(); + queue.pop(); + + PADDLE_ENFORCE( + orig_to_copied_node_mapping_.count(orig_node), + paddle::platform::errors::Fatal( + "Cannot reconstruct backward graph," + "unable to find copied target for certain grad node.")); + GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node]; + + const std::vector>& orig_edges = orig_node->GetEdges(); + std::vector>& copied_edges = + copied_node->GetMutableEdges(); + for (size_t i = 0; i < orig_edges.size(); i++) { + for (size_t j = 0; j < orig_edges[i].size(); j++) { + const Edge& orig_edge = orig_edges[i][j]; + Edge& copied_edge = copied_edges[i][j]; + + std::shared_ptr orig_next_node = + orig_edge.GetMutableGradNode(); + if (!orig_next_node) continue; + + // Copy Next Node + std::shared_ptr copied_next_node = + orig_next_node->Copy(); + orig_to_copied_node_mapping_[orig_next_node.get()] = + copied_next_node.get(); + copied_grad_nodes_.push_back(copied_next_node); + + // Update Edge's Grad Node + copied_edge.SetGradNode(copied_next_node); + + // Update BFS queue + queue.push(orig_next_node.get()); + } + } + } + } + private: GeneralGrad() = default; static GeneralGrad* general_grad_; @@ -345,6 +418,10 @@ class GeneralGrad { std::unordered_set /* pre nodes */> depending_nodes; std::unordered_map results_map; + + std::vector> copied_grad_nodes_; + std::unordered_map orig_to_copied_node_mapping_; + DISABLE_COPY_AND_ASSIGN(GeneralGrad); }; @@ -444,6 +521,7 @@ std::vector RunBackward( // 1. Init queue with starting nodes // 2. Prepare initial input buffers std::queue queue; + std::queue orig_queue; std::unordered_map> node_input_buffers_dict; for (size_t i = 0; i < tensors.size(); i++) { @@ -467,6 +545,16 @@ std::vector RunBackward( } GradNodeBase* grad_node = shared_grad_node.get(); + if (is_general_grad) { + // Save orig grad node + orig_queue.push(grad_node); + + // Replace grad_node with copied grad_node + grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node); + + // Record potential startup grad node + GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node); + } // Prepare GradTensorHolder if (!node_input_buffers_dict.count(grad_node)) { @@ -510,9 +598,11 @@ std::vector RunBackward( // Prepare queue, potential startup_nodes queue.push(grad_node); - if (is_general_grad) { - GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node); - } + } + + if (is_general_grad) { + // Copy Backward Graph + GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue); } VLOG(6) << "Update In degree Map for backward"; diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.h b/paddle/fluid/eager/custom_operator/custom_operator_node.h index 33b56fc8c863a..2637c0f8eeaa6 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.h +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.h @@ -36,9 +36,10 @@ class RunCustomOpNode : public GradNodeBase { } // Functor: perform backward computations - virtual std::vector> operator()( - std::vector>& grads, - bool create_graph = false) // NOLINT + virtual std::vector> + operator()( // NOLINT + std::vector>& grads, // NOLINT + bool create_graph = false) // NOLINT override; std::string name() { @@ -64,10 +65,6 @@ class RunCustomOpNode : public GradNodeBase { } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } void SetAttrs(const std::vector& attr) { attrs_ = attr; } diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 5f3dfe8e513ed..4e7dec9afbe7e 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -322,6 +322,10 @@ const std::vector>& GradNodeBase::GetEdges() const { return adj_edges_; } +std::vector>& GradNodeBase::GetMutableEdges() { + return adj_edges_; +} + std::vector> GradNodeBase::ApplyGradientHooks( const std::vector>& tensors) { diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 0d07f780dda9d..0d11583dce32d 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -113,7 +113,16 @@ class GradNodeBase { virtual void ClearTensorWrappers() = 0; - virtual bool IsTensorWrappersCleared() = 0; + /** + * Self-Copy interface designed for use in DoubleGrad + * **/ + virtual std::shared_ptr Copy() const { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Self-copy not supported for current GradNode." + "Please override GradNodeBase::Copy() method if necessary.")); + return nullptr; + } + /** * AddEdges is designed to set input tensors' backward Node as current * node's Edges. @@ -191,6 +200,16 @@ class GradNodeBase { /** * GetEdges is designed to get all edges of current node**/ const std::vector>& GetEdges() const; + std::vector>& GetMutableEdges(); + + /** + * The following interfaces are designed for no_need_buffer + * **/ + bool IsTensorWrappersCleared() { return is_tensor_wrappers_cleared_; } + + void SetIsTensorWrappersCleared(bool is_tensor_wrappers_cleared) { + is_tensor_wrappers_cleared_ = is_tensor_wrappers_cleared; + } private: // TODO(zhanlve): Merge adj_edges_ into GradOutMeta @@ -218,6 +237,7 @@ class GradNodeBase { // We handle complex to real conversion only if any complex GradIn is involved bool need_complex_to_real_ = false; int64_t next_hook_id_{0}; + bool is_tensor_wrappers_cleared_ = false; }; class Edge { @@ -246,6 +266,11 @@ class Edge { return grad_node_; } + void SetGradNode(const std::shared_ptr& node) { + VLOG(6) << "Reseting Edge's Grad Node"; + grad_node_ = node; + } + std::pair GetEdgeRankInfo() const { return std::make_pair(in_slot_id_, in_rank_); } diff --git a/paddle/fluid/eager/pylayer/py_layer_node.h b/paddle/fluid/eager/pylayer/py_layer_node.h index cd0a517afbf0f..f9d8f641571e5 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.h +++ b/paddle/fluid/eager/pylayer/py_layer_node.h @@ -40,11 +40,6 @@ class GradNodePyLayer : public GradNodeBase { void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } - std::string name() { return "GradNodePyLayer_" + std::string(Py_TYPE(ctx_)->tp_name); } diff --git a/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h b/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h index dff12fdfc34a1..f574ae05d5a91 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h +++ b/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h @@ -32,7 +32,7 @@ class GradTestNode : public egr::GradNodeBase { GradTestNode() : GradNodeBase() { val_ = 1.0; } std::string name() override { return "GradTestNode"; } std::vector> operator()( - std::vector>& grads, + std::vector>& grads, // NOLINT bool create_graph = false) override { val_ = std::dynamic_pointer_cast(grads[0][0].impl()) ->data()[0]; @@ -50,10 +50,6 @@ class GradTestNode : public egr::GradNodeBase { return res; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } float val_; }; } // namespace eager_test diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 781d6616e38ae..d777373910338 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -401,10 +401,6 @@ class GradNodeRunProgram : public egr::GradNodeBase { } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } // SetAttrMap void SetAttrMap(const paddle::framework::AttributeMap &attrs) {