Skip to content

Commit

Permalink
Revert "Use external CUB API instead of internal ones (#900)"
Browse files Browse the repository at this point in the history
This reverts commit 9cd5d0e.
  • Loading branch information
crozhon committed Nov 7, 2021
1 parent 9cd5d0e commit 9479665
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 87 deletions.
48 changes: 26 additions & 22 deletions gunrock/util/reduce_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,29 @@ cudaError_t cubSegmentedReduce(util::Array1D<uint64_t, char> &cub_temp_space,
cudaError_t retval = cudaSuccess;
size_t request_bytes = 0;

retval = cub::DeviceSegmentedReduce::
Reduce(NULL, request_bytes, keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE), num_segments,
segment_offsets.GetPointer(util::DEVICE),
segment_offsets.GetPointer(util::DEVICE) + 1, reduction_op,
initial_value, stream, debug_synchronous);
retval = cub::DispatchSegmentedReduce<
InputT *, OutputT *, SizeT *, SizeT,
ReductionOp>::Dispatch(NULL, request_bytes,
keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE), num_segments,
segment_offsets.GetPointer(util::DEVICE),
segment_offsets.GetPointer(util::DEVICE) + 1,
reduction_op, initial_value, stream,
debug_synchronous);
if (retval) return retval;

retval = cub_temp_space.EnsureSize_(request_bytes, util::DEVICE);
if (retval) return retval;

retval = cub::DeviceSegmentedReduce::
Reduce(cub_temp_space.GetPointer(util::DEVICE), request_bytes,
keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE), num_segments,
segment_offsets.GetPointer(util::DEVICE),
segment_offsets.GetPointer(util::DEVICE) + 1,
reduction_op, initial_value, stream, debug_synchronous);
retval = cub::DispatchSegmentedReduce<
InputT *, OutputT *, SizeT *, SizeT,
ReductionOp>::Dispatch(cub_temp_space.GetPointer(util::DEVICE),
request_bytes, keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE), num_segments,
segment_offsets.GetPointer(util::DEVICE),
segment_offsets.GetPointer(util::DEVICE) + 1,
reduction_op, initial_value, stream,
debug_synchronous);
if (retval) return retval;

return retval;
Expand All @@ -115,21 +120,20 @@ cudaError_t cubReduce(util::Array1D<uint64_t, char> &cub_temp_space,
size_t request_bytes = 0;

retval =
cub::DeviceReduce::
Reduce(NULL, request_bytes, keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE), num_keys, reduction_op,
initial_value, stream, debug_synchronous);
cub::DispatchReduce<InputT *, OutputT *, SizeT, ReductionOp>::Dispatch(
NULL, request_bytes, keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE), num_keys, reduction_op,
initial_value, stream, debug_synchronous);
if (retval) return retval;

retval = cub_temp_space.EnsureSize_(request_bytes, util::DEVICE);
if (retval) return retval;

retval =
cub::DeviceReduce::
Reduce(cub_temp_space.GetPointer(util::DEVICE), request_bytes,
keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE), num_keys, reduction_op,
initial_value, stream, debug_synchronous);
cub::DispatchReduce<InputT *, OutputT *, SizeT, ReductionOp>::Dispatch(
cub_temp_space.GetPointer(util::DEVICE), request_bytes,
keys_in.GetPointer(util::DEVICE), keys_out.GetPointer(util::DEVICE),
num_keys, reduction_op, initial_value, stream, debug_synchronous);
if (retval) return retval;

return retval;
Expand Down
26 changes: 16 additions & 10 deletions gunrock/util/select_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -193,22 +193,28 @@ cudaError_t cubSelectIf(util::Array1D<uint64_t, char> &cub_temp_space,
bool debug_synchronous = false) {
cudaError_t retval = cudaSuccess;

typedef cub::NullType *FlagIterator;
typedef cub::NullType EqualityOp;

size_t request_bytes = 0;
retval = cub::DeviceSelect::
If(NULL, request_bytes, keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE),
num_selected.GetPointer(util::DEVICE), num_keys, select_op, stream,
debug_synchronous);
retval = cub::DispatchSelectIf<
InputT *, FlagIterator, OutputT *, SizeT *, SelectOp, EqualityOp, SizeT,
false>::Dispatch(NULL, request_bytes, keys_in.GetPointer(util::DEVICE),
NULL, keys_out.GetPointer(util::DEVICE),
num_selected.GetPointer(util::DEVICE), select_op,
EqualityOp(), num_keys, stream, debug_synchronous);
if (retval) return retval;

