Skip to content

Commit

Permalink
[DoubleGrad PR PaddlePaddle#7] paddle.grad() to copy backward graph b…
Browse files Browse the repository at this point in the history
…efore backward run
  • Loading branch information
jim19930609 committed Apr 2, 2022
1 parent 1239f69 commit 4c2ece1
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 51 deletions.
15 changes: 9 additions & 6 deletions paddle/fluid/eager/accumulation/accumulation_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class GradNodeAccumulation : public GradNodeBase {
// Constructor: configure fwd input tensors to grad node
explicit GradNodeAccumulation(AutogradMeta* meta) : GradNodeBase(1, 1) {
VLOG(6) << "Construct GradNodeAccumulation";
weak_grad_ = meta->WeakGrad();
if (meta) {
weak_grad_ = meta->WeakGrad();
}

SetDefaultGradInOutMeta();
}

Expand All @@ -40,11 +43,6 @@ class GradNodeAccumulation : public GradNodeBase {

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }

bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

std::string name() { return "GradNodeAccumulation"; }

/**
Expand All @@ -58,6 +56,11 @@ class GradNodeAccumulation : public GradNodeBase {
inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; }
void ApplyReduceHooks();

std::shared_ptr<GradNodeBase> Copy() const override {
return std::shared_ptr<GradNodeAccumulation>(
new GradNodeAccumulation(nullptr));
}

private:
std::weak_ptr<paddle::experimental::Tensor> weak_grad_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class GradNodeScale : public GradNodeBase {

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }

bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

void SetTensorWrappers_X(
const std::vector<paddle::experimental::Tensor>& tensors);

Expand Down
7 changes: 1 addition & 6 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2479,22 +2479,17 @@ static std::string GenerateGradNodeHeaderContents(
"\n"
" void ClearTensorWrappers() override { \n"
"%s\n"
" is_tensor_wrappers_cleared = true;\n"
" SetIsTensorWrappersCleared(true);\n"
" }\n"
" std::string name() override { return \" GradNode%s \"; } \n "
"\n"
" // SetX, SetY, ...\n"
"%s\n"
" // SetAttrMap\n"
"%s\n"
" bool IsTensorWrappersCleared() override { \n"
" return is_tensor_wrappers_cleared;\n"
" }\n"
" private:\n"
" // TensorWrappers\n"
"%s\n"
" bool is_tensor_wrappers_cleared = false;\n"
"\n"
" // Attribute Map\n"
"%s\n"
"};";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,23 +125,23 @@ class {} : public egr::GradNodeBase {{
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
SetIsTensorWrappersCleared(true);
}}
std::shared_ptr<GradNodeBase> Copy() const override {{
auto copied_node = std::make_shared<{}>(*this);
return copied_node;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
{}
// SetAttributes
{}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private:
// TensorWrappers
{}
bool is_tensor_wrappers_cleared = false;
// Attributes
{}
}};
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def GenerateNodeDeclaration(self):
grad_node_name = GetGradNodeName(forward_op_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str,
grad_node_name, clear_tensor_wrapper_str, grad_node_name,
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)

Expand Down
100 changes: 95 additions & 5 deletions paddle/fluid/eager/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ class GeneralGrad {
for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto* orig_target_node = auto_grad_meta->GetMutableGradNode().get();

PADDLE_ENFORCE(
orig_to_copied_node_mapping_.count(orig_target_node),
paddle::platform::errors::InvalidArgument(
"Unable to find target node in orig_to_copied_node_mapping_."
"Most likely the starting nodes were not copied correctly."));
auto* target_node = orig_to_copied_node_mapping_[orig_target_node];

PADDLE_ENFORCE_NOT_NULL(target_node,
paddle::platform::errors::Fatal(
"There is no grad op for %s:[%d] or it's"
Expand Down Expand Up @@ -249,7 +257,14 @@ class GeneralGrad {
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto* orig_target_node = auto_grad_meta->GetMutableGradNode().get();

PADDLE_ENFORCE(
orig_to_copied_node_mapping_.count(orig_target_node),
paddle::platform::errors::InvalidArgument(
"Unable to find target node in orig_to_copied_node_mapping_."
"Most likely the starting nodes were not copied correctly."));
auto* target_node = orig_to_copied_node_mapping_[orig_target_node];

auto iter = results_map.find(target_node);
if (iter != results_map.end()) {
Expand Down Expand Up @@ -328,6 +343,64 @@ class GeneralGrad {
results_map.clear();
}

GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
if (orig_to_copied_node_mapping_.count(orig_node.get())) {
return orig_to_copied_node_mapping_[orig_node.get()];
}
std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();

// Save node and update mapping
orig_to_copied_node_mapping_[orig_node.get()] = copied_node.get();
copied_grad_nodes_.push_back(copied_node);

return copied_node.get();
}

void ReconstructBackwardGraph(
const std::queue<GradNodeBase*>& orig_init_queue) {
std::queue<GradNodeBase*> queue = orig_init_queue;

// BFS and recursively copy the grad nodes
while (!queue.empty()) {
GradNodeBase* orig_node = queue.front();
queue.pop();

PADDLE_ENFORCE(
orig_to_copied_node_mapping_.count(orig_node),
paddle::platform::errors::Fatal(
"Cannot reconstruct backward graph,"
"unable to find copied target for certain grad node."));
GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node];

const std::vector<std::vector<Edge>>& orig_edges = orig_node->GetEdges();
std::vector<std::vector<Edge>>& copied_edges =
copied_node->GetMutableEdges();
for (size_t i = 0; i < orig_edges.size(); i++) {
for (size_t j = 0; j < orig_edges[i].size(); j++) {
const Edge& orig_edge = orig_edges[i][j];
Edge& copied_edge = copied_edges[i][j];

std::shared_ptr<GradNodeBase> orig_next_node =
orig_edge.GetMutableGradNode();
if (!orig_next_node) continue;

// Copy Next Node
std::shared_ptr<GradNodeBase> copied_next_node =
orig_next_node->Copy();
orig_to_copied_node_mapping_[orig_next_node.get()] =
copied_next_node.get();
copied_grad_nodes_.push_back(copied_next_node);

// Update Edge's Grad Node
copied_edge.SetGradNode(copied_next_node);

// Update BFS queue
queue.push(orig_next_node.get());
}
}
}
}

private:
GeneralGrad() = default;
static GeneralGrad* general_grad_;
Expand All @@ -345,6 +418,10 @@ class GeneralGrad {
std::unordered_set<GradNodeBase*> /* pre nodes */>
depending_nodes;
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;

std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
std::unordered_map<GradNodeBase*, GradNodeBase*> orig_to_copied_node_mapping_;

DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};

Expand Down Expand Up @@ -444,6 +521,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// 1. Init queue with starting nodes
// 2. Prepare initial input buffers
std::queue<GradNodeBase*> queue;
std::queue<GradNodeBase*> orig_queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict;
for (size_t i = 0; i < tensors.size(); i++) {
Expand All @@ -467,6 +545,16 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}

GradNodeBase* grad_node = shared_grad_node.get();
if (is_general_grad) {
// Save orig grad node
orig_queue.push(grad_node);

// Replace grad_node with copied grad_node
grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

// Record potential startup grad node
GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
}

// Prepare GradTensorHolder
if (!node_input_buffers_dict.count(grad_node)) {
Expand Down Expand Up @@ -510,9 +598,11 @@ std::vector<paddle::experimental::Tensor> RunBackward(

// Prepare queue, potential startup_nodes
queue.push(grad_node);
if (is_general_grad) {
GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node);
}
}

if (is_general_grad) {
// Copy Backward Graph
GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
}

VLOG(6) << "Update In degree Map for backward";
Expand Down
11 changes: 4 additions & 7 deletions paddle/fluid/eager/custom_operator/custom_operator_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ class RunCustomOpNode : public GradNodeBase {
}

// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) // NOLINT
virtual std::vector<std::vector<paddle::experimental::Tensor>>
operator()( // NOLINT
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) // NOLINT
override;

std::string name() {
Expand All @@ -64,10 +65,6 @@ class RunCustomOpNode : public GradNodeBase {
}

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; }

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/eager/grad_node_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
return adj_edges_;
}

