Skip to content

Commit

Permalink
Implement StridedArrayView
Browse files Browse the repository at this point in the history
  • Loading branch information
bergbauer committed Sep 8, 2023
1 parent b47d186 commit 3d9fd71
Showing 1 changed file with 149 additions and 106 deletions.
255 changes: 149 additions & 106 deletions include/deal.II/base/array_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,142 @@ class Table;
template <typename number>
class LAPACKFullMatrix;

/**
* Base class of @p ArrayView which allows strided access into the view.
* This is particularly useful when you want to access only one lane of a
* VectorizedArray.
*/
template <typename ElementType, std::size_t stride = 1>
class StridedArrayView
{
public:
/**
* An alias that denotes the "value_type" of this container-like class,
* i.e., the type of the element it "stores" or points to.
*/
using value_type = ElementType;

/**
* Constructor.
*
* @param[in] starting_element A pointer to the first element of the array
* this object should represent.
* @param[in] n_elements The length (in elements) of the chunk of memory
* this object should represent.
*
* @note The object that is constructed from these arguments has no
* knowledge how large the object into which it points really is. As a
* consequence, whenever you call ArrayView::operator[], the array view can
* check that the given index is within the range of the view, but it can't
* check that the view is indeed a subset of the valid range of elements of
* the underlying object that allocated that range. In other words, you need
* to ensure that the range of the view specified by the two arguments to
* this constructor is in fact a subset of the elements of the array into
* which it points. The appropriate way to do this is to use the
* make_array_view() functions.
*/
StridedArrayView(value_type *starting_element, const std::size_t n_elements);

/**
* Return the size (in elements) of the view of memory this object
* represents.
*/
std::size_t
size() const;

/**
* Return a bool whether the array view is empty.
*/
bool
empty() const;

/**
* Return a pointer to the underlying array serving as element storage.
* In case the container is empty a nullptr is returned.
*/
DEAL_II_HOST_DEVICE value_type *
data() const noexcept;

/**
* Return a reference to the $i$th element of the range represented by the
* current object.
*
* This function is marked as @p const because it does not change the
* <em>view object</em>. It may however return a reference to a non-@p const
* memory location depending on whether the template type of the class is @p
* const or not.
*
* This function is only allowed to be called if the underlying data is indeed
* stored in CPU memory.
*/
value_type &
operator[](const std::size_t i) const;

protected:
/**
* A pointer to the first element of the range of locations in memory that
* this object represents.
*/
value_type *starting_element;

/**
* The length of the array this object represents.
*/
std::size_t n_elements;
};



template <typename ElementType, std::size_t stride>
typename StridedArrayView<ElementType, stride>::value_type &
StridedArrayView<ElementType, stride>::operator[](const std::size_t i) const
{
AssertIndexRange(i, this->n_elements);

return *(this->starting_element + stride * i);
}



template <typename ElementType, std::size_t stride>
typename StridedArrayView<ElementType, stride>::value_type *
StridedArrayView<ElementType, stride>::data() const noexcept
{
if (this->n_elements == 0)
return nullptr;
else
return this->starting_element;
}



template <typename ElementType, std::size_t stride>
bool
StridedArrayView<ElementType, stride>::empty() const
{
return this->n_elements == 0;
}



template <typename ElementType, std::size_t stride>
std::size_t
StridedArrayView<ElementType, stride>::size() const
{
return this->n_elements;
}



template <typename ElementType, std::size_t stride>
StridedArrayView<ElementType, stride>::StridedArrayView(
value_type *starting_element,
const std::size_t n_elements)
: starting_element(starting_element)
, n_elements(n_elements)
{}



/**
* A class that represents a window of memory locations of type @p ElementType
Expand Down Expand Up @@ -82,7 +218,7 @@ class LAPACKFullMatrix;
* @ingroup data
*/
template <typename ElementType, typename MemorySpaceType = MemorySpace::Host>
class ArrayView
class ArrayView : public StridedArrayView<ElementType, 1>
{
public:
/**
Expand Down Expand Up @@ -309,26 +445,6 @@ class ArrayView
operator!=(const ArrayView<std::remove_cv_t<value_type>, MemorySpaceType>
&other_view) const;

/**
* Return the size (in elements) of the view of memory this object
* represents.
*/
std::size_t
size() const;

/**
* Return a bool whether the array view is empty.
*/
bool
empty() const;

/**
* Return a pointer to the underlying array serving as element storage.
* In case the container is empty a nullptr is returned.
*/
DEAL_II_HOST_DEVICE value_type *
data() const noexcept;

/**
* Return an iterator pointing to the beginning of the array view.
*/
Expand All @@ -353,33 +469,6 @@ class ArrayView
const_iterator
cend() const;

/**
* Return a reference to the $i$th element of the range represented by the
* current object.
*
* This function is marked as @p const because it does not change the
* <em>view object</em>. It may however return a reference to a non-@p const
* memory location depending on whether the template type of the class is @p
* const or not.
*
* This function is only allowed to be called if the underlying data is indeed
* stored in CPU memory.
*/
value_type &
operator[](const std::size_t i) const;

private:
/**
* A pointer to the first element of the range of locations in memory that
* this object represents.
*/
value_type *starting_element;

/**
* The length of the array this object represents.
*/
std::size_t n_elements;

friend class ArrayView<const ElementType, MemorySpaceType>;
};

