Skip to content

Commit

Permalink
Update due to namespace change; apply review suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
mhoemmen committed Apr 2, 2024
1 parent 62f409e commit 457b822
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 62 deletions.
50 changes: 15 additions & 35 deletions examples/for_each_extents/for_each_extents.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <experimental/mdspan>
#include <mdspan/mdspan.hpp>
#include <cassert>
#include <iostream>
#include <type_traits>
Expand Down Expand Up @@ -26,42 +26,22 @@
#include <ranges>
namespace ranges_views = std::views;

namespace stdex = std::experimental;

auto print_args = [] <class ... Args> (Args&&... args) {
((std::cout << std::forward<Args>(args) << '\n'), ...);
};

template<size_t ... Lefts, size_t ... Rights>
auto concat_index_sequence(std::index_sequence<Lefts...>,
std::index_sequence<Rights...>)
{
return std::index_sequence<Lefts..., Rights...>{};
}

auto reverse_index_sequence(std::index_sequence<> x)
{
return x;
}
template <std::size_t... Is>
auto reverse(std::index_sequence<Is...>) ->
std::index_sequence<sizeof...(Is) - 1 - Is...>;

template<size_t First, size_t ... Rest>
auto reverse_index_sequence(std::index_sequence<First, Rest...>)
{
return concat_index_sequence(
reverse_index_sequence(std::index_sequence<Rest...>{}),
std::index_sequence<First>{});
}

template<size_t N>
auto make_reverse_index_sequence()
{
return reverse_index_sequence(std::make_index_sequence<N>());
}
template <std::size_t N>
using reverse_index_sequence_t =
decltype(reverse(std::make_index_sequence<N>()));

