Skip to content

Commit

Permalink
Use batched copy if. (#6826)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 6, 2021
1 parent aa0d8f2 commit 7bcc8b3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
23 changes: 19 additions & 4 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,21 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
num_items, nullptr, false)));
}

template <typename InIt, typename OutIt, typename Predicate>
void CopyIf(InIt in_first, InIt in_second, OutIt out_first, Predicate pred) {
// We loop over batches because thrust::copy_if cant deal with sizes > 2^31
// See thrust issue #1302, #6822
size_t max_copy_size = std::numeric_limits<int>::max() / 2;
size_t length = std::distance(in_first, in_second);
XGBCachingDeviceAllocator<char> alloc;
for (size_t offset = 0; offset < length; offset += max_copy_size) {
auto begin_input = in_first + offset;
auto end_input = in_first + std::min(offset + max_copy_size, length);
out_first = thrust::copy_if(thrust::cuda::par(alloc), begin_input,
end_input, out_first, pred);
}
}

template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
Expand All @@ -1311,14 +1326,14 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i

if (accending) {
void *d_temp_storage = nullptr;
cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false);
sizeof(KeyT) * 8, false, nullptr, false)));
dh::TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false);
sizeof(KeyT) * 8, false, nullptr, false)));
} else {
void *d_temp_storage = nullptr;
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
Expand Down
5 changes: 2 additions & 3 deletions src/common/hist_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
size_t num_valid = column_sizes_scan->back();
// Copy current subset of valid elements into temporary storage and sort
sorted_entries->resize(num_valid);
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + range.begin(),
entry_iter + range.end(), sorted_entries->begin(), is_valid);
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(),
sorted_entries->begin(), is_valid);
}

void SortByWeight(dh::device_vector<float>* weights,
Expand Down
13 changes: 2 additions & 11 deletions src/data/simple_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,9 @@ void CopyDataToDMatrix(AdapterT* adapter, common::Span<Entry> data,
COOToEntryOp<decltype(batch)> transform_op{batch};
thrust::transform_iterator<decltype(transform_op), decltype(counting)>
transform_iter(counting, transform_op);
// We loop over batches because thrust::copy_if cant deal with sizes > 2^31
// See thrust issue #1302
size_t max_copy_size = std::numeric_limits<int>::max() / 2;
auto begin_output = thrust::device_pointer_cast(data.data());
for (size_t offset = 0; offset < batch.Size(); offset += max_copy_size) {
auto begin_input = transform_iter + offset;
auto end_input =
transform_iter + std::min(offset + max_copy_size, batch.Size());
begin_output =
thrust::copy_if(thrust::cuda::par(alloc), begin_input, end_input,
begin_output, IsValidFunctor(missing));
}
dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output,
IsValidFunctor(missing));
}

// Does not currently support metainfo as no on-device data source contains this
Expand Down

0 comments on commit 7bcc8b3

Please sign in to comment.