Skip to content

Commit

Permalink
Add bound checks in RangePolicy and MDRangePolicy (kokkos#6617)
Browse files Browse the repository at this point in the history
* Added a bounds check in MDRangePolicy that checks that all lower bounds are less than its upper bound

* Modified the wording on the abort

* Converted the error msg from a stringstream to a string

* Modified abort msg

* Fixed the unit test output based on backend's default iterate direction

* Update core/unit_test/TestMDRangePolicyConstructors.hpp

Formatting.

Co-authored-by: Damien L-G <dalg24+github@gmail.com>

* Updated RangePolicy to have the same precondition as MDRangePolicy

---------

Co-authored-by: Damien L-G <dalg24+github@gmail.com>
  • Loading branch information
ldh4 and dalg24 committed Jan 17, 2024
1 parent 9593413 commit 179d2e6
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 4 deletions.
10 changes: 10 additions & 0 deletions core/src/KokkosExp_MDRangePolicy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ struct MDRangePolicy : public Kokkos::Impl::PolicyTraits<Properties...> {
}
for (int i = rank_start; i != rank_end; i += increment) {
const index_type length = m_upper[i] - m_lower[i];

if (m_upper[i] < m_lower[i]) {
std::string msg =
"Kokkos::MDRangePolicy bounds error: The lower bound (" +
std::to_string(m_lower[i]) + ") is greater than its upper bound (" +
std::to_string(m_upper[i]) + ") in dimension " + std::to_string(i) +
".";
Kokkos::abort(msg.c_str());
}

if (m_tile[i] <= 0) {
m_tune_tile_size = true;
if ((inner_direction == Iterate::Right && (i < rank - 1)) ||
Expand Down
22 changes: 18 additions & 4 deletions core/src/Kokkos_ExecPolicy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,18 @@ class RangePolicy : public Impl::PolicyTraits<Properties...> {
inline RangePolicy(const typename traits::execution_space& work_space,
const member_type work_begin, const member_type work_end)
: m_space(work_space),
m_begin(work_begin < work_end ? work_begin : 0),
m_end(work_begin < work_end ? work_end : 0),
m_begin(work_begin),
m_end(work_end),
m_granularity(0),
m_granularity_mask(0) {
check_bounds_validity();
set_auto_chunk_size();
}

/** \brief Total range */
inline RangePolicy(const member_type work_begin, const member_type work_end)
: RangePolicy(typename traits::execution_space(), work_begin, work_end) {
check_bounds_validity();
set_auto_chunk_size();
}

Expand All @@ -136,10 +138,11 @@ class RangePolicy : public Impl::PolicyTraits<Properties...> {
const member_type work_begin, const member_type work_end,
Args... args)
: m_space(work_space),
m_begin(work_begin < work_end ? work_begin : 0),
m_end(work_begin < work_end ? work_end : 0),
m_begin(work_begin),
m_end(work_end),
m_granularity(0),
m_granularity_mask(0) {
check_bounds_validity();
set_auto_chunk_size();
set(args...);
}
Expand All @@ -149,6 +152,7 @@ class RangePolicy : public Impl::PolicyTraits<Properties...> {
inline RangePolicy(const member_type work_begin, const member_type work_end,
Args... args)
: RangePolicy(typename traits::execution_space(), work_begin, work_end) {
check_bounds_validity();
set_auto_chunk_size();
set(args...);
}
Expand Down Expand Up @@ -218,6 +222,16 @@ class RangePolicy : public Impl::PolicyTraits<Properties...> {
m_granularity_mask = m_granularity - 1;
}

inline void check_bounds_validity() {
if (m_end < m_begin) {
std::string msg = "Kokkos::RangePolicy bounds error: The lower bound (" +
std::to_string(m_begin) +
") is greater than the upper bound (" +
std::to_string(m_end) + ").";
Kokkos::abort(msg.c_str());
}
}

public:
/** \brief Subrange for a partition's rank and size.
*
Expand Down
16 changes: 16 additions & 0 deletions core/unit_test/TestMDRangePolicyConstructors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ TEST(TEST_CATEGORY_DEATH, policy_bounds_unsafe_narrowing_conversions) {
},
"unsafe narrowing conversion");
}

TEST(TEST_CATEGORY_DEATH, policy_invalid_bounds) {
using Policy = Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>>;

::testing::FLAGS_gtest_death_test_style = "threadsafe";

auto dim = (Policy::inner_direction == Kokkos::Iterate::Right) ? 1 : 0;

ASSERT_DEATH(
{
(void)Policy({100, 100}, {90, 90});
},
"Kokkos::MDRangePolicy bounds error: The lower bound \\(100\\) is "
"greater than its upper bound \\(90\\) in dimension " +
std::to_string(dim) + "\\.");
}
#endif

} // namespace
13 changes: 13 additions & 0 deletions core/unit_test/TestRangePolicyConstructors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,17 @@ TEST(TEST_CATEGORY, range_policy_runtime_parameters) {
}
}

TEST(TEST_CATEGORY_DEATH, range_policy_invalid_bounds) {
using Policy = Kokkos::RangePolicy<TEST_EXECSPACE>;
using ChunkSize = Kokkos::ChunkSize;

ASSERT_DEATH({ (void)Policy(100, 90); },
"Kokkos::RangePolicy bounds error: The lower bound \\(100\\) is "
"greater than the upper bound \\(90\\)\\.");

ASSERT_DEATH({ (void)Policy(TEST_EXECSPACE(), 100, 90, ChunkSize(10)); },
"Kokkos::RangePolicy bounds error: The lower bound \\(100\\) is "
"greater than the upper bound \\(90\\)\\.");
}

} // namespace

0 comments on commit 179d2e6

Please sign in to comment.