Skip to content

Commit

Permalink
review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
esseivaju committed Dec 5, 2023
1 parent e6ade1f commit 7029c2b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
13 changes: 8 additions & 5 deletions src/celeritas/track/detail/TrackSortUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,14 @@ void sort_tracks(HostRef<CoreStateData> const& states, TrackOrder order)
{
case TrackOrder::partition_status:
return partition_impl(states.track_slots,
alive_predicate{states.sim.status.data()});
AlivePredicate{states.sim.status.data()});
case TrackOrder::sort_along_step_action:
case TrackOrder::sort_step_limit_action:
return sort_impl(states.track_slots,
id_comparator{get_action_ptr(states, order)});
IdComparator{get_action_ptr(states, order)});
case TrackOrder::sort_particle_type:
return sort_impl(
states.track_slots,
id_comparator{states.particles.particle_id.data()});
return sort_impl(states.track_slots,
IdComparator{states.particles.particle_id.data()});
default:
CELER_ASSERT_UNREACHABLE();
}
Expand Down Expand Up @@ -136,6 +135,10 @@ void count_tracks_per_action(
backfill_action_count(offsets, size);
}

//---------------------------------------------------------------------------//
/*!
* Fill missing action offsets.
*/
void backfill_action_count(Span<ThreadId> offsets, size_type num_actions)
{
CELER_EXPECT(offsets.size() >= 2);
Expand Down
21 changes: 17 additions & 4 deletions src/celeritas/track/detail/TrackSortUtils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ using ThreadItems
using TrackSlots = ThreadItems<TrackSlotId::size_type>;

//---------------------------------------------------------------------------//

/*!
* Partition track_slots based on predicate.
*/
template<class F>
void partition_impl(TrackSlots const& track_slots, F&& func, StreamId stream_id)
{
Expand All @@ -59,7 +61,10 @@ void partition_impl(TrackSlots const& track_slots, F&& func, StreamId stream_id)
}

//---------------------------------------------------------------------------//

/*!
* Reorder OpaqueId's based on track_slots so that track_slots[tid] correspond
* to ids[tid] instead of ids[tacks_slots[tid]].
*/
template<class Id>
__global__ void
reorder_ids_kernel(ObserverPtr<TrackSlotId::size_type const> track_slots,
Expand All @@ -75,6 +80,10 @@ reorder_ids_kernel(ObserverPtr<TrackSlotId::size_type const> track_slots,
}
}

//---------------------------------------------------------------------------//
/*!
* Sort track slots using ids as keys.
*/
template<class Id, class IdT = typename Id::size_type>
void sort_impl(TrackSlots const& track_slots,
ObserverPtr<Id const> ids,
Expand All @@ -96,7 +105,11 @@ void sort_impl(TrackSlots const& track_slots,
CELER_DEVICE_CHECK_ERROR();
}

// PRE: actions are sorted
//---------------------------------------------------------------------------//
/*!
* Calculate thread boundaries based on action ID.
* \pre actions are sorted
*/
__global__ void
tracks_per_action_kernel(ObserverPtr<ActionId const> actions,
ObserverPtr<TrackSlotId::size_type const> track_slots,
Expand Down Expand Up @@ -175,7 +188,7 @@ void sort_tracks(DeviceRef<CoreStateData> const& states, TrackOrder order)
{
case TrackOrder::partition_status:
return partition_impl(states.track_slots,
alive_predicate{states.sim.status.data()},
AlivePredicate{states.sim.status.data()},
states.stream_id);
case TrackOrder::sort_along_step_action:
case TrackOrder::sort_step_limit_action:
Expand Down
11 changes: 7 additions & 4 deletions src/celeritas/track/detail/TrackSortUtils.hh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ void sort_tracks(DeviceRef<CoreStateData> const&, TrackOrder);

//---------------------------------------------------------------------------//
// Count tracks associated to each action

void count_tracks_per_action(
HostRef<CoreStateData> const&,
Span<ThreadId>,
Expand All @@ -70,12 +69,14 @@ void count_tracks_per_action(
Collection<ThreadId, Ownership::value, MemSpace::mapped, ActionId>&,
TrackOrder);

//---------------------------------------------------------------------------//
// Fill missing action offsets.
void backfill_action_count(Span<ThreadId>, size_type);

//---------------------------------------------------------------------------//
// HELPER CLASSES AND FUNCTIONS
//---------------------------------------------------------------------------//
struct alive_predicate
struct AlivePredicate
{
ObserverPtr<TrackStatus const> status_;

Expand All @@ -86,7 +87,7 @@ struct alive_predicate
};

template<class Id>
struct id_comparator
struct IdComparator
{
ObserverPtr<Id const> ids_;

Expand All @@ -107,6 +108,8 @@ struct ActionAccessor
}
};

//---------------------------------------------------------------------------//
// Return the correct action pointer based on the track sort order
template<Ownership W, MemSpace M>
CELER_FUNCTION ObserverPtr<ActionId const>
get_action_ptr(CoreStateData<W, M> const& states, TrackOrder order)
Expand All @@ -127,7 +130,7 @@ get_action_ptr(CoreStateData<W, M> const& states, TrackOrder order)
//---------------------------------------------------------------------------//

template<class Id>
id_comparator(ObserverPtr<Id>) -> id_comparator<Id>;
IdComparator(ObserverPtr<Id>) -> IdComparator<Id>;

//---------------------------------------------------------------------------//
// INLINE DEFINITIONS
Expand Down

0 comments on commit 7029c2b

Please sign in to comment.