std::vector<std::vector<Edge>>& GradNodeBase::GetMutableEdges() {
return adj_edges_;
}

std::vector<std::vector<paddle::experimental::Tensor>>
GradNodeBase::ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {
Expand Down
27 changes: 26 additions & 1 deletion paddle/fluid/eager/grad_node_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,16 @@ class GradNodeBase {

virtual void ClearTensorWrappers() = 0;

virtual bool IsTensorWrappersCleared() = 0;
/**
* Self-Copy interface designed for use in DoubleGrad
* **/
virtual std::shared_ptr<GradNodeBase> Copy() const {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Self-copy not supported for current GradNode."
"Please override GradNodeBase::Copy() method if necessary."));
return nullptr;
}

/**
* AddEdges is designed to set input tensors' backward Node as current
* node's Edges.
Expand Down Expand Up @@ -191,6 +200,16 @@ class GradNodeBase {
/**
* GetEdges is designed to get all edges of current node**/
const std::vector<std::vector<Edge>>& GetEdges() const;
std::vector<std::vector<Edge>>& GetMutableEdges();

/**
* The following interfaces are designed for no_need_buffer
* **/
bool IsTensorWrappersCleared() { return is_tensor_wrappers_cleared_; }

void SetIsTensorWrappersCleared(bool is_tensor_wrappers_cleared) {
is_tensor_wrappers_cleared_ = is_tensor_wrappers_cleared;
}

private:
// TODO(zhanlve): Merge adj_edges_ into GradOutMeta
Expand Down Expand Up @@ -218,6 +237,7 @@ class GradNodeBase {
// We handle complex to real conversion only if any complex GradIn is involved
bool need_complex_to_real_ = false;
int64_t next_hook_id_{0};
bool is_tensor_wrappers_cleared_ = false;
};

class Edge {
Expand Down Expand Up @@ -246,6 +266,11 @@ class Edge {
return grad_node_;
}

void SetGradNode(const std::shared_ptr<GradNodeBase>& node) {
VLOG(6) << "Reseting Edge's Grad Node";
grad_node_ = node;
}

std::pair<size_t, size_t> GetEdgeRankInfo() const {
return std::make_pair(in_slot_id_, in_rank_);
}
Expand Down
5 changes: 0 additions & 5 deletions paddle/fluid/eager/pylayer/py_layer_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ class GradNodePyLayer : public GradNodeBase {

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }

bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

std::string name() {
return "GradNodePyLayer_" + std::string(Py_TYPE(ctx_)->tp_name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class GradTestNode : public egr::GradNodeBase {
GradTestNode() : GradNodeBase() { val_ = 1.0; }
std::string name() override { return "GradTestNode"; }
std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads,
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override {
val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl())
->data<float>()[0];
Expand All @@ -50,10 +50,6 @@ class GradTestNode : public egr::GradNodeBase {
return res;
}
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
float val_;
};
} // namespace eager_test
4 changes: 0 additions & 4 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,6 @@ class GradNodeRunProgram : public egr::GradNodeBase {
}

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

// SetAttrMap
void SetAttrMap(const paddle::framework::AttributeMap &attrs) {
Expand Down

0 comments on commit 4c2ece1

Please sign in to comment.