Skip to content

Commit

Permalink
Merge pull request #147 from bluescarni/pr/pairwise_reduce
Browse files Browse the repository at this point in the history
Pairwise product
  • Loading branch information
bluescarni committed May 30, 2021
2 parents 85bd420 + 73785cb commit c5a1828
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 19 deletions.
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if(NOT CMAKE_BUILD_TYPE)
FORCE)
endif()

project(heyoka VERSION 0.9.0 LANGUAGES CXX C)
project(heyoka VERSION 0.10.0 LANGUAGES CXX C)

list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/yacma")

Expand Down Expand Up @@ -234,8 +234,8 @@ if(HEYOKA_BUILD_STATIC_LIBRARY)
else()
# Setup of the heyoka shared library.
add_library(heyoka SHARED "${HEYOKA_SRC_FILES}")
set_property(TARGET heyoka PROPERTY VERSION "9.0")
set_property(TARGET heyoka PROPERTY SOVERSION 9)
set_property(TARGET heyoka PROPERTY VERSION "10.0")
set_property(TARGET heyoka PROPERTY SOVERSION 10)
set_target_properties(heyoka PROPERTIES CXX_VISIBILITY_PRESET hidden)
set_target_properties(heyoka PROPERTIES VISIBILITY_INLINES_HIDDEN TRUE)
endif()
Expand Down
9 changes: 9 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
=========

0.10.0 (unreleased)
-------------------

New
~~~

- Add a pairwise product primitive
(`#146 <https://github.com/bluescarni/heyoka/pull/146>`__).

0.9.0 (2021-05-25)
------------------

Expand Down
1 change: 1 addition & 0 deletions include/heyoka/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ HEYOKA_DLL_PUBLIC expression diff(const expression &, const std::string &);
HEYOKA_DLL_PUBLIC expression diff(const expression &, const expression &);

HEYOKA_DLL_PUBLIC expression pairwise_sum(std::vector<expression>);
HEYOKA_DLL_PUBLIC expression pairwise_prod(std::vector<expression>);

HEYOKA_DLL_PUBLIC double eval_dbl(const expression &, const std::unordered_map<std::string, double> &,
const std::vector<double> & = {});
Expand Down
65 changes: 49 additions & 16 deletions src/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,36 +702,69 @@ expression subs(const expression &e, const std::unordered_map<std::string, expre
return std::visit([&smap](const auto &arg) { return subs(arg, smap); }, e.value());
}

// Pairwise summation of a vector of expressions.
// https://en.wikipedia.org/wiki/Pairwise_summation
expression pairwise_sum(std::vector<expression> sum)
namespace detail
{
if (sum.size() == std::numeric_limits<decltype(sum.size())>::max()) {
throw std::overflow_error("Overflow detected in pairwise_sum()");
}

if (sum.empty()) {
return expression{0.};
namespace
{

// Pairwise reduction of a vector of expressions.
template <typename F>
expression pairwise_reduce(const F &func, std::vector<expression> list)
{
assert(!list.empty());

// LCOV_EXCL_START
if (list.size() == std::numeric_limits<decltype(list.size())>::max()) {
throw std::overflow_error("Overflow detected in pairwise_reduce()");
}
// LCOV_EXCL_STOP

while (sum.size() != 1u) {
std::vector<expression> new_sum;
while (list.size() != 1u) {
std::vector<expression> new_list;

for (decltype(sum.size()) i = 0; i < sum.size(); i += 2u) {
if (i + 1u == sum.size()) {
for (decltype(list.size()) i = 0; i < list.size(); i += 2u) {
if (i + 1u == list.size()) {
// We are at the last element of the vector
// and the size of the vector is odd. Just append
// the existing value.
new_sum.push_back(std::move(sum[i]));
new_list.push_back(std::move(list[i]));
} else {
new_sum.push_back(std::move(sum[i]) + std::move(sum[i + 1u]));
new_list.push_back(func(std::move(list[i]), std::move(list[i + 1u])));
}
}

new_sum.swap(sum);
new_list.swap(list);
}

return std::move(list[0]);
}

} // namespace

} // namespace detail

// Pairwise summation of a vector of expressions.
// https://en.wikipedia.org/wiki/Pairwise_summation
expression pairwise_sum(std::vector<expression> sum)
{
if (sum.empty()) {
return expression{0.};
}

return detail::pairwise_reduce([](expression &&a, expression &&b) { return std::move(a) + std::move(b); },
std::move(sum));
}

// Pairwise product.
expression pairwise_prod(std::vector<expression> prod)
{
if (prod.empty()) {
return expression{1.};
}

return sum[0];
return detail::pairwise_reduce([](expression &&a, expression &&b) { return std::move(a) * std::move(b); },
std::move(prod));
}

double eval_dbl(const expression &e, const std::unordered_map<std::string, double> &map,
Expand Down
26 changes: 26 additions & 0 deletions test/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,29 @@ TEST_CASE("has time")
REQUIRE(has_time((x + y) * (hy::time + 1_dbl)));
REQUIRE(has_time((x + y) * (par[0] * hy::time + 1_dbl)));
}

TEST_CASE("pairwise_sum")
{
auto [x0, x1, x2, x3, x4, x5] = make_vars("x0", "x1", "x2", "x3", "x4", "x5");

REQUIRE(pairwise_sum({}) == 0_dbl);
REQUIRE(pairwise_sum({x0}) == x0);
REQUIRE(pairwise_sum({x0, x1}) == x0 + x1);
REQUIRE(pairwise_sum({x0, x1, x2}) == x0 + x1 + x2);
REQUIRE(pairwise_sum({x0, x1, x2, x3}) == (x0 + x1) + (x2 + x3));
REQUIRE(pairwise_sum({x0, x1, x2, x3, x4}) == (x0 + x1) + (x2 + x3) + x4);
REQUIRE(pairwise_sum({x0, x1, x2, x3, x4, x5}) == ((x0 + x1) + (x2 + x3)) + (x4 + x5));
}

TEST_CASE("pairwise_prod")
{
auto [x0, x1, x2, x3, x4, x5] = make_vars("x0", "x1", "x2", "x3", "x4", "x5");

REQUIRE(pairwise_prod({}) == 1_dbl);
REQUIRE(pairwise_prod({x0}) == x0);
REQUIRE(pairwise_prod({x0, x1}) == x0 * x1);
REQUIRE(pairwise_prod({x0, x1, x2}) == x0 * x1 * x2);
REQUIRE(pairwise_prod({x0, x1, x2, x3}) == (x0 * x1) * (x2 * x3));
REQUIRE(pairwise_prod({x0, x1, x2, x3, x4}) == (x0 * x1) * (x2 * x3) * x4);
REQUIRE(pairwise_prod({x0, x1, x2, x3, x4, x5}) == ((x0 * x1) * (x2 * x3)) * (x4 * x5));
}

0 comments on commit c5a1828

Please sign in to comment.