Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Mar 4, 2020
1 parent 3910f0b commit 5852f48
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

#include "../data/adapter.h"
#include "../data/device_adapter.cuh"
#include "../tree/param.h"
#include "device_helpers.cuh"
#include "hist_util.h"
#include "math.h" // NOLINT
Expand Down Expand Up @@ -149,15 +148,15 @@ void ExtractCuts(int device, Span<SketchEntry> cuts,
}

/**
* \brief Extracts the cuts from sorted data, considering weights.
*
* \param device The device.
* \param cuts Output cuts
* \param num_cuts_per_feature Number of cuts per feature.
* \param sorted_data Sorted entries in segments of columns
* \param column_sizes_scan Describes the boundaries of column segments in
* sorted data
*/
* \brief Extracts the cuts from sorted data, considering weights.
*
* \param device The device.
* \param cuts Output cuts.
* \param num_cuts_per_feature Number of cuts per feature.
* \param sorted_data Sorted entries in segments of columns.
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
*/
void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
size_t num_cuts_per_feature, Span<Entry> sorted_data,
Span<float> weights_scan,
Expand Down Expand Up @@ -259,7 +258,6 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), temp_weights.begin(),
EntryCompareOp());
std::vector<Entry> entries_t(sorted_entries.begin(), sorted_entries.end());

// Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
Expand All @@ -268,7 +266,6 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
[=] __device__(const Entry& a, const Entry& b) {
return a.index == b.index;
});
std::vector<float> weights_t(temp_weights.begin(), temp_weights.end());

dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan,
Expand Down Expand Up @@ -308,7 +305,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
}
dmat->Info().weights_.SetDevice(device);
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
size_t batch_nnz = batch.data.ConstHostVector().size();
size_t batch_nnz = batch.data.Size();
for (auto begin = 0ull; begin < batch_nnz;
begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
Expand Down Expand Up @@ -345,9 +342,10 @@ struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
}
};

// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}

Expand All @@ -357,19 +355,17 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
dh::XGBCachingDeviceAllocator<char> alloc;
adapter->BeforeFirst();
adapter->Next();
auto& batch = adapter->Value();
auto &batch = adapter->Value();
// Enforce single batch
CHECK(!adapter->Next());

auto batch_iter = MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
auto entry_iter = MakeTransformIterator<Entry>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return Entry(batch.GetElement(idx).column_idx,
batch.GetElement(idx).value);
});

// Work out how many valid entries we have in each column
dh::caching_device_vector<size_t> column_sizes_scan(adapter->NumColumns() + 1,
0);
Expand Down

0 comments on commit 5852f48

Please sign in to comment.