Skip to content

Commit

Permalink
refactor(expression): add function cast_tensor_expression for casting
Browse files Browse the repository at this point in the history
This function casts any `tensor_expression` to its child class, and it
also handles recursive casting to get the real expression that is stored
inside the layers of `tensor_expression`.
  • Loading branch information
amitsingh19975 committed Feb 13, 2022
1 parent d70a701 commit a128dfd
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 49 deletions.
63 changes: 31 additions & 32 deletions include/boost/numeric/ublas/tensor/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,27 @@ static constexpr bool does_exp_need_cast_v = does_exp_need_cast< std::decay_t<T>
template<typename E, typename T>
struct does_exp_need_cast< tensor_expression<T,E> > : std::true_type{};

/**
* @brief It is a safer way of casting `tensor_expression` because it handles
* recursive expressions. Otherwise, in most of the cases, we try to access
* `operator()`, which requires a parameter argument, that is not supported
* by the `tensor_expression` class and might give an error if it is not casted
* properly.
*
* @tparam T type of the tensor
* @tparam E type of the child stored inside tensor_expression
* @param e tensor_expression that needs to be casted
* @return child of tensor_expression that is not tensor_expression
*/
template<typename T, typename E>
constexpr auto const& cast_tensor_exression(tensor_expression<T,E> const& e) noexcept{
auto const& res = e();
if constexpr(does_exp_need_cast_v<decltype(res)>)
return cast_tensor_exression(res);
else
return res;
}

template<typename E, typename T>
constexpr auto is_tensor_expression_impl(tensor_expression<T,E> const*) -> std::true_type;

Expand Down Expand Up @@ -137,33 +158,15 @@ struct binary_tensor_expression
binary_tensor_expression(const binary_tensor_expression& l) = delete;
binary_tensor_expression& operator=(binary_tensor_expression const& l) noexcept = delete;

constexpr auto const& left_expr() const noexcept{ return cast_tensor_exression(el); }
constexpr auto const& right_expr() const noexcept{ return cast_tensor_exression(er); }

[[nodiscard]] inline
constexpr decltype(auto) operator()(size_type i) const
requires (does_exp_need_cast_v<expression_type_left> && does_exp_need_cast_v<expression_type_right>)
{
return op(el()(i), er()(i));
}

[[nodiscard]] inline
constexpr decltype(auto) operator()(size_type i) const
requires (does_exp_need_cast_v<expression_type_left> && !does_exp_need_cast_v<expression_type_right>)
{
return op(el()(i), er(i));
}

[[nodiscard]] inline
constexpr decltype(auto) operator()(size_type i) const
requires (!does_exp_need_cast_v<expression_type_left> && does_exp_need_cast_v<expression_type_right>)
{
return op(el(i), er()(i));
}

[[nodiscard]] inline
constexpr decltype(auto) operator()(size_type i) const {
return op(el(i), er(i));
constexpr decltype(auto) operator()(size_type i) const {
return op(left_expr()(i), right_expr()(i));
}

private:
expression_type_left el;
expression_type_right er;
binary_operation op;
Expand Down Expand Up @@ -211,19 +214,15 @@ struct unary_tensor_expression
constexpr unary_tensor_expression() = delete;
unary_tensor_expression(unary_tensor_expression const& l) = delete;
unary_tensor_expression& operator=(unary_tensor_expression const& l) noexcept = delete;

[[nodiscard]] inline constexpr
decltype(auto) operator()(size_type i) const
requires does_exp_need_cast_v<expression_type>
{
return op(e()(i));
}

constexpr auto const& expr() const noexcept{ return cast_tensor_exression(e); }

[[nodiscard]] inline constexpr
decltype(auto) operator()(size_type i) const {
return op(e(i));
decltype(auto) operator()(size_type i) const {
return op(expr()(i));
}

private:
expression_type e;
unary_operation op;
};
Expand Down
50 changes: 33 additions & 17 deletions include/boost/numeric/ublas/tensor/expression_evaluation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,20 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
static_assert(has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& lexpr = expr.left_expr();
auto const& rexpr = expr.right_expr();

if constexpr ( same_exp<T,EL> )
return expr.el.extents();
return lexpr.extents();

else if constexpr ( same_exp<T,ER> )
return expr.er.extents();
return rexpr.extents();

else if constexpr ( has_tensor_types_v<T,EL> )
return retrieve_extents(expr.el);
return retrieve_extents(lexpr);

else if constexpr ( has_tensor_types_v<T,ER> )
return retrieve_extents(expr.er);
return retrieve_extents(rexpr);
}

#ifdef _MSC_VER
Expand All @@ -164,12 +167,14 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)

static_assert(has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& uexpr = expr.expr();

if constexpr ( same_exp<T,E> )
return expr.e.extents();
return uexpr.extents();

else if constexpr ( has_tensor_types_v<T,E> )
return retrieve_extents(expr.e);
return retrieve_extents(uexpr);
}

} // namespace boost::numeric::ublas::detail
Expand Down Expand Up @@ -221,20 +226,23 @@ constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& exp
using ::operator==;
using ::operator!=;

auto const& lexpr = expr.left_expr();
auto const& rexpr = expr.right_expr();

if constexpr ( same_exp<T,EL> )
if(e != expr.el.extents())
if(e != lexpr.extents())
return false;

if constexpr ( same_exp<T,ER> )
if(e != expr.er.extents())
if(e != rexpr.extents())
return false;

if constexpr ( has_tensor_types_v<T,EL> )
if(!all_extents_equal(expr.el, e))
if(!all_extents_equal(lexpr, e))
return false;

if constexpr ( has_tensor_types_v<T,ER> )
if(!all_extents_equal(expr.er, e))
if(!all_extents_equal(rexpr, e))
return false;

return true;
Expand All @@ -250,12 +258,14 @@ constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, ex

using ::operator==;

auto const& uexpr = expr.expr();

if constexpr ( same_exp<T,E> )
if(e != expr.e.extents())
if(e != uexpr.extents())
return false;

if constexpr ( has_tensor_types_v<T,E> )
if(!all_extents_equal(expr.e, e))
if(!all_extents_equal(uexpr, e))
return false;

return true;
Expand All @@ -281,9 +291,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
if(!all_extents_equal(expr, lhs.extents() ))
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");

#pragma omp parallel for
auto const& rhs = cast_tensor_exression(expr);

#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
lhs(i) = expr()(i);
lhs(i) = rhs(i);
}

/** @brief Evaluates expression for a tensor_core
Expand All @@ -310,9 +322,11 @@ inline void eval(tensor_type& lhs, tensor_expression<other_tensor_type, derived_
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");
}

auto const& rhs = cast_tensor_exression(expr);

#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
lhs(i) = expr()(i);
lhs(i) = rhs(i);
}

/** @brief Evaluates expression for a tensor_core
Expand All @@ -330,9 +344,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
if(!all_extents_equal( expr, lhs.extents() ))
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");

auto const& rhs = cast_tensor_exression(expr);

#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
fn(lhs(i), expr()(i));
fn(lhs(i), rhs(i));
}


Expand All @@ -347,7 +363,7 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
template<class tensor_type, class unary_fn>
inline void eval(tensor_type& lhs, unary_fn const& fn)
{
#pragma omp parallel for
#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
fn(lhs(i));
}
Expand Down

0 comments on commit a128dfd

Please sign in to comment.