From 067becf9bc2afa9257c83e4261d11714746ba786 Mon Sep 17 00:00:00 2001 From: Eric Cano <37585813+ericcano@users.noreply.github.com> Date: Wed, 16 Nov 2022 14:44:04 +0100 Subject: [PATCH] Add range checking to SoA view-level indexed accessors --- DataFormats/SoATemplate/interface/SoALayout.h | 11 ----- DataFormats/SoATemplate/interface/SoAView.h | 12 ++++- .../SoATemplate/test/SoALayoutAndView_t.cu | 19 ++++++-- .../AlpakaTest/plugins/TestAlpakaAnalyzer.cc | 46 +++++++++++-------- 4 files changed, 54 insertions(+), 34 deletions(-) diff --git a/DataFormats/SoATemplate/interface/SoALayout.h b/DataFormats/SoATemplate/interface/SoALayout.h index 9cee0272a3dff..ad262f48334a0 100644 --- a/DataFormats/SoATemplate/interface/SoALayout.h +++ b/DataFormats/SoATemplate/interface/SoALayout.h @@ -581,17 +581,6 @@ throw std::runtime_error("In " #CLASS "::" #CLASS ": unexpected end pointer."); \ } \ \ - /* Range checker conditional to the macro _DO_RANGECHECK */ \ - SOA_HOST_DEVICE SOA_INLINE \ - void rangeCheck(size_type index) const { \ - if constexpr (_DO_RANGECHECK) { \ - if (index >= elements_) { \ - printf("In " #CLASS "::rangeCheck(): index out of range: %zu with elements: %zu\n", index, elements_); \ - assert(false); \ - } \ - } \ - } \ - \ /* Data members */ \ std::byte* mem_; \ size_type elements_; \ diff --git a/DataFormats/SoATemplate/interface/SoAView.h b/DataFormats/SoATemplate/interface/SoAView.h index 8ced878378f24..0099587a98180 100644 --- a/DataFormats/SoATemplate/interface/SoAView.h +++ b/DataFormats/SoATemplate/interface/SoAView.h @@ -389,7 +389,11 @@ namespace cms::soa { cms::soa::SoAAccessType::mutableAccess>::template Alignment:: \ template RestrictQualifier::ParamReturnType \ LOCAL_NAME(size_type index) { \ - return typename cms::soa::SoAAccessors:: \ + if constexpr (rangeChecking == cms::soa::RangeChecking::enabled) { \ + if (index >= base_type::elements_) \ + SOA_THROW_OUT_OF_RANGE("Out of range index in mutable " #LOCAL_NAME "(size_type index)") \ + } \ + return typename cms::soa::SoAAccessors:: \ template ColumnType::template AccessType< \ cms::soa::SoAAccessType::mutableAccess>::template Alignment:: \ template RestrictQualifier(const_cast_SoAParametersImpl( \ @@ -423,7 +427,11 @@ namespace cms::soa { cms::soa::SoAAccessType::constAccess>::template Alignment:: \ template RestrictQualifier::ParamReturnType \ LOCAL_NAME(size_type index) const { \ - return typename cms::soa::SoAAccessors:: \ + if constexpr (rangeChecking == cms::soa::RangeChecking::enabled) { \ + if (index >= elements_) \ + SOA_THROW_OUT_OF_RANGE("Out of range index in const " #LOCAL_NAME "(size_type index)") \ + } \ + return typename cms::soa::SoAAccessors:: \ template ColumnType::template AccessType< \ cms::soa::SoAAccessType::constAccess>::template Alignment:: \ template RestrictQualifier(BOOST_PP_CAT(LOCAL_NAME, Parameters_))(index); \ diff --git a/DataFormats/SoATemplate/test/SoALayoutAndView_t.cu b/DataFormats/SoATemplate/test/SoALayoutAndView_t.cu index 00e3465a1304a..34dda3bd7a803 100644 --- a/DataFormats/SoATemplate/test/SoALayoutAndView_t.cu +++ b/DataFormats/SoATemplate/test/SoALayoutAndView_t.cu @@ -101,7 +101,7 @@ using RangeCheckingHostDeviceView = // We expect to just run one thread. __global__ void rangeCheckKernel(RangeCheckingHostDeviceView soa) { - printf("About to fail range-check in CUDA thread: %d\n", threadIdx.x); + printf("About to fail range-check (operator[]) in CUDA thread: %d\n", threadIdx.x); [[maybe_unused]] auto si = soa[soa.metadata().size()]; printf("Fail: range-check failure should have stopped the kernel.\n"); } @@ -250,10 +250,23 @@ int main(void) { soa1viewRangeChecking(h_soahdLayout); // This should throw an exception [[maybe_unused]] auto si = soa1viewRangeChecking[soa1viewRangeChecking.metadata().size()]; - std::cout << "Fail: expected range-check exception not caught on the host." << std::endl; + std::cout << "Fail: expected range-check exception (operator[]) not caught on the host." << std::endl; assert(false); } catch (const std::out_of_range&) { - std::cout << "Pass: expected range-check exception successfully caught on the host." << std::endl; + std::cout << "Pass: expected range-check exception (operator[]) successfully caught on the host." << std::endl; + } + + try { + // Get a view like the default, except for range checking + SoAHostDeviceLayout::ViewTemplate + soa1viewRangeChecking(h_soahdLayout); + // This should throw an exception + [[maybe_unused]] auto si = soa1viewRangeChecking[soa1viewRangeChecking.metadata().size()]; + std::cout << "Fail: expected range-check exception (view-level index access) not caught on the host." << std::endl; + assert(false); + } catch (const std::out_of_range&) { + std::cout << "Pass: expected range-check exception (view-level index access) successfully caught on the host." + << std::endl; } // Validation of range checking in a kernel diff --git a/HeterogeneousCore/AlpakaTest/plugins/TestAlpakaAnalyzer.cc b/HeterogeneousCore/AlpakaTest/plugins/TestAlpakaAnalyzer.cc index 1058af9c9e77e..2c6ed46c5d282 100644 --- a/HeterogeneousCore/AlpakaTest/plugins/TestAlpakaAnalyzer.cc +++ b/HeterogeneousCore/AlpakaTest/plugins/TestAlpakaAnalyzer.cc @@ -48,6 +48,29 @@ namespace { column.print(out); return out; } + + template + void checkViewAddresses(T const& view) { + assert(view.metadata().addressOf_x() == view.x()); + assert(view.metadata().addressOf_x() == &view.x(0)); + assert(view.metadata().addressOf_x() == &view[0].x()); + assert(view.metadata().addressOf_y() == view.y()); + assert(view.metadata().addressOf_y() == &view.y(0)); + assert(view.metadata().addressOf_y() == &view[0].y()); + assert(view.metadata().addressOf_z() == view.z()); + assert(view.metadata().addressOf_z() == &view.z(0)); + assert(view.metadata().addressOf_z() == &view[0].z()); + assert(view.metadata().addressOf_id() == view.id()); + assert(view.metadata().addressOf_id() == &view.id(0)); + assert(view.metadata().addressOf_id() == &view[0].id()); + assert(view.metadata().addressOf_m() == view.m()); + assert(view.metadata().addressOf_m() == &view.m(0).coeffRef(0, 0)); + assert(view.metadata().addressOf_m() == &view[0].m().coeffRef(0, 0)); + assert(view.metadata().addressOf_r() == &view.r()); + //assert(view.metadata().addressOf_r() == &view.r(0)); // cannot access a scalar with an index + //assert(view.metadata().addressOf_r() == &view[0].r()); // cannot access a scalar via a SoA row-like accessor + } + } // namespace class TestAlpakaAnalyzer : public edm::stream::EDAnalyzer<> { @@ -58,6 +81,8 @@ class TestAlpakaAnalyzer : public edm::stream::EDAnalyzer<> { void analyze(edm::Event const& event, edm::EventSetup const&) override { portabletest::TestHostCollection const& product = event.get(token_); auto const& view = product.const_view(); + auto& mview = product.view(); + auto const& cmview = product.view(); { edm::LogInfo msg("TestAlpakaAnalyzer"); @@ -88,24 +113,9 @@ class TestAlpakaAnalyzer : public edm::stream::EDAnalyzer<> { reinterpret_cast(view.metadata().addressOf_r()); } - assert(view.metadata().addressOf_x() == view.x()); - assert(view.metadata().addressOf_x() == &view.x(0)); - assert(view.metadata().addressOf_x() == &view[0].x()); - assert(view.metadata().addressOf_y() == view.y()); - assert(view.metadata().addressOf_y() == &view.y(0)); - assert(view.metadata().addressOf_y() == &view[0].y()); - assert(view.metadata().addressOf_z() == view.z()); - assert(view.metadata().addressOf_z() == &view.z(0)); - assert(view.metadata().addressOf_z() == &view[0].z()); - assert(view.metadata().addressOf_id() == view.id()); - assert(view.metadata().addressOf_id() == &view.id(0)); - assert(view.metadata().addressOf_id() == &view[0].id()); - assert(view.metadata().addressOf_m() == view.m()); - assert(view.metadata().addressOf_m() == &view.m(0).coeffRef(0, 0)); - assert(view.metadata().addressOf_m() == &view[0].m().coeffRef(0, 0)); - assert(view.metadata().addressOf_r() == &view.r()); - //assert(view.metadata().addressOf_r() == &view.r(0)); // cannot access a scalar with an index - //assert(view.metadata().addressOf_r() == &view[0].r()); // cannot access a scalar via a SoA row-like accessor + checkViewAddresses(view); + checkViewAddresses(mview); + checkViewAddresses(cmview); const portabletest::Matrix matrix{{1, 2, 3, 4, 5, 6}, {2, 4, 6, 8, 10, 12}, {3, 6, 9, 12, 15, 18}}; assert(view.r() == 1.);