Skip to content

Commit

Permalink
Support sorting tracks by particle types (#1044)
Browse files Browse the repository at this point in the history
  • Loading branch information
esseivaju committed Dec 2, 2023
1 parent 2e2d09d commit d7ba565
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 37 deletions.
1 change: 1 addition & 0 deletions src/celeritas/Types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ char const* to_cstring(TrackOrder value)
"sort_along_step_action",
"sort_step_limit_action",
"sort_action",
"sort_particle_type",
};
return to_cstring_impl(value);
}
Expand Down
1 change: 1 addition & 0 deletions src/celeritas/Types.hh
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ enum class TrackOrder
sort_along_step_action, //!< Sort only by the along-step action id
sort_step_limit_action, //!< Sort only by the step limit action id
sort_action, //!< Sort by along-step id, then post-step ID
sort_particle_type, //!< Sort by particle type
size_
};

Expand Down
44 changes: 33 additions & 11 deletions src/celeritas/track/SortTracksAction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,21 @@ bool is_sort_trackorder(TrackOrder to)
TrackOrder::sort_step_limit_action,
TrackOrder::sort_along_step_action,
TrackOrder::sort_action,
TrackOrder::sort_particle_type,
};
return std::find(std::begin(allowed), std::end(allowed), to)
!= std::end(allowed);
}

/*!
* Checks whether the TrackOrder sort tracks using an ActionId.
*/
inline bool is_sort_by_action(TrackOrder to)
{
return to == TrackOrder::sort_along_step_action
|| to == TrackOrder::sort_step_limit_action
|| to == TrackOrder::sort_action;
}
//---------------------------------------------------------------------------//
} // namespace

Expand Down Expand Up @@ -66,6 +77,9 @@ SortTracksAction::SortTracksAction(ActionId id, TrackOrder track_order)
// Sort *before* post-step action, i.e. *after* pre-post and
// along-step
return ActionOrder::sort_pre_post;
case TrackOrder::sort_particle_type:
// Sorth at the beginning of the step
return ActionOrder::sort_start;
default:
CELER_ASSERT_UNREACHABLE();
}
Expand All @@ -86,6 +100,8 @@ std::string SortTracksAction::label() const
return "sort-tracks-along-step";
case TrackOrder::sort_step_limit_action:
return "sort-tracks-post-step";
case TrackOrder::sort_particle_type:
return "sort-tracks-start";
default:
CELER_ASSERT_UNREACHABLE();
}
Expand All @@ -97,11 +113,14 @@ std::string SortTracksAction::label() const
void SortTracksAction::execute(CoreParams const&, CoreStateHost& state) const
{
detail::sort_tracks(state.ref(), track_order_);
detail::count_tracks_per_action(
state.ref(),
state.action_thread_offsets()[AllItems<ThreadId, MemSpace::host>{}],
state.action_thread_offsets(),
track_order_);
if (is_sort_by_action(track_order_))
{
detail::count_tracks_per_action(
state.ref(),
state.action_thread_offsets()[AllItems<ThreadId, MemSpace::host>{}],
state.action_thread_offsets(),
track_order_);
}
}

//---------------------------------------------------------------------------//
Expand All @@ -111,12 +130,15 @@ void SortTracksAction::execute(CoreParams const&, CoreStateHost& state) const
void SortTracksAction::execute(CoreParams const&, CoreStateDevice& state) const
{
detail::sort_tracks(state.ref(), track_order_);
detail::count_tracks_per_action(
state.ref(),
state.native_action_thread_offsets()[AllItems<ThreadId,
MemSpace::device>{}],
state.action_thread_offsets(),
track_order_);
if (is_sort_by_action(track_order_))
{
detail::count_tracks_per_action(
state.ref(),
state.native_action_thread_offsets()[AllItems<ThreadId,
MemSpace::device>{}],
state.action_thread_offsets(),
track_order_);
}
}

//---------------------------------------------------------------------------//
Expand Down
67 changes: 41 additions & 26 deletions src/celeritas/track/detail/TrackSortUtils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,36 +60,38 @@ void partition_impl(TrackSlots const& track_slots, F&& func, StreamId stream_id)

//---------------------------------------------------------------------------//

