Skip to content

Commit

Permalink
refactor(compare): combine two compare function into one for easier m…
Browse files Browse the repository at this point in the history
…aintainability
  • Loading branch information
amitsingh19975 committed Feb 16, 2022
1 parent 549f8a0 commit 5162569
Showing 1 changed file with 47 additions and 82 deletions.
129 changes: 47 additions & 82 deletions include/boost/numeric/ublas/tensor/operators_comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,70 +60,60 @@ struct is_equality_functional_object< std::not_equal_to<> >
: std::true_type
{};

template<class T1, class T2, class L, class R, class BinaryPred>
[[nodiscard]] inline
constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,R> const& rhs, BinaryPred&& pred) noexcept
requires is_equality_functional_object_v<BinaryPred>
template<integral SizeType, typename LE, typename RE>
constexpr auto compare_helper(LE const& le, RE const& re, std::true_type /*unused*/) noexcept
-> std::pair<bool, SizeType>
{

auto const& lexpr = cast_tensor_expression(lhs);
auto const& rexpr = cast_tensor_expression(rhs);

using lvalue_type = decltype(lexpr(0));
using rvalue_type = decltype(rexpr(0));

static_assert( same_exp< lvalue_type, rvalue_type >,
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
"both LHS and RHS should have the same value type"
);

static_assert(
std::is_invocable_r_v<bool, BinaryPred, lvalue_type, rvalue_type>,
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
"the predicate must be a binary predicate, and it must return a bool"
);

auto const& le = retrieve_extents(lexpr);
auto const& re = retrieve_extents(rexpr);

using size_type = typename T1::size_type;
using ::operator==;

constexpr auto zero = SizeType{};

if constexpr( is_static_v< LE > && is_static_v< RE > ){
constexpr bool is_same = std::is_same_v<LE, RE>;
constexpr SizeType size = ( is_same ? SizeType{ product_v< LE > } : zero );
return { is_same, size };
}else{
bool const is_same = ( le == re );
SizeType const size = ( is_same ? SizeType{ product(le) } : zero );
return { is_same, size };
}
}

// returns the pair containing false if extents are not equal
// else true, and the size of the container.
constexpr auto cal_size = [](auto const& le, auto const& re)
-> std::pair<bool, size_type>
{
using lex_t = std::decay_t< decltype(le) >;
using rex_t = std::decay_t< decltype(re) >;

if constexpr(is_static_v< lex_t > && is_static_v< rex_t >){
constexpr bool is_same = same_exp< lex_t, rex_t >;
return { is_same, is_same ? product_v< lex_t > : size_type{} };
} else {
bool const is_same = ::operator==(le,re);
return { is_same, is_same ? product( le ) : size_type{} };
template<typename SizeType, typename LE, typename RE>
constexpr auto compare_helper(LE const& le, RE const& re, std::false_type /*unused*/)
noexcept( is_static_v< LE> && is_static_v< RE > ) -> std::pair<bool, SizeType>
{
using ::operator!=;

if constexpr( is_static_v< LE > && is_static_v< RE > ){
static_assert(std::is_same_v< LE, RE >,
"boost::numeric::ublas::detail::compare_helper(Lextents const& lhs, Rextents const& rhs) : "
"cannot compare tensors with different shapes."
);

constexpr SizeType size = product_v< LE >;
return { true, size };
}else{
if(le != re){
throw std::runtime_error(
"boost::numeric::ublas::detail::compare_helper(Lextents const& lhs, Rextents const& rhs) : "
"cannot compare tensors with different shapes."
);
}
};

auto const [status, size] = cal_size(le, re);

for(auto i = size_type{}; i < size; ++i){
if(!std::invoke(pred, lexpr(i), rexpr(i)))
return false;
SizeType const size = product( le );
return { true, size };
}

// return false if the status is false
return ( true & status );
}

template<class T1, class T2, class L, class R, class BinaryPred>
[[nodiscard]] inline
constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,R> const& rhs, BinaryPred&& pred)
noexcept(
is_static_v< std::decay_t< decltype(retrieve_extents(lhs)) > > &&
is_static_v< std::decay_t< decltype(retrieve_extents(rhs)) > >
( is_static_v< std::decay_t< decltype(retrieve_extents(lhs)) > > &&
is_static_v< std::decay_t< decltype(retrieve_extents(rhs)) > >
) || is_equality_functional_object_v<BinaryPred>
)
requires ( not is_equality_functional_object_v<BinaryPred> )
{
auto const& lexpr = cast_tensor_expression(lhs);
auto const& rexpr = cast_tensor_expression(rhs);
Expand All @@ -146,41 +136,16 @@ constexpr bool compare(tensor_expression<T1,L> const& lhs, tensor_expression<T2,
auto const& re = retrieve_extents(rexpr);

using size_type = typename T1::size_type;
using is_eq_t = std::conditional_t< is_equality_functional_object_v<BinaryPred>, std::true_type, std::false_type >;

// returns the size of the container
constexpr auto cal_size = [](auto const& le, auto const& re)
-> size_type
{
using lex_t = std::decay_t< decltype(le) >;
using rex_t = std::decay_t< decltype(re) >;

if constexpr(is_static_v< lex_t > && is_static_v< rex_t >){
static_assert(same_exp< lex_t, rex_t >,
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
"cannot compare tensors with different shapes."
);

return product_v< lex_t >;
}else{
if(::operator!=(le,re)){
throw std::runtime_error(
"boost::numeric::ublas::detail::compare(tensor_expresion const& lhs, tensor_expresion const& rhs, BinaryFn&& pred) : "
"cannot compare tensors with different shapes."
);
}

return product( le );
}
};

size_type const size = cal_size(le, re);

auto const [status, size] = compare_helper<size_type>(le, re, is_eq_t{});

for(auto i = size_type{}; i < size; ++i){
if(!std::invoke(pred, lexpr(i), rexpr(i)))
return false;
}

return true;
return status;
}


Expand All @@ -201,7 +166,7 @@ constexpr bool compare(tensor_expression<T,D> const& expr, UnaryPred&& pred) noe
"the predicate must be an unary predicate, and it must return a bool"
);

size_type const size = is_static_v< extents_t > ? product_v< extents_t > : product( e );;
size_type const size = is_static_v< extents_t > ? product_v< extents_t > : product( e );

for(auto i = size_type{}; i < size; ++i){
if(!std::invoke(pred, ue(i)))
Expand Down

0 comments on commit 5162569

Please sign in to comment.