Skip to content

Commit

Permalink
Implement HPX::in_parallel (kokkos#6143)
Browse files Browse the repository at this point in the history
* Implement HPX::in_parallel

* Don't use gtest macros in parallel regions in HPX in_parallel test
  • Loading branch information
msimberg committed May 30, 2023
1 parent e88537f commit ab6f756
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 12 deletions.
25 changes: 25 additions & 0 deletions core/src/HPX/Kokkos_HPX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,31 @@ void HPX::print_configuration(std::ostream &os, const bool) const {
os << hpx::configuration_string() << '\n';
}

bool &HPX::impl_get_in_parallel() noexcept {
static thread_local bool in_parallel = false;
return in_parallel;
}

HPX::impl_in_parallel_scope::impl_in_parallel_scope() noexcept {
KOKKOS_EXPECTS(!impl_get_in_parallel());
impl_get_in_parallel() = true;
}

HPX::impl_in_parallel_scope::~impl_in_parallel_scope() noexcept {
KOKKOS_EXPECTS(impl_get_in_parallel());
impl_get_in_parallel() = false;
}

HPX::impl_not_in_parallel_scope::impl_not_in_parallel_scope() noexcept {
KOKKOS_EXPECTS(impl_get_in_parallel());
impl_get_in_parallel() = false;
}

HPX::impl_not_in_parallel_scope::~impl_not_in_parallel_scope() noexcept {
KOKKOS_EXPECTS(!impl_get_in_parallel());
impl_get_in_parallel() = true;
}

void HPX::impl_decrement_active_parallel_region_count() {
std::unique_lock<hpx::spinlock> l(m_active_parallel_region_count_mutex);
if (--m_active_parallel_region_count == 0) {
Expand Down
81 changes: 69 additions & 12 deletions core/src/HPX/Kokkos_HPX.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,30 @@ class HPX {
return impl_get_instance_data().m_instance_id;
}

static bool &impl_get_in_parallel() noexcept;

struct impl_in_parallel_scope {
impl_in_parallel_scope() noexcept;
~impl_in_parallel_scope() noexcept;
impl_in_parallel_scope(impl_in_parallel_scope &&) = delete;
impl_in_parallel_scope(impl_in_parallel_scope const &) = delete;
impl_in_parallel_scope &operator=(impl_in_parallel_scope &&) = delete;
impl_in_parallel_scope &operator=(impl_in_parallel_scope const &) = delete;
};

struct impl_not_in_parallel_scope {
impl_not_in_parallel_scope() noexcept;
~impl_not_in_parallel_scope() noexcept;
impl_not_in_parallel_scope(impl_not_in_parallel_scope &&) = delete;
impl_not_in_parallel_scope(impl_not_in_parallel_scope const &) = delete;
impl_not_in_parallel_scope &operator=(impl_not_in_parallel_scope &&) =
delete;
impl_not_in_parallel_scope &operator=(impl_not_in_parallel_scope const &) =
delete;
};

static bool in_parallel(HPX const & = HPX()) noexcept {
// TODO: Very awkward to keep track of. What should this really return?
return false;
return impl_get_in_parallel();
}

static void impl_decrement_active_parallel_region_count();
Expand Down Expand Up @@ -333,7 +354,10 @@ class HPX {
hpx::threads::thread_stacksize stacksize =
hpx::threads::thread_stacksize::default_) const {
impl_bulk_plain_erased(force_synchronous, is_light_weight_policy,
{[functor](Index i) { functor.execute_range(i); }},
{[functor](Index i) {
impl_in_parallel_scope p;
functor.execute_range(i);
}},
n, stacksize);
}

Expand Down Expand Up @@ -391,11 +415,20 @@ class HPX {
Functor const &functor, Index const n,
hpx::threads::thread_stacksize stacksize =
hpx::threads::thread_stacksize::default_) const {
impl_bulk_setup_finalize_erased(
force_synchronous, is_light_weight_policy,
{[functor](Index i) { functor.execute_range(i); }},
{[functor]() { functor.setup(); }},
{[functor]() { functor.finalize(); }}, n, stacksize);
impl_bulk_setup_finalize_erased(force_synchronous, is_light_weight_policy,
{[functor](Index i) {
impl_in_parallel_scope p;
functor.execute_range(i);
}},
{[functor]() {
impl_in_parallel_scope p;
functor.setup();
}},
{[functor]() {
impl_in_parallel_scope p;
functor.finalize();
}},
n, stacksize);
}

static constexpr const char *name() noexcept { return "HPX"; }
Expand Down Expand Up @@ -1259,7 +1292,13 @@ class ParallelScan<FunctorType, Kokkos::RangePolicy<Traits...>,
const WorkRange range(m_policy, t, num_worker_threads);
execute_chunk(range.begin(), range.end(), update_sum, false);

barrier.arrive_and_wait();
{
// Since arrive_and_wait may yield and resume on another worker thread we
// set in_parallel = false on the current thread before suspending and set
// it again to true when we resume.
Kokkos::Experimental::HPX::impl_not_in_parallel_scope p;
barrier.arrive_and_wait();
}

if (t == 0) {
final_reducer.init(reinterpret_cast<pointer_type>(
Expand All @@ -1281,7 +1320,13 @@ class ParallelScan<FunctorType, Kokkos::RangePolicy<Traits...>,
}
}

barrier.arrive_and_wait();
{
// Since arrive_and_wait may yield and resume on another worker thread we
// set in_parallel = false on the current thread before suspending and set
// it again to true when we resume.
Kokkos::Experimental::HPX::impl_not_in_parallel_scope p;
barrier.arrive_and_wait();
}

reference_type update_base =
Analysis::Reducer::reference(reinterpret_cast<pointer_type>(
Expand Down Expand Up @@ -1362,7 +1407,13 @@ class ParallelScanWithTotal<FunctorType, Kokkos::RangePolicy<Traits...>,
const WorkRange range(m_policy, t, num_worker_threads);
execute_chunk(range.begin(), range.end(), update_sum, false);

barrier.arrive_and_wait();
{
// Since arrive_and_wait may yield and resume on another worker thread we
// set in_parallel = false on the current thread before suspending and set
// it again to true when we resume.
Kokkos::Experimental::HPX::impl_not_in_parallel_scope p;
barrier.arrive_and_wait();
}

if (t == 0) {
final_reducer.init(reinterpret_cast<pointer_type>(
Expand All @@ -1384,7 +1435,13 @@ class ParallelScanWithTotal<FunctorType, Kokkos::RangePolicy<Traits...>,
}
}

barrier.arrive_and_wait();
{
// Since arrive_and_wait may yield and resume on another worker thread we
// set in_parallel = false on the current thread before suspending and set
// it again to true when we resume.
Kokkos::Experimental::HPX::impl_not_in_parallel_scope p;
barrier.arrive_and_wait();
}

reference_type update_base =
Analysis::Reducer::reference(reinterpret_cast<pointer_type>(
Expand Down
6 changes: 6 additions & 0 deletions core/unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ if(Kokkos_ENABLE_HPX)
hpx/TestHPX_IndependentInstancesRefCounting.cpp
hpx/TestHPX_IndependentInstancesSynchronization.cpp
)
KOKKOS_ADD_EXECUTABLE_AND_TEST(
CoreUnitTest_HPX_InParallel
SOURCES
UnitTestMainInit.cpp
hpx/TestHPX_InParallel.cpp
)
endif()

if(Kokkos_ENABLE_OPENMPTARGET)
Expand Down
183 changes: 183 additions & 0 deletions core/unit_test/hpx/TestHPX_InParallel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <gtest/gtest.h>
#include <Kokkos_Core.hpp>

// These tests specifically check that work dispatched to independent instances
// is synchronized correctly on fences. A previous bug that this protects
// against is work being mistakenly dispatched to the default instance, but the
// fence fencing the independent instance. In that case these tests will fail.

namespace {
inline constexpr int n = 1 << 10;

TEST(hpx, in_parallel_for_range_policy) {
Kokkos::View<bool *, Kokkos::Experimental::HPX> a("a", n);

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

Kokkos::RangePolicy<Kokkos::Experimental::HPX> policy(0, n);
Kokkos::parallel_for(
"parallel_for_range_policy", policy, KOKKOS_LAMBDA(const int i) {
a(i) = Kokkos::Experimental::HPX::in_parallel();
});

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());
Kokkos::fence();
ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

for (int i = 0; i < n; ++i) {
ASSERT_TRUE(a(i));
}
}

TEST(hpx, in_parallel_for_mdrange_policy) {
Kokkos::View<bool *, Kokkos::Experimental::HPX> a("a", n);

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

Kokkos::MDRangePolicy<Kokkos::Experimental::HPX, Kokkos::Rank<2>> policy(
{0, 0}, {n, 1});
Kokkos::parallel_for(
"parallel_for_mdrange_policy", policy,
KOKKOS_LAMBDA(const int i, const int) {
a(i) = Kokkos::Experimental::HPX::in_parallel();
});

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());
Kokkos::fence();
ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

for (int i = 0; i < n; ++i) {
ASSERT_TRUE(a(i));
}
}

TEST(hpx, in_parallel_for_team_policy) {
Kokkos::View<bool *, Kokkos::Experimental::HPX> a("a", n);

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

Kokkos::TeamPolicy<Kokkos::Experimental::HPX> policy(n, 1);
using member_type = decltype(policy)::member_type;
Kokkos::parallel_for(
"parallel_for_team_policy", policy,
KOKKOS_LAMBDA(const member_type &handle) {
a(handle.league_rank()) = Kokkos::Experimental::HPX::in_parallel();
});

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());
Kokkos::fence();
ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

for (int i = 0; i < n; ++i) {
ASSERT_TRUE(a(i));
}
}

TEST(hpx, in_parallel_reduce_range_policy) {
Kokkos::View<int, Kokkos::Experimental::HPX> a("a");

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

Kokkos::RangePolicy<Kokkos::Experimental::HPX> policy(0, n);
Kokkos::parallel_reduce(
"parallel_reduce_range_policy", policy,
KOKKOS_LAMBDA(const int, int &x) {
if (!Kokkos::Experimental::HPX::in_parallel()) {
++x;
}
},
a);

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());
Kokkos::fence();
ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

ASSERT_EQ(a(), 0);
}

TEST(hpx, in_parallel_reduce_mdrange_policy) {
Kokkos::View<int, Kokkos::Experimental::HPX> a("a");

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

Kokkos::MDRangePolicy<Kokkos::Experimental::HPX, Kokkos::Rank<2>> policy(
{0, 0}, {n, 1});
Kokkos::parallel_reduce(
"parallel_reduce_mdrange_policy", policy,
KOKKOS_LAMBDA(const int, const int, int &x) {
if (!Kokkos::Experimental::HPX::in_parallel()) {
++x;
}
},
a);

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());
Kokkos::fence();
ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

ASSERT_EQ(a(), 0);
}

TEST(hpx, in_parallel_reduce_team_policy) {
Kokkos::View<int, Kokkos::Experimental::HPX> a("a");

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

Kokkos::TeamPolicy<Kokkos::Experimental::HPX> policy(n, 1);
using member_type = decltype(policy)::member_type;
Kokkos::parallel_reduce(
"parallel_reduce_team_policy", policy,
KOKKOS_LAMBDA(const member_type &, int &x) {
if (!Kokkos::Experimental::HPX::in_parallel()) {
++x;
}
},
a);

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());
Kokkos::fence();
ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

ASSERT_EQ(a(), 0);
}

TEST(hpx, in_parallel_scan_range_policy) {
Kokkos::View<int *, Kokkos::Experimental::HPX> a("a", n);

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

Kokkos::RangePolicy<Kokkos::Experimental::HPX> policy(0, n);
Kokkos::parallel_scan(
"parallel_scan_range_policy", policy,
KOKKOS_LAMBDA(const int, int &x, bool) {
if (!Kokkos::Experimental::HPX::in_parallel()) {
++x;
}
},
a);

ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());
Kokkos::fence();
ASSERT_FALSE(Kokkos::Experimental::HPX::in_parallel());

for (int i = 0; i < n; ++i) {
ASSERT_EQ(a(i), 0);
}
}
} // namespace

0 comments on commit ab6f756

Please sign in to comment.