Skip to content

Commit

Permalink
Fix heavy cascade copies by using shared pointers on Var(s) too
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromani committed Apr 5, 2024
1 parent 72c5e02 commit 0ae2685
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 22 deletions.
2 changes: 2 additions & 0 deletions .autoformat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

clang-format -i src/*.cpp src/*.h src/functions/*.cpp src/functions/*.h
clang-format -i tests/*.cpp tests/include/*.h
clang-format -i examples/*.cpp

cmake-format -c=.cmake-format CMakeLists.txt -o CMakeLists.txt
cmake-format -c=.cmake-format src/CMakeLists.txt -o src/CMakeLists.txt
cmake-format -c=.cmake-format tests/CMakeLists.txt -o tests/CMakeLists.txt
cmake-format -c=.cmake-format examples/CMakeLists.txt -o examples/CMakeLists.txt
5 changes: 4 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ cmake_minimum_required(VERSION 3.16)

add_executable(example_covariances example_covariances.cpp)
add_executable(example_derivatives example_derivatives.cpp)
add_executable(example_optimization example_optimization.cpp)

target_link_libraries(example_covariances cascade_static)
target_link_libraries(example_derivatives cascade_static)
target_link_libraries(example_optimization cascade_static)

install(TARGETS example_covariances example_derivatives RUNTIME DESTINATION bin)
install(TARGETS example_covariances example_derivatives example_optimization
RUNTIME DESTINATION bin)
26 changes: 26 additions & 0 deletions examples/example_optimization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "cascade.h"

#include <iostream>

using namespace cascade;

int main()
{
Var x = 1.0;
Var y = 0.5;

Var z;

for (int i = 0; i < 100; ++i)
{
std::cout << i << std::endl;
z = x * y;
x = x + 2.0;
}

z = x * y;

z.backprop();

return 0;
}
39 changes: 20 additions & 19 deletions src/var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ void Var::createEdges_(const std::initializer_list<Var> &inputNodes, Var &output
{
for (const Var &x: inputNodes)
{
outputNode.children_.push_back(x);
outputNode.children_.push_back(std::make_shared<Var>(x));
outputNode.node_->children_.push_back(x.node_);
}

for (Var &x: outputNode.children_)
for (const std::shared_ptr<Var> &x: outputNode.children_)
{
x.parents_.push_back(outputNode);
x.node_->parents_.push_back(outputNode.node_);
x->parents_.push_back(std::make_shared<Var>(outputNode));
x->node_->parents_.push_back(outputNode.node_);
}
}

Expand All @@ -186,22 +186,22 @@ std::vector<Var> Var::sortedNodes_() const

nodes.push_back(node);

for (const Var &child: node.children_)
for (const std::shared_ptr<Var> &child: node.children_)
{
auto search = numParents.find(child.id());
auto search = numParents.find(child->id());

if (search != numParents.end())
{
--numParents[child.id()];
--numParents[child->id()];
}
else
{
numParents[child.id()] = child.parents_.size() - 1;
numParents[child->id()] = child->parents_.size() - 1;
}

if (numParents[child.id()] == 0)
if (numParents[child->id()] == 0)
{
stack.push(child);
stack.push(*child);
}
}
}
Expand All @@ -214,9 +214,10 @@ std::vector<Var> Var::inputNodes_() const
const std::vector<Var> nodes = sortedNodes_();

std::vector<Var> inputNodes;
std::copy_if(nodes.begin(), nodes.end(), std::back_inserter(inputNodes), [](const Var &node) {
return node.children_.empty();
});
std::copy_if(nodes.begin(),
nodes.end(),
std::back_inserter(inputNodes),
[](const Var &node) { return node.children_.empty(); });

return inputNodes;
}
Expand All @@ -228,17 +229,17 @@ double Var::covariance_(const Var &x, const Var &y)

// Copy the gradients of the input nodes before backpropagating on the second graph
std::unordered_map<int, double> xGradientMap;
std::for_each(xNodes.begin(), xNodes.end(), [&xGradientMap](const Var &node) {
xGradientMap.emplace(node.id(), node.derivative());
});
std::for_each(xNodes.begin(),
xNodes.end(),
[&xGradientMap](const Var &node) { xGradientMap.emplace(node.id(), node.derivative()); });

y.backprop();
const std::vector<Var> yNodes = y.inputNodes_();

std::unordered_map<int, double> yGradientMap;
std::for_each(yNodes.begin(), yNodes.end(), [&yGradientMap](const Var &node) {
yGradientMap.emplace(node.id(), node.derivative());
});
std::for_each(yNodes.begin(),
yNodes.end(),
[&yGradientMap](const Var &node) { yGradientMap.emplace(node.id(), node.derivative()); });

xNodes.insert(xNodes.end(), yNodes.begin(), yNodes.end());

Expand Down
4 changes: 2 additions & 2 deletions src/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class Var

std::shared_ptr<Node> node_;

std::vector<Var> children_;
std::vector<Var> parents_;
std::vector<std::shared_ptr<Var>> children_;
std::vector<std::weak_ptr<Var>> parents_;
};
} // namespace cascade

Expand Down

0 comments on commit 0ae2685

Please sign in to comment.