Skip to content

Commit

Permalink
FEEval: do not cast const away during read_dof_values()
Browse files Browse the repository at this point in the history
  • Loading branch information
peterrum committed Jan 14, 2022
1 parent 3447d0a commit 4a44987
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 23 deletions.
89 changes: 67 additions & 22 deletions include/deal.II/matrix_free/fe_evaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -3270,6 +3270,22 @@ FEEvaluationBase<dim, n_components_, Number, is_face, VectorizedArrayType>::

namespace internal
{
template <typename VectorType, bool>
struct ConstBlockVectorSelector
{};

template <typename VectorType>
struct ConstBlockVectorSelector<VectorType, true>
{
using BaseVectorType = const typename VectorType::BlockType;
};

template <typename VectorType>
struct ConstBlockVectorSelector<VectorType, false>
{
using BaseVectorType = typename VectorType::BlockType;
};

// allows to select between block vectors and non-block vectors, which
// allows to use a unified interface for extracting blocks on block vectors
// and doing nothing on usual vectors
Expand All @@ -3280,7 +3296,9 @@ namespace internal
template <typename VectorType>
struct BlockVectorSelector<VectorType, true>
{
using BaseVectorType = typename VectorType::BlockType;
using BaseVectorType = typename ConstBlockVectorSelector<
VectorType,
std::is_const<VectorType>::value>::BaseVectorType;

static BaseVectorType *
get_vector_component(VectorType &vec, const unsigned int component)
Expand Down Expand Up @@ -3329,6 +3347,20 @@ namespace internal
}
};

template <typename VectorType>
struct BlockVectorSelector<const std::vector<VectorType>, false>
{
using BaseVectorType = const VectorType;

static const BaseVectorType *
get_vector_component(const std::vector<VectorType> &vec,
const unsigned int component)
{
AssertIndexRange(component, vec.size());
return &vec[component];
}
};

template <typename VectorType>
struct BlockVectorSelector<std::vector<VectorType *>, false>
{
Expand All @@ -3342,6 +3374,20 @@ namespace internal
return vec[component];
}
};

template <typename VectorType>
struct BlockVectorSelector<const std::vector<VectorType *>, false>
{
using BaseVectorType = const VectorType;

static const BaseVectorType *
get_vector_component(const std::vector<VectorType *> &vec,
const unsigned int component)
{
AssertIndexRange(component, vec.size());
return vec[component];
}
};
} // namespace internal


Expand Down Expand Up @@ -4183,15 +4229,13 @@ namespace internal
template <int n_components, typename VectorType>
std::pair<
std::array<typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::
value>::BaseVectorType *,
VectorType,
IsBlockVector<VectorType>::value>::BaseVectorType *,
n_components>,
std::array<
const std::vector<ArrayView<const typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::value>::
BaseVectorType::value_type>> *,
VectorType,
IsBlockVector<VectorType>::value>::BaseVectorType::value_type>> *,
n_components>>
get_vector_data(VectorType & src,
const unsigned int first_index,
Expand All @@ -4203,32 +4247,33 @@ namespace internal
// of components is checked in the internal data
std::pair<
std::array<typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::
value>::BaseVectorType *,
VectorType,
IsBlockVector<VectorType>::value>::BaseVectorType *,
n_components>,
std::array<
const std::vector<
ArrayView<const typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::
value>::BaseVectorType::value_type>> *,
VectorType,
IsBlockVector<VectorType>::value>::BaseVectorType::value_type>> *,
n_components>>
src_data;

for (unsigned int d = 0; d < n_components; ++d)
src_data.first[d] = internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::value>::
get_vector_component(
const_cast<typename std::remove_const<VectorType>::type &>(src),
d + first_index);
VectorType,
IsBlockVector<VectorType>::value>::get_vector_component(src,
d +
first_index);

for (unsigned int d = 0; d < n_components; ++d)
src_data.second[d] = get_shared_vector_data(*src_data.first[d],
is_valid_mode_for_sm,
active_fe_index,
dof_info);
src_data.second[d] = get_shared_vector_data(
const_cast<typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::value>::
BaseVectorType &>(*src_data.first[d]),
is_valid_mode_for_sm,
active_fe_index,
dof_info);

return src_data;
}
Expand Down
3 changes: 2 additions & 1 deletion include/deal.II/matrix_free/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ namespace internal
{
static const bool value =
has_begin<T>::value &&
(has_local_element<T>::value || is_serial_vector<T>::value) &&
(has_local_element<T>::value ||
is_serial_vector<typename std::remove_const<T>::type>::value) &&
std::is_same<typename T::value_type, Number>::value;
};

Expand Down
11 changes: 11 additions & 0 deletions include/deal.II/matrix_free/vector_access_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ namespace internal
VectorizedArrayType *dof_values,
std::integral_constant<bool, true>) const
{
#ifdef DEBUG
for (unsigned int i = 0; i < dofs_per_cell; ++i)
for (unsigned int v = 0; v < VectorizedArrayType::size(); ++v)
vector_access(vec, dof_index + v + i * VectorizedArrayType::size());
#endif

const Number *vec_ptr = vec.begin() + dof_index;
for (unsigned int i = 0; i < dofs_per_cell;
++i, vec_ptr += VectorizedArrayType::size())
Expand Down Expand Up @@ -340,6 +346,11 @@ namespace internal
VectorizedArrayType &res,
std::integral_constant<bool, true>) const
{
#ifdef DEBUG
for (unsigned int v = 0; v < VectorizedArrayType::size(); ++v)
vector_access(vec, indices[v] + constant_offset);
#endif

res.gather(vec.begin() + constant_offset, indices);
}

Expand Down

0 comments on commit 4a44987

Please sign in to comment.