Skip to content

Commit

Permalink
kokkos#6805: fixes to strided layout and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nmm0 committed Mar 7, 2024
1 parent cb3f579 commit 6401ce2
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 89 deletions.
81 changes: 79 additions & 2 deletions core/src/View/MDSpan/Kokkos_MDSpan_Layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ struct ArrayLayoutFromLayout<Experimental::layout_right_padded<padding_value>> {
padding = {};
};

template <>
struct ArrayLayoutFromLayout<layout_stride> {
using type = Kokkos::LayoutStride;
};

template <class ArrayLayout>
struct LayoutFromArrayLayout;

Expand Down Expand Up @@ -75,8 +80,11 @@ struct ViewOffsetFromExtents {
Kokkos::Impl::ViewOffset<typename data_analysis::dimension, array_layout>;
};

template <class ArrayLayout>
struct ArrayLayoutFromMappingImpl;

template <class ArrayLayout, class MDSpanType>
KOKKOS_INLINE_FUNCTION auto array_layout_from_mapping(
KOKKOS_INLINE_FUNCTION auto array_layout_leftright_from_mapping_impl(
const typename MDSpanType::mapping_type &mapping) {
using mapping_type = typename MDSpanType::mapping_type;
using extents_type = typename mapping_type::extents_type;
Expand All @@ -97,12 +105,81 @@ KOKKOS_INLINE_FUNCTION auto array_layout_from_mapping(
rank > 7 ? dimension_from_extent(ext, 7) : KOKKOS_IMPL_CTOR_DEFAULT_ARG};
}

template <>
struct ArrayLayoutFromMappingImpl<Kokkos::LayoutLeft> {
template <class MDSpanType>
static Kokkos::LayoutLeft construct(
const typename MDSpanType::mapping_type &mapping) {
return array_layout_leftright_from_mapping_impl<Kokkos::LayoutLeft,
MDSpanType>(mapping);
}
};

template <>
struct ArrayLayoutFromMappingImpl<Kokkos::LayoutRight> {
template <class MDSpanType>
static Kokkos::LayoutRight construct(
const typename MDSpanType::mapping_type &mapping) {
return array_layout_leftright_from_mapping_impl<Kokkos::LayoutRight,
MDSpanType>(mapping);
}
};

template <>
struct ArrayLayoutFromMappingImpl<Kokkos::LayoutStride> {
template <class MDSpanType>
static Kokkos::LayoutStride construct(
const typename MDSpanType::mapping_type &mapping) {
using mapping_type = typename MDSpanType::mapping_type;
using extents_type = typename mapping_type::extents_type;

static constexpr auto rank = extents_type::rank();
const auto &ext = mapping.extents();

static_assert(rank <= ARRAY_LAYOUT_MAX_RANK,
"Unsupported rank for mdspan (must be <= 8)");
return Kokkos::LayoutStride{
rank > 0 ? dimension_from_extent(ext, 0) : KOKKOS_IMPL_CTOR_DEFAULT_ARG,
rank > 0 ? mapping.stride(0) : 0,
rank > 1 ? dimension_from_extent(ext, 1) : KOKKOS_IMPL_CTOR_DEFAULT_ARG,
rank > 1 ? mapping.stride(1) : 0,
rank > 2 ? dimension_from_extent(ext, 2) : KOKKOS_IMPL_CTOR_DEFAULT_ARG,
rank > 2 ? mapping.stride(2) : 0,
rank > 3 ? dimension_from_extent(ext, 3) : KOKKOS_IMPL_CTOR_DEFAULT_ARG,
rank > 3 ? mapping.stride(3) : 0,
rank > 4 ? dimension_from_extent(ext, 4) : KOKKOS_IMPL_CTOR_DEFAULT_ARG,
rank > 4 ? mapping.stride(4) : 0,
rank > 5 ? dimension_from_extent(ext, 5) : KOKKOS_IMPL_CTOR_DEFAULT_ARG,
rank > 5 ? mapping.stride(5) : 0,
rank > 6 ? dimension_from_extent(ext, 6) : KOKKOS_IMPL_CTOR_DEFAULT_ARG,
rank > 6 ? mapping.stride(6) : 0,
rank > 7 ? dimension_from_extent(ext, 7) : KOKKOS_IMPL_CTOR_DEFAULT_ARG,
rank > 7 ? mapping.stride(7) : 0,
};
}
};

template <class ArrayLayout, class MDSpanType>
KOKKOS_INLINE_FUNCTION auto array_layout_from_mapping(
const typename MDSpanType::mapping_type &mapping) {
return ArrayLayoutFromMappingImpl<ArrayLayout>::template construct<MDSpanType>(
mapping);
}

template <class MDSpanType, class VM>
KOKKOS_INLINE_FUNCTION auto mapping_from_view_mapping(const VM &view_mapping) {
using mapping_type = typename MDSpanType::mapping_type;
using extents_type = typename mapping_type::extents_type;

return mapping_type(extents_from_view_mapping<extents_type>(view_mapping));
if constexpr (std::is_same_v<typename mapping_type::layout_type,
Kokkos::layout_stride>) {
std::array<std::size_t, VM::Rank> strides;
view_mapping.stride(strides.data());
return mapping_type(extents_from_view_mapping<extents_type>(view_mapping),
strides);
} else {
return mapping_type(extents_from_view_mapping<extents_type>(view_mapping));
}
}

template <class ElementType, class Extents, class LayoutPolicy,
Expand Down

0 comments on commit 6401ce2

Please sign in to comment.