Skip to content

Commit

Permalink
Ensure reproducibility when using MT Geant4 with Celeritas offloading (
Browse files Browse the repository at this point in the history
…#1061)

* Add kernel to reseed RNG states based on Geant4 event ID
* Reseed Celeritas RNG at the start of an event in LocalTransporter
* Add Celeritas step diagnostic to celer-g4 when offloading is enabled
* Move RngReseed to celeritas/random
* Use same type for RNG engine initializer seed and RNG params seed
* Add a reseed method to stepper
* Deprecate SetEventId and always reseed RNG
* Always reseed RNGs in serial mode
  • Loading branch information
amandalund committed Dec 15, 2023
1 parent 94fb9e4 commit bc3ccd5
Show file tree
Hide file tree
Showing 16 changed files with 307 additions and 14 deletions.
4 changes: 2 additions & 2 deletions app/celer-g4/EventAction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ void EventAction::BeginOfEventAction(G4Event const* event)
if (SharedParams::CeleritasDisabled())
return;

// Set event ID in local transporter
// Set event ID in local transporter and reseed Celerits RNG
ExceptionConverter call_g4exception{"celer0002"};
CELER_TRY_HANDLE(transport_->SetEventId(event->GetEventID()),
CELER_TRY_HANDLE(transport_->InitializeEvent(event->GetEventID()),
call_g4exception);
}

Expand Down
19 changes: 17 additions & 2 deletions app/celer-g4/GeantDiagnostics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#include "corecel/sys/MemRegistry.hh"
#include "corecel/sys/MultiExceptionHandler.hh"
#include "celeritas/Types.hh"
#include "celeritas/global/ActionRegistry.hh"
#include "celeritas/global/CoreParams.hh"
#include "celeritas/user/StepDiagnostic.hh"

#include "GlobalSetup.hh"

Expand Down Expand Up @@ -57,9 +59,22 @@ GeantDiagnostics::GeantDiagnostics(SharedParams const& params)
if (global_setup.StepDiagnostic())
{
// Create the track step diagnostic and add to output registry
step_diagnostic_ = std::make_shared<GeantStepDiagnostic>(
global_setup.GetStepDiagnosticBins(), num_threads);
auto num_bins = GlobalSetup::Instance()->GetStepDiagnosticBins();
step_diagnostic_
= std::make_shared<GeantStepDiagnostic>(num_bins, num_threads);
output_reg->insert(step_diagnostic_);

// Add the Celeritas step diagnostic if Celeritas offloading is enabled
if (params)
{
auto step_diagnostic = std::make_shared<celeritas::StepDiagnostic>(
params.Params()->action_reg()->next_id(),
params.Params()->particle(),
num_bins,
num_threads);
params.Params()->action_reg()->insert(step_diagnostic);
output_reg->insert(step_diagnostic);
}
}

if (!params)
Expand Down
15 changes: 13 additions & 2 deletions src/accel/LocalTransporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <string>
#include <type_traits>
#include <CLHEP/Units/SystemOfUnits.h>
#include <G4MTRunManager.hh>
#include <G4ParticleDefinition.hh>
#include <G4Threading.hh>
#include <G4ThreeVector.hh>
Expand Down Expand Up @@ -114,14 +115,24 @@ LocalTransporter::LocalTransporter(SetupOptions const& options,

//---------------------------------------------------------------------------//
/*!
* Set the event ID at the start of an event.
* Set the event ID and reseed the Celeritas RNG at the start of an event.
*/
void LocalTransporter::SetEventId(int id)
void LocalTransporter::InitializeEvent(int id)
{
CELER_EXPECT(*this);
CELER_EXPECT(id >= 0);

event_id_ = EventId(id);
track_counter_ = 0;

if (!(G4Threading::IsMultithreadedApplication()
&& G4MTRunManager::SeedOncePerCommunication()))
{
// Since Geant4 schedules events dynamically, reseed the Celeritas RNGs
// using the Geant4 event ID for reproducibility. This guarantees that
// an event can be reproduced given the event ID.
step_->reseed(event_id_);
}
}

//---------------------------------------------------------------------------//
Expand Down
7 changes: 5 additions & 2 deletions src/accel/LocalTransporter.hh
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ class LocalTransporter
inline void
Initialize(SetupOptions const& options, SharedParams const& params);

// Set the event ID
void SetEventId(int);
// Set the event ID and reseed the Celeritas RNG (remove in v1.0)
[[deprecated]] void SetEventId(int id) { this->InitializeEvent(id); }

// Set the event ID and reseed the Celeritas RNG at the start of an event
void InitializeEvent(int);

// Offload this track
void Push(G4Track const&);
Expand Down
7 changes: 4 additions & 3 deletions src/accel/SimpleOffload.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,17 @@ void SimpleOffload::BeginOfRunAction(G4Run const*)

//---------------------------------------------------------------------------//
/*!
* Send Celeritas the event ID.
* Send Celeritas the event ID and reseed the Celeritas RNG.
*/
void SimpleOffload::BeginOfEventAction(G4Event const* event)
{
if (!*this)
return;

// Set event ID in local transporter
// Set event ID in local transporter and reseed RNG for reproducibility
ExceptionConverter call_g4exception{"celer0002"};
CELER_TRY_HANDLE(local_->SetEventId(event->GetEventID()), call_g4exception);
CELER_TRY_HANDLE(local_->InitializeEvent(event->GetEventID()),
call_g4exception);
}

//---------------------------------------------------------------------------//
Expand Down
1 change: 1 addition & 0 deletions src/celeritas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ celeritas_polysource(global/alongstep/AlongStepUniformMscAction)
celeritas_polysource(global/alongstep/AlongStepRZMapFieldMscAction)
celeritas_polysource(phys/detail/DiscreteSelectAction)
celeritas_polysource(phys/detail/PreStepAction)
celeritas_polysource(random/RngReseed)
celeritas_polysource(random/detail/CuHipRngStateInit)
celeritas_polysource(track/detail/TrackInitAlgorithms)
celeritas_polysource(track/detail/TrackSortUtils)
Expand Down
16 changes: 16 additions & 0 deletions src/celeritas/global/Stepper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "corecel/sys/ScopedProfiling.hh"
#include "orange/OrangeData.hh"
#include "celeritas/Types.hh"
#include "celeritas/random/RngParams.hh"
#include "celeritas/random/RngReseed.hh"
#include "celeritas/track/TrackInitParams.hh"

#include "CoreParams.hh"
Expand Down Expand Up @@ -103,6 +105,20 @@ auto Stepper<M>::operator()(SpanConstPrimary primaries) -> result_type
return (*this)();
}

//---------------------------------------------------------------------------//
/*!
* Reseed the RNGs at the start of an event for "strong" reproducibility.
*
* This reinitializes the RNG states using a single seed and unique subsequence
* for each thread. It ensures that given an event number, that event can be
* reproduced.
*/
template<MemSpace M>
void Stepper<M>::reseed(EventId event_id)
{
reseed_rng(get_ref<M>(*params_->rng()), state_.ref().rng, event_id.get());
}

//---------------------------------------------------------------------------//
// EXPLICIT INSTANTIATION
//---------------------------------------------------------------------------//
Expand Down
6 changes: 6 additions & 0 deletions src/celeritas/global/Stepper.hh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class StepperInterface
// Transport existing states and these new primaries
virtual StepperResult operator()(SpanConstPrimary primaries) = 0;

// Reseed the RNGs at the start of an event for reproducibility
virtual void reseed(EventId event_id) = 0;

//! Get action sequence for timing diagnostics
virtual ActionSequence const& actions() const = 0;

Expand Down Expand Up @@ -139,6 +142,9 @@ class Stepper final : public StepperInterface
// Transport existing states and these new primaries
StepperResult operator()(SpanConstPrimary primaries) final;

// Reseed the RNGs at the start of an event for reproducibility
void reseed(EventId event_id) final;

//! Get action sequence for timing diagnostics
ActionSequence const& actions() const final { return *actions_; }

Expand Down
40 changes: 40 additions & 0 deletions src/celeritas/random/RngReseed.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2023 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/random/RngReseed.cc
//---------------------------------------------------------------------------//
#include "RngReseed.hh"

#include "corecel/cont/Range.hh"
#include "corecel/sys/ThreadId.hh"

#include "RngEngine.hh"

namespace celeritas
{
//---------------------------------------------------------------------------//
/*!
* Reinitialize the RNG states on host at the start of an event.
*
* Each thread's state is initialized using same seed and skipped ahead a
* different number of subsequences so the sequences on different threads will
* not have statistically correlated values.
*/
void reseed_rng(HostCRef<RngParamsData> const& params,
HostRef<RngStateData> const& state,
size_type event_id)
{
for (auto tid : range(TrackSlotId{state.size()}))
{
RngEngine::Initializer_t init;
init.seed = params.seed;
init.subsequence = event_id * state.size() + tid.get();
RngEngine engine(params, state, tid);
engine = init;
}
}

//---------------------------------------------------------------------------//
} // namespace celeritas
67 changes: 67 additions & 0 deletions src/celeritas/random/RngReseed.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//---------------------------------*-CUDA-*----------------------------------//
// Copyright 2023 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/random/RngReseed.cu
//---------------------------------------------------------------------------//
#include "RngReseed.hh"

#include "corecel/device_runtime_api.h"
#include "corecel/Assert.hh"
#include "corecel/sys/Device.hh"
#include "corecel/sys/KernelParamCalculator.device.hh"

#include "RngEngine.hh"

namespace celeritas
{
namespace
{
//---------------------------------------------------------------------------//
// KERNELS
//---------------------------------------------------------------------------//
/*!
* Reinitialize the RNG states on device at the start of an event.
*/
__global__ void reseed_rng_kernel(DeviceCRef<RngParamsData> const params,
DeviceRef<RngStateData> const state,
size_type event_id)
{
auto tid = TrackSlotId{
celeritas::KernelParamCalculator::thread_id().unchecked_get()};
if (tid.get() < state.size())
{
TrackSlotId tsid{tid.unchecked_get()};
RngEngine::Initializer_t init;
init.seed = params.seed;
init.subsequence = event_id * state.size() + tsid.get();
RngEngine rng(params, state, tsid);
rng = init;
}
}

//---------------------------------------------------------------------------//
} // namespace

//---------------------------------------------------------------------------//
// KERNEL INTERFACE
//---------------------------------------------------------------------------//
/*!
* Reinitialize the RNG states on device at the start of an event.
*
* Each thread's state is initialized using same seed and skipped ahead a
* different number of subsequences so the sequences on different threads will
* not have statistically correlated values.
*/
void reseed_rng(DeviceCRef<RngParamsData> const& params,
DeviceRef<RngStateData> const& state,
size_type event_id)
{
CELER_EXPECT(state);
CELER_EXPECT(params);
CELER_LAUNCH_KERNEL(reseed_rng, state.size(), 0, params, state, event_id);
}

//---------------------------------------------------------------------------//
} // namespace celeritas
43 changes: 43 additions & 0 deletions src/celeritas/random/RngReseed.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2023 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/random/RngReseed.hh
//---------------------------------------------------------------------------//
#pragma once

#include "corecel/Assert.hh"
#include "corecel/Macros.hh"
#include "corecel/Types.hh"
#include "corecel/data/Collection.hh"

#include "RngData.hh"

namespace celeritas
{
//---------------------------------------------------------------------------//
// Reinitialize the RNG states on host/device at the start of an event
void reseed_rng(DeviceCRef<RngParamsData> const&,
DeviceRef<RngStateData> const&,
size_type);

void reseed_rng(HostCRef<RngParamsData> const&,
HostRef<RngStateData> const&,
size_type);

#if !CELER_USE_DEVICE
//---------------------------------------------------------------------------//
/*!
* Reinitialize the RNG states on device at the start of an event.
*/
inline void reseed_rng(DeviceCRef<RngParamsData> const&,
DeviceRef<RngStateData> const&,
size_type)
{
CELER_ASSERT_UNREACHABLE();
}
#endif

//---------------------------------------------------------------------------//
} // namespace celeritas
2 changes: 1 addition & 1 deletion src/celeritas/random/XorwowRngData.hh
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct XorwowRngParamsData
*/
struct XorwowRngInitializer
{
ull_int seed{0};
Array<unsigned int, 1> seed{0};
ull_int subsequence{0};
ull_int offset{0};
};
Expand Down
2 changes: 1 addition & 1 deletion src/celeritas/random/XorwowRngEngine.hh
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ XorwowRngEngine::operator=(Initializer_t const& init)
auto& s = state_->xorstate;

// Initialize the state from the seed
SplitMix64 rng{init.seed};
SplitMix64 rng{init.seed[0]};
uint64_t seed = rng();
s[0] = static_cast<uint_t>(seed);
s[1] = static_cast<uint_t>(seed >> 32);
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ set(CELERITASTEST_PREFIX celeritas/random)

celeritas_add_device_test(celeritas/random/RngEngine)
celeritas_add_test(celeritas/random/Selector.test.cc)
celeritas_add_test(celeritas/random/RngReseed.test.cc)
celeritas_add_test(celeritas/random/XorwowRngEngine.test.cc GPU)

celeritas_add_test(celeritas/random/distribution/BernoulliDistribution.test.cc)
Expand Down

0 comments on commit bc3ccd5

Please sign in to comment.