Skip to content

Commit

Permalink
fix wrong mtv call from test multiplication.
Browse files Browse the repository at this point in the history
  • Loading branch information
bassoy committed Dec 11, 2022
1 parent 949ffaa commit 50e19e5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 38 deletions.
2 changes: 1 addition & 1 deletion include/boost/numeric/ublas/tensor/multiplication.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ void ttv0(SizeType const r,
* [n] = size(B(..,:,..))
*
*
* @param[in] k if k = 0
* @param[in] k C[i1] = sum(A[i1,i2] * B[i2]) if k = 1 or C[i2] = sum(A[i1,i2] * B[i1]) if k = 0
* @param[in] m number of rows of A
* @param[in] n number of columns of A
* @param[out] c pointer to C
Expand Down
35 changes: 22 additions & 13 deletions test/tensor/multiplication/test_multiplication_mtv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
// Google and Fraunhofer IOSB, Ettlingen, Germany
//

#include <boost/test/unit_test.hpp>
#include "../fixture_utility.hpp"
#include <boost/test/unit_test.hpp>
#include <boost/numeric/ublas/matrix.hpp>
#include <boost/numeric/ublas/vector.hpp>

#include <boost/numeric/ublas/tensor/multiplication.hpp>

BOOST_AUTO_TEST_SUITE(test_multiplication_mtv,
*boost::unit_test::description("Validate Matrix Times Vector")
*boost::unit_test::description("Test Matrix Times Vector")
)


Expand Down Expand Up @@ -62,11 +64,13 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_extents_dynamic,
auto wc = ublas::to_strides(nc,layout_type{});
auto c = vector_t (ublas::product(nc), value_type{0});

ublas::detail::recursive::mtv(
ublas::detail::recursive::mtv(
m,
c.data(), nc.data(), wc.data(),
a.data(), na.data(), wa.data(),
b.data());
na[0],
na[1],
c.data(), 1ul,
a.data(), wa[0], wa[1],
b.data(), 1ul);

auto v = value_type{static_cast<inner_t>(na[m])};
BOOST_CHECK(std::equal(c.begin(),c.end(),a.begin(), [v](auto cc, auto aa){return cc == v*aa;}));
Expand Down Expand Up @@ -123,9 +127,11 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_extents_static_rank,

ublas::detail::recursive::mtv(
m,
c.data(), nc.data(), wc.data(),
a.data(), na.data(), wa.data(),
b.data());
na[0],
na[1],
c.data(), 1ul,
a.data(), wa[0], wa[1],
b.data(), 1ul);

auto v = value_type{static_cast<inner_t>(na[m])};
BOOST_CHECK(std::equal(c.begin(),c.end(),a.begin(), [v](auto cc, auto aa){return cc == v*aa;}));
Expand Down Expand Up @@ -214,11 +220,14 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_extents_static,
auto c = std::array<value_type, ublas::product_v<nc_type> >();
std::fill(std::begin(c), std::end(c), value_type{0});


ublas::detail::recursive::mtv(
m,
c.data(), nc.data(), wc.data(),
a.data(), na.data(), wa.data(),
b.data());
na[0],
na[1],
c.data(), 1ul,
a.data(), wa[0], wa[1],
b.data(), 1ul);

auto v = value_type{static_cast<inner_t>(na[m])};
BOOST_CHECK(std::equal(c.begin(),c.end(),a.begin(), [v](auto cc, auto aa){return cc == v*aa;}));
Expand All @@ -229,4 +238,4 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_extents_static,
});
}

BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE_END()
24 changes: 0 additions & 24 deletions test/tensor/test_main.cpp

This file was deleted.

0 comments on commit 50e19e5

Please sign in to comment.