template<class Callable, class IndexType,
std::size_t ... Extents, std::size_t ... RankIndices>
void for_each_in_extents_impl(Callable&& f,
stdex::extents<IndexType, Extents...> e,
Kokkos::extents<IndexType, Extents...> e,
std::index_sequence<RankIndices...> rank_sequence)
{
// In the layout_left case, caller passes in N-1, N-2, ..., 1, 0.
Expand All @@ -82,17 +62,17 @@ void for_each_in_extents_impl(Callable&& f,

template<class Callable, class IndexType, std::size_t ... Extents, class Layout>
void for_each_in_extents(Callable&& f,
stdex::extents<IndexType, Extents...> e,
Kokkos::extents<IndexType, Extents...> e,
Layout)
{
using layout_type = std::remove_cvref_t<Layout>;
if constexpr (std::is_same_v<layout_type, stdex::layout_left>) {
if constexpr (std::is_same_v<layout_type, Kokkos::layout_left>) {
for_each_in_extents_impl(std::forward<Callable>(f), e,
make_reverse_index_sequence<e.rank()>());
reverse_index_sequence_t<e.rank()>{});
}
else { // layout_right or any other layout
for_each_in_extents_impl(std::forward<Callable>(f), e,
std::make_index_sequence<e.rank()>());
reverse_index_sequence_t<e.rank()>{});
}
}

Expand All @@ -101,14 +81,14 @@ void for_each_in_extents(Callable&& f,
int main() {

#if defined(MDSPAN_EXAMPLE_CAN_USE_STD_RANGES)
stdex::extents<int, 2, 3> e;
Kokkos::extents<int, 2, 3> e;
auto printer = [] (int i, int j) {
std::cout << "(" << i << "," << j << ")\n";
};
std::cout << "layout_right:\n";
for_each_in_extents(printer, e, stdex::layout_right{});
for_each_in_extents(printer, e, Kokkos::layout_right{});
std::cout << "\nlayout_left:\n";
for_each_in_extents(printer, e, stdex::layout_left{});
for_each_in_extents(printer, e, Kokkos::layout_left{});
#endif // defined(MDSPAN_EXAMPLE_CAN_USE_STD_RANGES)

return 0;
Expand Down
52 changes: 25 additions & 27 deletions examples/for_each_extents/for_each_extents_no_ranges.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#include <experimental/mdspan>
#include <mdspan/mdspan.hpp>
#include <array>
#include <iostream>
#include <tuple>
#include <type_traits>

namespace stdex = std::experimental;

// There's no separate feature test macro for the C++20 feature
// of lambdas with named template parameters (P0428R2).
#if __cplusplus >= 202002L
Expand Down Expand Up @@ -80,11 +78,11 @@ auto print_pack = []<class ... InputTypes>(InputTypes&& ... input) {
// This example shows that you can do
// index arithmetic on an index sequence.
template<class IndexType, std::size_t ... Extents>
auto right_extents( stdex::extents<IndexType, Extents...> e )
auto right_extents( Kokkos::extents<IndexType, Extents...> e )
{
static_assert(sizeof...(Extents) != 0);
return [&]<std::size_t ... Indices>( std::index_sequence<Indices...> ) {
return stdex::extents<IndexType, e.static_extent(Indices + 1)...>{
return Kokkos::extents<IndexType, e.static_extent(Indices + 1)...>{
e.extent(Indices + 1)...
};
}( std::make_index_sequence<sizeof...(Extents) - 1>() );
Expand All @@ -101,10 +99,10 @@ auto right_extents( stdex::extents<IndexType, Extents...> e )
// This needs to be a lambda or function object,
// not a templated function.
auto split_extents_at_leftmost =
[]<class IndexType, std::size_t... Extents>(stdex::extents<IndexType, Extents...> e)
[]<class IndexType, std::size_t... Extents>(Kokkos::extents<IndexType, Extents...> e)
{
static_assert(sizeof...(Extents) != 0);
stdex::extents<IndexType, e.static_extent(0)> left_ext(
Kokkos::extents<IndexType, e.static_extent(0)> left_ext(
e.extent(0));
return std::tuple{left_ext, right_extents(e)};
};
Expand All @@ -116,22 +114,22 @@ auto split_extents_at_leftmost =
// Returns a new extents object representing
// all but the rightmost extent of e.
template<class IndexType, std::size_t ... Extents>
auto left_extents( stdex::extents<IndexType, Extents...> e )
auto left_extents( Kokkos::extents<IndexType, Extents...> e )
{
static_assert(sizeof...(Extents) != 0);
return [&]<std::size_t ... Indices>( std::index_sequence<Indices...> ) {
return stdex::extents<IndexType, e.static_extent(Indices)...>{
return Kokkos::extents<IndexType, e.static_extent(Indices)...>{
e.extent(Indices)...
};
}( std::make_index_sequence<sizeof...(Extents) - 1>() );
}

// This needs to be a lambda or function object, not a templated function.
auto split_extents_at_rightmost =
[]<class IndexType, std::size_t ... Extents>(stdex::extents<IndexType, Extents...> e)
[]<class IndexType, std::size_t ... Extents>(Kokkos::extents<IndexType, Extents...> e)
{
static_assert(sizeof...(Extents) != 0);
stdex::extents<IndexType, e.static_extent(e.rank() - 1)> right_ext(
Kokkos::extents<IndexType, e.static_extent(e.rank() - 1)> right_ext(
e.extent(e.rank() - 1));
return std::tuple{left_extents(e), right_ext};
};
Expand All @@ -149,10 +147,10 @@ auto split_extents_at_rightmost =
// optimization information -- e.g., whether we want
// to apply "#pragma omp simd" to a particular extent.
template<class Callable, class IndexType, std::size_t Extent>
void for_each_one_extent(Callable&& callable, stdex::extents<IndexType, Extent> ext)
void for_each_one_extent(Callable&& callable, Kokkos::extents<IndexType, Extent> ext)
{
// If it's a run-time extent, do a run-time loop.
if constexpr(ext.static_extent(0) == stdex::dynamic_extent) {
if constexpr(ext.static_extent(0) == Kokkos::dynamic_extent) {
for(IndexType index = 0; index < ext.extent(0); ++index) {
std::forward<Callable>(callable)(index);
}
Expand All @@ -176,7 +174,7 @@ void for_each_one_extent(Callable&& callable, stdex::extents<IndexType, Extent>
template<class Callable, class IndexType, std::size_t ... Extents>
void for_each_in_extents_row_major(
Callable&& callable,
stdex::extents<IndexType, Extents...> ext)
Kokkos::extents<IndexType, Extents...> ext)
{
if constexpr(ext.rank() == 0) {
return;
Expand All @@ -203,12 +201,12 @@ void for_each_in_extents_row_major(
// The implementation differs in only two places from the row-major version.
// This suggests a way to generalize.
//
// Overloading on stdex::extents<IndexType, LeftExtents..., RightExtent>
// Overloading on extents<IndexType, LeftExtents..., RightExtent>
// works fine for the row major case, but not for the column major case.
template<class Callable, class IndexType, std::size_t ... Extents>
void for_each_in_extents_col_major(
Callable&& callable,
stdex::extents<IndexType, Extents...> ext)
Kokkos::extents<IndexType, Extents...> ext)
{
if constexpr(ext.rank() == 0) {
return;
Expand Down Expand Up @@ -242,7 +240,7 @@ void for_each_in_extents_col_major(
template<class Callable, class IndexType, std::size_t ... Extents,
class ExtentsReorderer, class ExtentsSplitter, class IndicesReorderer>
void for_each_in_extents_impl(Callable&& callable,
stdex::extents<IndexType, Extents...> ext,
Kokkos::extents<IndexType, Extents...> ext,
ExtentsReorderer reorder_extents,
ExtentsSplitter split_extents,
IndicesReorderer reorder_indices)
Expand Down Expand Up @@ -280,18 +278,18 @@ void for_each_in_extents_impl(Callable&& callable,
}

auto extents_identity = []<class IndexType, std::size_t ... Extents>(
stdex::extents<IndexType, Extents...> ext)
Kokkos::extents<IndexType, Extents...> ext)
{
return ext;
};

auto extents_reverse = []<class IndexType, std::size_t ... Extents>(
stdex::extents<IndexType, Extents...> ext)
Kokkos::extents<IndexType, Extents...> ext)
{
constexpr std::size_t N = ext.rank();

return [&]<std::size_t ... Indices>( std::index_sequence<Indices...> ) {
return stdex::extents<
return Kokkos::extents<
IndexType,
ext.static_extent(N - 1 - Indices)...
>{ ext.extent(N - 1 - Indices)... };
Expand Down Expand Up @@ -325,8 +323,8 @@ auto indices_reverse = [](auto... args) {
// Row-major iteration
template<class Callable, class IndexType, std::size_t ... Extents>
void for_each_in_extents(Callable&& callable,
stdex::extents<IndexType, Extents...> ext,
stdex::layout_right)
Kokkos::extents<IndexType, Extents...> ext,
Kokkos::layout_right)
{
for_each_in_extents_impl(std::forward<Callable>(callable), ext,
extents_identity, split_extents_at_leftmost, indices_identity);
Expand All @@ -335,8 +333,8 @@ void for_each_in_extents(Callable&& callable,
// Column-major iteration
template<class Callable, class IndexType, std::size_t ... Extents>
void for_each_in_extents(Callable&& callable,
stdex::extents<IndexType, Extents...> ext,
stdex::layout_left)
Kokkos::extents<IndexType, Extents...> ext,
Kokkos::layout_left)
{
for_each_in_extents_impl(std::forward<Callable>(callable), ext,
extents_reverse, split_extents_at_rightmost, indices_reverse);
Expand All @@ -349,7 +347,7 @@ int main() {
#if ! defined(__clang__) && defined(MDSPAN_EXAMPLE_CAN_USE_LAMBDA_TEMPLATE_PARAM_LIST)
// The functions work for any combination
// of compile-time or run-time extents.
stdex::extents<int, 3, stdex::dynamic_extent, 5> e{4};
Kokkos::extents<int, 3, Kokkos::dynamic_extent, 5> e{4};

std::cout << "\nRow major:\n";
for_each_in_extents_row_major(print_pack, e);
Expand All @@ -358,10 +356,10 @@ int main() {
for_each_in_extents_col_major(print_pack, e);

std::cout << "\nfor_each_in_extents: row major:\n";
for_each_in_extents(print_pack, e, stdex::layout_right{});
for_each_in_extents(print_pack, e, Kokkos::layout_right{});

std::cout << "\nfor_each_in_extents: column major:\n";
for_each_in_extents(print_pack, e, stdex::layout_left{});
for_each_in_extents(print_pack, e, Kokkos::layout_left{});
#endif // defined(MDSPAN_EXAMPLE_CAN_USE_LAMBDA_TEMPLATE_PARAM_LIST)

return 0;
Expand Down

0 comments on commit 457b822

Please sign in to comment.