retval = cub_temp_space.EnsureSize_(request_bytes, util::DEVICE);
if (retval) return retval;

retval = cub::DeviceSelect::
If(cub_temp_space.GetPointer(util::DEVICE), request_bytes,
keys_in.GetPointer(util::DEVICE), keys_out.GetPointer(util::DEVICE),
num_selected.GetPointer(util::DEVICE), num_keys, select_op, stream,
debug_synchronous);
retval = cub::DispatchSelectIf<
InputT *, FlagIterator, OutputT *, SizeT *, SelectOp, EqualityOp, SizeT,
false>::Dispatch(cub_temp_space.GetPointer(util::DEVICE), request_bytes,
keys_in.GetPointer(util::DEVICE), NULL,
keys_out.GetPointer(util::DEVICE),
num_selected.GetPointer(util::DEVICE), select_op,
EqualityOp(), num_keys, stream, debug_synchronous);
if (retval) return retval;

return retval;
Expand Down
98 changes: 43 additions & 55 deletions gunrock/util/sort_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,23 @@ cudaError_t SegmentedSort(util::Array1D<SizeT, KeyT> &in,
util::Array1D<uint64_t, char> temp_space;
size_t request_bytes = 0;

retval = cub::DeviceSegmentedRadixSort::
SortKeys(NULL, request_bytes, keys, num_items, num_segments,
offsets.GetPointer(util::DEVICE),
offsets.GetPointer(util::DEVICE) + 1, begin_bit, end_bit,
stream, debug_synchronous);
retval = cub::DeviceSegmentedRadixSort::SortKeys(NULL, request_bytes,
keys, num_items, num_segments,
offsets.GetPointer(util::DEVICE),
offsets.GetPointer(util::DEVICE) + 1,
begin_bit, end_bit, stream, debug_synchronous);
if(retval) return retval;

retval = temp_space.EnsureSize_(request_bytes, util::DEVICE);
if(retval) return retval;

retval = cub::DeviceSegmentedRadixSort::
SortKeys(temp_space.GetPointer(util::DEVICE), request_bytes, keys,
num_items, num_segments, offsets.GetPointer(util::DEVICE),
offsets.GetPointer(util::DEVICE) + 1, begin_bit, end_bit,
stream, debug_synchronous);
retval = cub::DeviceSegmentedRadixSort::SortKeys(
temp_space.GetPointer(util::DEVICE),
request_bytes,
keys, num_items, num_segments,
offsets.GetPointer(util::DEVICE),
offsets.GetPointer(util::DEVICE) + 1,
begin_bit, end_bit, stream, debug_synchronous);

if(retval) return retval;

Expand Down Expand Up @@ -286,26 +288,19 @@ cudaError_t cubSortPairs(util::Array1D<uint64_t, char> &temp_space,
values_out.GetPointer(util::DEVICE));

size_t request_bytes = 0;
retval = cub::DeviceRadixSort::
SortPairs(NULL, request_bytes, keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE),
values_in.GetPointer(util::DEVICE),
values_out.GetPointer(util::DEVICE), num_items, begin_bit,
end_bit, stream, debug_synchronous);
retval = cub::DispatchRadixSort<false, KeyT, ValueT, SizeT>::Dispatch(
NULL, request_bytes, keys, values, num_items, begin_bit, end_bit, false,
stream, debug_synchronous);
if (retval) return retval;
// util::PrintMsg("num_items = " + std::to_string(num_items)
// + ", request_bytes = " + std::to_string(request_bytes));

retval = temp_space.EnsureSize_(request_bytes, util::DEVICE);
if (retval) return retval;

retval = cub::DeviceRadixSort::
SortPairs(temp_space.GetPointer(util::DEVICE), request_bytes,
keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE),
values_in.GetPointer(util::DEVICE),
values_out.GetPointer(util::DEVICE), num_items, begin_bit,
end_bit, stream, debug_synchronous);
retval = cub::DispatchRadixSort<false, KeyT, ValueT, SizeT>::Dispatch(
temp_space.GetPointer(util::DEVICE), request_bytes, keys, values,
num_items, begin_bit, end_bit, false, stream, debug_synchronous);
if (retval) return retval;

