Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Tensor::unroll(). #16467

Merged
merged 1 commit into from
Jan 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 29 additions & 48 deletions include/deal.II/base/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,6 @@ class Tensor<0, dim, Number>
*/
Number value;

/**
* Internal helper function for unroll.
*/
template <typename Iterator>
Iterator
unroll_recursion(const Iterator current, const Iterator end) const;

// Allow an arbitrary Tensor to access the underlying values.
template <int, int, typename>
friend class Tensor;
Expand Down Expand Up @@ -885,13 +878,6 @@ class Tensor
// ... avoid a compiler warning in case of dim == 0 and ensure that the
// array always has positive size.

/**
* Internal helper function for unroll.
*/
template <typename Iterator>
Iterator
unroll_recursion(const Iterator current, const Iterator end) const;

/**
* This constructor is for internal use. It provides a way
* to create constexpr constructors for Tensor<rank, dim, Number>
Expand Down Expand Up @@ -1255,24 +1241,6 @@ constexpr DEAL_II_HOST_DEVICE inline DEAL_II_ALWAYS_INLINE



template <int dim, typename Number>
template <typename Iterator>
Iterator
Tensor<0, dim, Number>::unroll_recursion(const Iterator current,
const Iterator end) const
{
(void)end;
Assert(dim != 0,
ExcMessage("Cannot unroll an object of type Tensor<0,0,Number>"));
Assert(std::distance(current, end) >= 1,
ExcMessage("The provided iterator range must contain at least one "
"element."));
*current = value;
return std::next(current);
}



template <int dim, typename Number>
constexpr inline void
Tensor<0, dim, Number>::clear()
Expand All @@ -1289,8 +1257,14 @@ template <class Iterator>
inline void
Tensor<0, dim, Number>::unroll(const Iterator begin, const Iterator end) const
{
(void)end;
AssertDimension(std::distance(begin, end), n_independent_components);
unroll_recursion(begin, end);
Assert(dim != 0,
ExcMessage("Cannot unroll an object of type Tensor<0,0,Number>"));
Assert(std::distance(begin, end) >= 1,
ExcMessage("The provided iterator range must contain at least one "
"element."));
*begin = value;
}


Expand Down Expand Up @@ -1775,25 +1749,32 @@ inline void
Tensor<rank_, dim, Number>::unroll(const Iterator begin,
const Iterator end) const
{
AssertDimension(std::distance(begin, end), n_independent_components);
unroll_recursion(begin, end);
if constexpr (rank_ > 1)
{
// For higher-rank tensors, we recurse to the sub-tensors:
Iterator next = begin;
for (unsigned int i = 0; i < dim; ++i)
{
values[i].unroll(next, end);
std::advance(
next, Tensor<rank_ - 1, dim, Number>::n_independent_components);
}
}
else
{
// For rank-1 tensors, we can simply copy the current elements from
// our linear array into the output range, and then return the
// iterator moved forward by 'dim' elements:
Assert(std::distance(begin, end) >= dim,
ExcMessage(
"The provided iterator range must contain at least 'dim' "
"elements."));
std::copy(&values[0], &values[dim], begin);
}
}



template <int rank_, int dim, typename Number>
template <typename Iterator>
Iterator
Tensor<rank_, dim, Number>::unroll_recursion(const Iterator current,
const Iterator end) const
{
Iterator next = current;
for (unsigned int i = 0; i < dim; ++i)
next = values[i].unroll_recursion(next, end);
return next;
}


template <int rank_, int dim, typename Number>
constexpr inline unsigned int
Tensor<rank_, dim, Number>::component_to_unrolled_index(
Expand Down