template<class Id>
__global__ void
reorder_actions_kernel(ObserverPtr<TrackSlotId::size_type const> track_slots,
ObserverPtr<ActionId const> actions,
ObserverPtr<ActionId::size_type> out_actions,
size_type size)
reorder_ids_kernel(ObserverPtr<TrackSlotId::size_type const> track_slots,
ObserverPtr<Id const> ids,
ObserverPtr<typename Id::size_type> ids_out,
size_type size)
{
if (ThreadId tid = celeritas::KernelParamCalculator::thread_id();
tid < size)
{
out_actions.get()[tid.get()]
= actions.get()[track_slots.get()[tid.get()]].unchecked_get();
ids_out.get()[tid.get()]
= ids.get()[track_slots.get()[tid.get()]].unchecked_get();
}
}

template<class Id, class IdT = typename Id::size_type>
void sort_impl(TrackSlots const& track_slots,
ObserverPtr<ActionId const> actions,
ObserverPtr<Id const> ids,
StreamId stream_id)
{
DeviceVector<ActionId::size_type> reordered_actions(track_slots.size(),
stream_id);
CELER_LAUNCH_KERNEL(reorder_actions,
track_slots.size(),
celeritas::device().stream(stream_id).get(),
track_slots.data(),
actions,
make_observer(reordered_actions.data()),
track_slots.size());
DeviceVector<IdT> reordered_ids(track_slots.size(), stream_id);
CELER_LAUNCH_KERNEL_TEMPLATE_1(reorder_ids,
Id,
track_slots.size(),
celeritas::device().stream(stream_id).get(),
track_slots.data(),
ids,
make_observer(reordered_ids.data()),
track_slots.size());
thrust::sort_by_key(thrust_execute_on(stream_id),
reordered_actions.data(),
reordered_actions.data() + reordered_actions.size(),
reordered_ids.data(),
reordered_ids.data() + reordered_ids.size(),
device_pointer_cast(track_slots.data()));
CELER_DEVICE_CHECK_ERROR();
}
Expand Down Expand Up @@ -197,14 +199,27 @@ void sort_tracks(DeviceRef<CoreStateData> const& states, TrackOrder order)
return partition_impl(states.track_slots,
alive_predicate{states.sim.status.data()},
states.stream_id);
case TrackOrder::sort_along_step_action:
return sort_impl(states.track_slots,
states.sim.along_step_action.data(),
states.stream_id);
case TrackOrder::sort_step_limit_action:
return sort_impl(states.track_slots,
states.sim.post_step_action.data(),
states.stream_id);
case TrackOrder::sort_along_step_action: {
using Id =
typename decltype(states.sim.along_step_action)::value_type;
return sort_impl<Id>(states.track_slots,
states.sim.along_step_action.data(),
states.stream_id);
}
case TrackOrder::sort_step_limit_action: {
using Id =
typename decltype(states.sim.post_step_action)::value_type;
return sort_impl<Id>(states.track_slots,
states.sim.post_step_action.data(),
states.stream_id);
}
case TrackOrder::sort_particle_type: {
using Id =
typename decltype(states.particles.particle_id)::value_type;
return sort_impl<Id>(states.track_slots,
states.particles.particle_id.data(),
states.stream_id);
}
default:
CELER_ASSERT_UNREACHABLE();
}
Expand Down
23 changes: 23 additions & 0 deletions src/corecel/sys/KernelParamCalculator.device.hh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,29 @@
CELER_DEVICE_CHECK_ERROR(); \
} while (0)

/*!
* \def CELER_LAUNCH_KERNEL_TEMPLATE_1
*
* Create a kernel param calculator with the given kernel with
* one template parameter, assuming the unction itself has a \c _kernel
* suffix, and launch with the given block/thread sizes and arguments list.
*/
#define CELER_LAUNCH_KERNEL_TEMPLATE_1(NAME, T1, THREADS, STREAM, ...) \
do \
{ \
static const ::celeritas::KernelParamCalculator calc_launch_params_( \
#NAME, NAME##_kernel<T1>); \
auto grid_ = calc_launch_params_(THREADS); \
\
CELER_LAUNCH_KERNEL_IMPL(NAME##_kernel, \
grid_.blocks_per_grid, \
grid_.threads_per_block, \
0, \
STREAM, \
__VA_ARGS__); \
CELER_DEVICE_CHECK_ERROR(); \
} while (0)

#if CELERITAS_USE_CUDA
# define CELER_LAUNCH_KERNEL_IMPL(KERNEL, GRID, BLOCK, SHARED, STREAM, ...) \
KERNEL<<<GRID, BLOCK, SHARED, STREAM>>>(__VA_ARGS__)
Expand Down

0 comments on commit d7ba565

Please sign in to comment.