if (keys.Current() != keys_out.GetPointer(util::DEVICE)) {
Expand Down Expand Up @@ -347,27 +342,17 @@ cudaError_t cubSortPairsDescending(util::Array1D<uint64_t, char> &temp_space,
values_out.GetPointer(util::DEVICE));

size_t request_bytes = 0;

retval = cub::DeviceRadixSort::
SortPairsDescending(NULL, request_bytes,
keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE),
values_in.GetPointer(util::DEVICE),
values_out.GetPointer(util::DEVICE), num_items,
begin_bit, end_bit, stream, debug_synchronous);
retval = cub::DispatchRadixSort<true, KeyT, ValueT, SizeT>::Dispatch(
NULL, request_bytes, keys, values, num_items, begin_bit, end_bit, false,
stream, debug_synchronous);
if (retval) return retval;

retval = temp_space.EnsureSize_(request_bytes, util::DEVICE);
if (retval) return retval;

retval = cub::DeviceRadixSort::
SortPairsDescending((void*)temp_space.GetPointer(util::DEVICE),
request_bytes, keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE),
values_in.GetPointer(util::DEVICE),
values_out.GetPointer(util::DEVICE),
num_items, begin_bit, end_bit, stream,
debug_synchronous);
retval = cub::DispatchRadixSort<true, KeyT, ValueT, SizeT>::Dispatch(
temp_space.GetPointer(util::DEVICE), request_bytes, keys, values,
num_items, begin_bit, end_bit, false, stream, debug_synchronous);
if (retval) return retval;

if (keys.Current() != keys_out.GetPointer(util::DEVICE)) {
Expand Down Expand Up @@ -402,28 +387,31 @@ cudaError_t cubSegmentedSortPairs(
bool debug_synchronous = false) {
cudaError_t retval = cudaSuccess;

cub::DoubleBuffer<KeyT> keys(
const_cast<KeyT *>(keys_in.GetPointer(util::DEVICE)),
keys_out.GetPointer(util::DEVICE));
cub::DoubleBuffer<ValueT> values(
const_cast<ValueT *>(values_in.GetPointer(util::DEVICE)),
values_out.GetPointer(util::DEVICE));

size_t request_bytes = 0;
retval = cub::DeviceSegmentedRadixSort::
SortPairs(NULL, request_bytes,
keys_in.GetPointer(util::DEVICE),
keys_out.GetPointer(util::DEVICE),
values_in.GetPointer(util::DEVICE),
values_out.GetPointer(util::DEVICE), num_items, num_segments,
seg_offsets.GetPointer(util::DEVICE),
seg_offsets.GetPointer(util::DEVICE) + 1, begin_bit, end_bit,
stream, debug_synchronous);
retval =
cub::DispatchSegmentedRadixSort<false, KeyT, ValueT, SizeT *, SizeT>::
Dispatch(NULL, request_bytes, keys, values, num_items, num_segments,
seg_offsets.GetPointer(util::DEVICE),
seg_offsets.GetPointer(util::DEVICE) + 1, begin_bit, end_bit,
false, stream, debug_synchronous);
if (retval) return retval;

retval = temp_space.EnsureSize_(request_bytes, util::DEVICE);
if (retval) return retval;

retval = cub::DeviceSegmentedRadixSort::
SortPairs(temp_space.GetPointer(util::DEVICE), request_bytes,
keys_in, keys_out, values_in, values_out, num_items,
num_segments, num_items, num_segments,
seg_offsets.GetPointer(util::DEVICE),
seg_offsets.GetPointer(util::DEVICE) + 1, begin_bit,
end_bit, stream, debug_synchronous);
retval = cub::
DispatchSegmentedRadixSort<false, KeyT, ValueT, SizeT *, SizeT>::Dispatch(
temp_space.GetPointer(util::DEVICE), request_bytes, keys, values,
num_items, num_segments, seg_offsets.GetPointer(util::DEVICE),
seg_offsets.GetPointer(util::DEVICE) + 1, begin_bit, end_bit, false,
stream, debug_synchronous);
if (retval) return retval;

return retval;
Expand Down

0 comments on commit 9479665

Please sign in to comment.