Expand All @@ -390,8 +479,7 @@ class ArrayView

template <typename ElementType, typename MemorySpaceType>
inline ArrayView<ElementType, MemorySpaceType>::ArrayView()
: starting_element(nullptr)
, n_elements(0)
: StridedArrayView<ElementType, 1>(nullptr, 0)
{}


Expand All @@ -400,8 +488,7 @@ template <typename ElementType, typename MemorySpaceType>
inline ArrayView<ElementType, MemorySpaceType>::ArrayView(
value_type *starting_element,
const std::size_t n_elements)
: starting_element(starting_element)
, n_elements(n_elements)
: StridedArrayView<ElementType, 1>(starting_element, n_elements)
{}


Expand All @@ -419,17 +506,15 @@ ArrayView<ElementType, MemorySpaceType>::reinit(value_type *starting_element,

template <typename ElementType, typename MemorySpaceType>
inline ArrayView<ElementType, MemorySpaceType>::ArrayView(ElementType &element)
: starting_element(&element)
, n_elements(1)
: StridedArrayView<ElementType, 1>(&element, 1)
{}



template <typename ElementType, typename MemorySpaceType>
inline ArrayView<ElementType, MemorySpaceType>::ArrayView(
const ArrayView<std::remove_cv_t<value_type>, MemorySpaceType> &view)
: starting_element(view.starting_element)
, n_elements(view.n_elements)
: StridedArrayView<ElementType, 1>(view.starting_element, view.n_elements)
{}


Expand Down Expand Up @@ -516,8 +601,8 @@ inline bool
ArrayView<ElementType, MemorySpaceType>::operator==(
const ArrayView<const value_type, MemorySpaceType> &other_view) const
{
return (other_view.data() == starting_element) &&
(other_view.size() == n_elements);
return (other_view.data() == this->starting_element) &&
(other_view.size() == this->n_elements);
}


Expand All @@ -528,8 +613,8 @@ ArrayView<ElementType, MemorySpaceType>::operator==(
const ArrayView<std::remove_cv_t<value_type>, MemorySpaceType> &other_view)
const
{
return (other_view.data() == starting_element) &&
(other_view.size() == n_elements);
return (other_view.data() == this->starting_element) &&
(other_view.size() == this->n_elements);
}


Expand All @@ -544,19 +629,6 @@ ArrayView<ElementType, MemorySpaceType>::operator!=(



template <typename ElementType, typename MemorySpaceType>
inline DEAL_II_HOST_DEVICE
typename ArrayView<ElementType, MemorySpaceType>::value_type *
ArrayView<ElementType, MemorySpaceType>::data() const noexcept
{
if (n_elements == 0)
return nullptr;
else
return starting_element;
}



template <typename ElementType, typename MemorySpaceType>
inline bool
ArrayView<ElementType, MemorySpaceType>::operator!=(
Expand All @@ -568,29 +640,11 @@ ArrayView<ElementType, MemorySpaceType>::operator!=(



template <typename ElementType, typename MemorySpaceType>
inline std::size_t
ArrayView<ElementType, MemorySpaceType>::size() const
{
return n_elements;
}



template <typename ElementType, typename MemorySpaceType>
inline bool
ArrayView<ElementType, MemorySpaceType>::empty() const
{
return n_elements == 0;
}



template <typename ElementType, typename MemorySpaceType>
inline typename ArrayView<ElementType, MemorySpaceType>::iterator
ArrayView<ElementType, MemorySpaceType>::begin() const
{
return starting_element;
return this->starting_element;
}


Expand All @@ -599,7 +653,7 @@ template <typename ElementType, typename MemorySpaceType>
inline typename ArrayView<ElementType, MemorySpaceType>::iterator
ArrayView<ElementType, MemorySpaceType>::end() const
{
return starting_element + n_elements;
return this->starting_element + this->n_elements;
}


Expand All @@ -608,7 +662,7 @@ template <typename ElementType, typename MemorySpaceType>
inline typename ArrayView<ElementType, MemorySpaceType>::const_iterator
ArrayView<ElementType, MemorySpaceType>::cbegin() const
{
return starting_element;
return this->starting_element;
}


Expand All @@ -617,18 +671,7 @@ template <typename ElementType, typename MemorySpaceType>
inline typename ArrayView<ElementType, MemorySpaceType>::const_iterator
ArrayView<ElementType, MemorySpaceType>::cend() const
{
return starting_element + n_elements;
}



template <typename ElementType, typename MemorySpaceType>
inline typename ArrayView<ElementType, MemorySpaceType>::value_type &
ArrayView<ElementType, MemorySpaceType>::operator[](const std::size_t i) const
{
AssertIndexRange(i, n_elements);

return *(starting_element + i);
return this->starting_element + this->n_elements;
}


Expand Down

0 comments on commit 3d9fd71

Please sign in to comment.