Skip to content

Commit

Permalink
FEEval: do not cast const away during read_dof_values()
Browse files Browse the repository at this point in the history
  • Loading branch information
peterrum committed Jan 15, 2022
1 parent 3447d0a commit a6ec42c
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 25 deletions.
93 changes: 69 additions & 24 deletions include/deal.II/matrix_free/fe_evaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -3270,17 +3270,35 @@ FEEvaluationBase<dim, n_components_, Number, is_face, VectorizedArrayType>::

namespace internal
{
// given a block vector return the underlying vector type
// including constness (specified by bool)
template <typename VectorType, bool>
struct ConstBlockVectorSelector;

template <typename VectorType>
struct ConstBlockVectorSelector<VectorType, true>
{
using BaseVectorType = const typename VectorType::BlockType;
};

template <typename VectorType>
struct ConstBlockVectorSelector<VectorType, false>
{
using BaseVectorType = typename VectorType::BlockType;
};

// allows to select between block vectors and non-block vectors, which
// allows to use a unified interface for extracting blocks on block vectors
// and doing nothing on usual vectors
template <typename VectorType, bool>
struct BlockVectorSelector
{};
struct BlockVectorSelector;

template <typename VectorType>
struct BlockVectorSelector<VectorType, true>
{
using BaseVectorType = typename VectorType::BlockType;
using BaseVectorType = typename ConstBlockVectorSelector<
VectorType,
std::is_const<VectorType>::value>::BaseVectorType;

static BaseVectorType *
get_vector_component(VectorType &vec, const unsigned int component)
Expand Down Expand Up @@ -3329,6 +3347,20 @@ namespace internal
}
};

template <typename VectorType>
struct BlockVectorSelector<const std::vector<VectorType>, false>
{
using BaseVectorType = const VectorType;

static const BaseVectorType *
get_vector_component(const std::vector<VectorType> &vec,
const unsigned int component)
{
AssertIndexRange(component, vec.size());
return &vec[component];
}
};

template <typename VectorType>
struct BlockVectorSelector<std::vector<VectorType *>, false>
{
Expand All @@ -3342,6 +3374,20 @@ namespace internal
return vec[component];
}
};

template <typename VectorType>
struct BlockVectorSelector<const std::vector<VectorType *>, false>
{
using BaseVectorType = const VectorType;

static const BaseVectorType *
get_vector_component(const std::vector<VectorType *> &vec,
const unsigned int component)
{
AssertIndexRange(component, vec.size());
return vec[component];
}
};
} // namespace internal


Expand Down Expand Up @@ -4183,15 +4229,13 @@ namespace internal
template <int n_components, typename VectorType>
std::pair<
std::array<typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::
value>::BaseVectorType *,
VectorType,
IsBlockVector<VectorType>::value>::BaseVectorType *,
n_components>,
std::array<
const std::vector<ArrayView<const typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::value>::
BaseVectorType::value_type>> *,
VectorType,
IsBlockVector<VectorType>::value>::BaseVectorType::value_type>> *,
n_components>>
get_vector_data(VectorType & src,
const unsigned int first_index,
Expand All @@ -4203,32 +4247,33 @@ namespace internal
// of components is checked in the internal data
std::pair<
std::array<typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::
value>::BaseVectorType *,
VectorType,
IsBlockVector<VectorType>::value>::BaseVectorType *,
n_components>,
std::array<
const std::vector<
ArrayView<const typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::
value>::BaseVectorType::value_type>> *,
VectorType,
IsBlockVector<VectorType>::value>::BaseVectorType::value_type>> *,
n_components>>
src_data;

for (unsigned int d = 0; d < n_components; ++d)
src_data.first[d] = internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::value>::
get_vector_component(
const_cast<typename std::remove_const<VectorType>::type &>(src),
d + first_index);
VectorType,
IsBlockVector<VectorType>::value>::get_vector_component(src,
d +
first_index);

for (unsigned int d = 0; d < n_components; ++d)
src_data.second[d] = get_shared_vector_data(*src_data.first[d],
is_valid_mode_for_sm,
active_fe_index,
dof_info);
src_data.second[d] = get_shared_vector_data(
const_cast<typename internal::BlockVectorSelector<
typename std::remove_const<VectorType>::type,
IsBlockVector<typename std::remove_const<VectorType>::type>::value>::
BaseVectorType &>(*src_data.first[d]),
is_valid_mode_for_sm,
active_fe_index,
dof_info);

return src_data;
}
Expand Down
3 changes: 2 additions & 1 deletion include/deal.II/matrix_free/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ namespace internal
{
static const bool value =
has_begin<T>::value &&
(has_local_element<T>::value || is_serial_vector<T>::value) &&
(has_local_element<T>::value ||
is_serial_vector<typename std::remove_const<T>::type>::value) &&
std::is_same<typename T::value_type, Number>::value;
};

Expand Down
20 changes: 20 additions & 0 deletions include/deal.II/matrix_free/vector_access_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,20 @@ namespace internal
VectorizedArrayType *dof_values,
std::integral_constant<bool, true>) const
{
#ifdef DEBUG
// in debug mode, run non-vectorized version because this path
// has additional checks (e.g., regarding ghosting)
process_dofs_vectorized(dofs_per_cell,
dof_index,
vec,
dof_values,
std::integral_constant<bool, false>());
#else
const Number *vec_ptr = vec.begin() + dof_index;
for (unsigned int i = 0; i < dofs_per_cell;
++i, vec_ptr += VectorizedArrayType::size())
dof_values[i].load(vec_ptr);
#endif
}


Expand Down Expand Up @@ -340,7 +350,17 @@ namespace internal
VectorizedArrayType &res,
std::integral_constant<bool, true>) const
{
#ifdef DEBUG
// in debug mode, run non-vectorized version because this path
// has additional checks (e.g., regarding ghosting)
process_dof_gather(indices,
vec,
constant_offset,
res,
std::integral_constant<bool, false>());
#else
res.gather(vec.begin() + constant_offset, indices);
#endif
}


Expand Down

0 comments on commit a6ec42c

Please sign in to comment.