From 91563a5929c6a17a7e0ddc24997876c1a1cd1bae Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 26 Feb 2020 16:39:22 +1300 Subject: [PATCH 1/7] Sketching from adapters --- src/common/hist_util.cu | 664 ++++++++++--------------- src/common/hist_util.h | 18 +- src/data/adapter.h | 1 + src/data/ellpack_page.cu | 22 +- src/data/ellpack_page_source.cu | 22 +- tests/cpp/common/test_gpu_hist_util.cu | 105 ---- tests/cpp/common/test_hist_util.cu | 323 ++++++++++++ tests/cpp/common/test_hist_util.h | 51 +- 8 files changed, 657 insertions(+), 549 deletions(-) delete mode 100644 tests/cpp/common/test_gpu_hist_util.cu create mode 100644 tests/cpp/common/test_hist_util.cu diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 89a59ec8b19b..7368cb7e10aa 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -10,84 +10,27 @@ #include #include #include +#include -#include -#include #include #include +#include +#include -#include "hist_util.h" -#include "xgboost/host_device_vector.h" +#include "../data/adapter.h" +#include "../data/device_adapter.cuh" +#include "../tree/param.h" #include "device_helpers.cuh" +#include "hist_util.h" +#include "math.h" #include "quantile.h" -#include "../tree/param.h" +#include "xgboost/host_device_vector.h" namespace xgboost { namespace common { using WQSketch = DenseCuts::WQSketch; - -__global__ void FindCutsK(WQSketch::Entry* __restrict__ cuts, - const bst_float* __restrict__ data, - const float* __restrict__ cum_weights, - int nsamples, - int ncuts) { - // ncuts < nsamples - int icut = threadIdx.x + blockIdx.x * blockDim.x; - if (icut >= ncuts) { - return; - } - int isample = 0; - if (icut == 0) { - isample = 0; - } else if (icut == ncuts - 1) { - isample = nsamples - 1; - } else { - bst_float rank = cum_weights[nsamples - 1] / static_cast(ncuts - 1) - * static_cast(icut); - // -1 is used because cum_weights is an inclusive sum - isample = dh::UpperBound(cum_weights, nsamples, rank); - isample = max(0, min(isample, nsamples - 1)); - } - // repeated values will be filtered out on the CPU - bst_float rmin = isample > 0 ? cum_weights[isample - 1] : 0; - bst_float rmax = cum_weights[isample]; - cuts[icut] = WQSketch::Entry(rmin, rmax, rmax - rmin, data[isample]); -} - -// predictate for thrust filtering that returns true if the element is not a NaN -struct IsNotNaN { - __device__ bool operator()(float a) const { return !isnan(a); } -}; - -__global__ void UnpackFeaturesK(float* __restrict__ fvalues, - float* __restrict__ feature_weights, - const size_t* __restrict__ row_ptrs, - const float* __restrict__ weights, - Entry* entries, - size_t nrows_array, - size_t row_begin_ptr, - size_t nrows) { - size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x; - if (irow >= nrows) { - return; - } - size_t row_length = row_ptrs[irow + 1] - row_ptrs[irow]; - int icol = threadIdx.y + blockIdx.y * blockDim.y; - if (icol >= row_length) { - return; - } - Entry entry = entries[row_ptrs[irow] - row_begin_ptr + icol]; - size_t ind = entry.index * nrows_array + irow; - // if weights are present, ensure that a non-NaN value is written to weights - // if and only if it is also written to features - if (!isnan(entry.fvalue) && (weights == nullptr || !isnan(weights[irow]))) { - fvalues[ind] = entry.fvalue; - if (feature_weights != nullptr && weights != nullptr) { - feature_weights[ind] = weights[irow]; - } - } -} +using SketchEntry = WQSketch::Entry; /*! * \brief A container that holds the device sketches across all @@ -98,379 +41,280 @@ __global__ void UnpackFeaturesK(float* __restrict__ fvalues, */ struct SketchContainer { std::vector sketches_; // NOLINT - std::vector col_locks_; // NOLINT static constexpr int kOmpNumColsParallelizeLimit = 1000; - SketchContainer(int max_bin, DMatrix* dmat) : col_locks_(dmat->Info().num_col_) { - const MetaInfo& info = dmat->Info(); + SketchContainer(int max_bin, size_t num_columns, size_t num_rows) { // Initialize Sketches for this dmatrix - sketches_.resize(info.num_col_); -#pragma omp parallel for default(none) shared(info, max_bin) schedule(static) \ -if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT - for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT - sketches_[icol].Init(info.num_row_, 1.0 / (8 * max_bin)); + sketches_.resize(num_columns); +#pragma omp parallel for default(none) shared(max_bin) \ + schedule(static) if (num_columns > kOmpNumColsParallelizeLimit) // NOLINT + for (int icol = 0; icol < num_columns; ++icol) { // NOLINT + sketches_[icol].Init(num_rows, 1.0 / (8 * max_bin)); } } - // Prevent copying/assigning/moving this as its internals can't be assigned/copied/moved - SketchContainer(const SketchContainer &) = delete; - SketchContainer(const SketchContainer &&) = delete; - SketchContainer &operator=(const SketchContainer &) = delete; - SketchContainer &operator=(const SketchContainer &&) = delete; -}; - -// finds quantiles on the GPU -class GPUSketcher { - public: - GPUSketcher(int device, int max_bin, int gpu_nrows) - : device_(device), max_bin_(max_bin), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {} - - ~GPUSketcher() { // NOLINT - dh::safe_cuda(cudaSetDevice(device_)); - } - - void SketchBatch(const SparsePage &batch, const MetaInfo &info) { - n_rows_ = batch.Size(); - - Init(batch, info, gpu_batch_nrows_); - Sketch(batch, info); - ComputeRowStride(); + /** + * \brief Pushes cuts to the sketches. + * + * \param entries_per_column The entries per column. + * \param entries Vector of cuts from all columns, length + * entries_per_column * num_columns. \param column_scan Exclusive scan + * of column sizes. Used to detect cases where there are fewer entries than we + * have storage for. + */ + void Push(size_t entries_per_column, + const thrust::host_vector& entries, + const thrust::host_vector& column_scan) { +#pragma omp parallel for default(none) schedule(static) if (sketches_.size() > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT + for (int icol = 0; icol < sketches_.size(); ++icol) { + size_t column_size = column_scan[icol + 1] - column_scan[icol]; + if (column_size == 0) continue; + WQuantileSketch::SummaryContainer summary; + size_t num_available_cuts = + std::min(size_t(entries_per_column), column_size); + summary.Reserve(num_available_cuts); + summary.MakeFromSorted(&entries[entries_per_column * icol], + num_available_cuts); + + sketches_[icol].PushSummary(summary); + } } - /* Builds the sketches on the GPU for the dmatrix and returns the row stride - * for the entire dataset */ - size_t Sketch(DMatrix *dmat, DenseCuts *hmat) { - const MetaInfo& info = dmat->Info(); + // Prevent copying/assigning/moving this as its internals can't be + // assigned/copied/moved + SketchContainer(const SketchContainer&) = delete; + SketchContainer(const SketchContainer&&) = delete; + SketchContainer& operator=(const SketchContainer&) = delete; + SketchContainer& operator=(const SketchContainer&&) = delete; +}; - row_stride_ = 0; - sketch_container_.reset(new SketchContainer(max_bin_, dmat)); - for (const auto& batch : dmat->GetBatches()) { - this->SketchBatch(batch, info); +struct EntryCompareOp { + __device__ bool operator()(const Entry& a, const Entry& b) { + if (a.index == b.index) { + return a.fvalue < b.fvalue; } - - hmat->Init(&sketch_container_->sketches_, max_bin_, info.num_row_); - return row_stride_; + return a.index < b.index; } +}; - // This needs to be public because of the __device__ lambda. - void ComputeRowStride() { - // Find the row stride for this batch - auto row_iter = row_ptrs_.begin(); - // Functor for finding the maximum row size for this batch - auto get_size = [=] __device__(size_t row) { - return row_iter[row + 1] - row_iter[row]; - }; // NOLINT - - auto counting = thrust::make_counting_iterator(size_t(0)); - using TransformT = thrust::transform_iterator; - TransformT row_size_iter = TransformT(counting, get_size); - size_t batch_row_stride = - thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, thrust::maximum()); - row_stride_ = std::max(row_stride_, batch_row_stride); - } +/** + * \brief Extracts the cuts from sorted data. + * + * \param device The device. + * \param cuts Output cuts + * \param num_cuts_per_feature Number of cuts per feature. + * \param sorted_data Sorted entries in segments of columns + * \param column_sizes_scan Describes the boundaries of column segments in + * sorted data + */ +void ExtractCuts(int device, Span cuts, + size_t num_cuts_per_feature, Span sorted_data, + Span column_sizes_scan) { + dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) { + // Each thread is responsible for obtaining one cut from the sorted input + size_t column_idx = idx / num_cuts_per_feature; + size_t column_size = + column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; + size_t num_available_cuts = + std::min(size_t(num_cuts_per_feature), column_size); + size_t cut_idx = idx % num_cuts_per_feature; + if (cut_idx >= num_available_cuts) return; + + Span column_entries = + sorted_data.subspan(column_sizes_scan[column_idx], column_size); + + size_t rank = (column_entries.size() * cut_idx) / num_available_cuts; + auto value = column_entries[rank].fvalue; + cuts[idx] = SketchEntry(rank, rank + 1, 1, value); + }); +} - // This needs to be public because of the __device__ lambda. - void FindColumnCuts(size_t batch_nrows, size_t icol) { - size_t tmp_size = tmp_storage_.size(); - // filter out NaNs in feature values - auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_; - cub::DeviceSelect::If(tmp_storage_.data().get(), - tmp_size, - fvalues_begin, - fvalues_cur_.data(), - num_elements_.begin(), - batch_nrows, - IsNotNaN()); - size_t nfvalues_cur = 0; - thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur); - - // compute cumulative weights using a prefix scan - if (has_weights_) { - // filter out NaNs in weights; - // since cub::DeviceSelect::If performs stable filtering, - // the weights are stored in the correct positions - auto feature_weights_begin = feature_weights_.data() + icol * gpu_batch_nrows_; - cub::DeviceSelect::If(tmp_storage_.data().get(), - tmp_size, - feature_weights_begin, - weights_.data().get(), - num_elements_.begin(), - batch_nrows, - IsNotNaN()); - - // sort the values and weights - cub::DeviceRadixSort::SortPairs(tmp_storage_.data().get(), - tmp_size, - fvalues_cur_.data().get(), - fvalues_begin.get(), - weights_.data().get(), - weights2_.data().get(), - nfvalues_cur); - - // sum the weights to get cumulative weight values - cub::DeviceScan::InclusiveSum(tmp_storage_.data().get(), - tmp_size, - weights2_.begin(), - weights_.begin(), - nfvalues_cur); - } else { - // sort the batch values - cub::DeviceRadixSort::SortKeys(tmp_storage_.data().get(), - tmp_size, - fvalues_cur_.data().get(), - fvalues_begin.get(), - nfvalues_cur); - - // fill in cumulative weights with counting iterator - thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur, weights_.begin()); - } +void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, + SketchContainer* sketch_container, int num_cuts, + size_t num_columns) { + dh::XGBCachingDeviceAllocator alloc; + const auto& host_data = page.data.ConstHostVector(); + dh::device_vector sorted_entries(host_data.begin() + begin, + host_data.begin() + end); + thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), + sorted_entries.end(), EntryCompareOp()); + dh::caching_device_vector column_sizes_scan(num_columns + 1, 0); + auto d_column_sizes_scan = column_sizes_scan.data().get(); + auto d_sorted_entries = sorted_entries.data().get(); + dh::LaunchN(device, sorted_entries.size(), [=] __device__(size_t idx) { + auto& e = d_sorted_entries[idx]; + atomicAdd(reinterpret_cast( // NOLINT + &d_column_sizes_scan[e.index]), + static_cast(1)); // NOLINT + }); + thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan.begin(), + column_sizes_scan.end(), column_sizes_scan.begin()); + thrust::host_vector host_column_sizes_scan(column_sizes_scan); + + dh::caching_device_vector cuts(num_columns * num_cuts); + ExtractCuts(device, {cuts.data().get(), cuts.size()}, num_cuts, + {sorted_entries.data().get(), sorted_entries.size()}, + {column_sizes_scan.data().get(), column_sizes_scan.size()}); + + // add cuts into sketches + thrust::host_vector host_cuts(cuts); + sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); +} - // remove repeated items and sum the weights across them; - // non-negative weights are assumed - cub::DeviceReduce::ReduceByKey(tmp_storage_.data().get(), - tmp_size, - fvalues_begin, - fvalues_cur_.begin(), - weights_.begin(), - weights2_.begin(), - num_elements_.begin(), - thrust::maximum(), - nfvalues_cur); - size_t n_unique = 0; - thrust::copy_n(num_elements_.begin(), 1, &n_unique); - - // extract cuts - n_cuts_cur_[icol] = std::min(n_cuts_, n_unique); - // if less elements than cuts: copy all elements with their weights - if (n_cuts_ > n_unique) { - float* weights2_ptr = weights2_.data().get(); - float* fvalues_ptr = fvalues_cur_.data().get(); - WQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_; - dh::LaunchN(device_, n_unique, [=]__device__(size_t i) { - bst_float rmax = weights2_ptr[i]; - bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0; - cuts_ptr[i] = WQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]); - }); - } else if (n_cuts_cur_[icol] > 0) { - // if more elements than cuts: use binary search on cumulative weights - uint32_t constexpr kBlockThreads = 256; - uint32_t const kGrids = common::DivRoundUp(n_cuts_cur_[icol], kBlockThreads); - dh::LaunchKernel {kGrids, kBlockThreads} ( - FindCutsK, - cuts_d_.data().get() + icol * n_cuts_, - fvalues_cur_.data().get(), - weights2_.data().get(), - n_unique, - n_cuts_cur_[icol]); - dh::safe_cuda(cudaGetLastError()); // NOLINT +HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, + size_t sketch_batch_num_elements) { + HistogramCuts cuts; + DenseCuts dense_cuts(&cuts); + SketchContainer sketch_container(max_bins, dmat->Info().num_col_, + dmat->Info().num_row_); + + constexpr int kFactor = 8; + double eps = 1.0 / (kFactor * max_bins); + size_t dummy_nlevel; + size_t num_cuts; + WQuantileSketch::LimitSizeLevel( + dmat->Info().num_row_, eps, &dummy_nlevel, &num_cuts); + num_cuts = std::min(num_cuts, dmat->Info().num_row_); + if (sketch_batch_num_elements == 0) { + sketch_batch_num_elements = dmat->Info().num_nonzero_; + } + for (const auto& batch : dmat->GetBatches()) { + size_t batch_nnz = batch.data.ConstHostVector().size(); + for (auto begin = 0ull; begin < batch_nnz; + begin += sketch_batch_num_elements) { + size_t end = std::min(batch_nnz, begin + sketch_batch_num_elements); + ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts, + dmat->Info().num_col_); } } - private: - void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) { - num_cols_ = info.num_col_; - has_weights_ = info.weights_.Size() > 0; - - // find the batch size - if (gpu_batch_nrows == 0) { - // By default, use no more than 1/16th of GPU memory - gpu_batch_nrows_ = dh::TotalMemory(device_) / (16 * num_cols_ * sizeof(Entry)); - } else if (gpu_batch_nrows == -1) { - gpu_batch_nrows_ = n_rows_; - } else { - gpu_batch_nrows_ = gpu_batch_nrows; - } - if (gpu_batch_nrows_ > n_rows_) { - gpu_batch_nrows_ = n_rows_; - } + dense_cuts.Init(&sketch_container.sketches_, max_bins, dmat->Info().num_row_); + return cuts; +} - constexpr int kFactor = 8; - double eps = 1.0 / (kFactor * max_bin_); - size_t dummy_nlevel; - WQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_); - - // allocate necessary GPU buffers - dh::safe_cuda(cudaSetDevice(device_)); - - entries_.resize(gpu_batch_nrows_ * num_cols_); - fvalues_.resize(gpu_batch_nrows_ * num_cols_); - fvalues_cur_.resize(gpu_batch_nrows_); - cuts_d_.resize(n_cuts_ * num_cols_); - cuts_h_.resize(n_cuts_ * num_cols_); - weights_.resize(gpu_batch_nrows_); - weights2_.resize(gpu_batch_nrows_); - num_elements_.resize(1); - - if (has_weights_) { - feature_weights_.resize(gpu_batch_nrows_ * num_cols_); - } - n_cuts_cur_.resize(num_cols_); - - // allocate storage for CUB algorithms; the size is the maximum of the sizes - // required for various algorithm - size_t tmp_size = 0, cur_tmp_size = 0; - // size for sorting - if (has_weights_) { - cub::DeviceRadixSort::SortPairs(nullptr, - cur_tmp_size, - fvalues_cur_.data().get(), - fvalues_.data().get(), - weights_.data().get(), - weights2_.data().get(), - gpu_batch_nrows_); - } else { - cub::DeviceRadixSort::SortKeys(nullptr, - cur_tmp_size, - fvalues_cur_.data().get(), - fvalues_.data().get(), - gpu_batch_nrows_); - } - tmp_size = std::max(tmp_size, cur_tmp_size); - // size for inclusive scan - if (has_weights_) { - cub::DeviceScan::InclusiveSum(nullptr, - cur_tmp_size, - weights2_.begin(), - weights_.begin(), - gpu_batch_nrows_); - tmp_size = std::max(tmp_size, cur_tmp_size); +struct IsValidFunctor : public thrust::unary_function { + explicit IsValidFunctor(float missing) : missing(missing) {} + + float missing; + __device__ bool operator()(const data::COOTuple& e) const { + if (common::CheckNAN(e.value) || e.value == missing) { + return false; } - // size for reduction by key - cub::DeviceReduce::ReduceByKey(nullptr, - cur_tmp_size, - fvalues_.begin(), - fvalues_cur_.begin(), - weights_.begin(), - weights2_.begin(), - num_elements_.begin(), - thrust::maximum(), - gpu_batch_nrows_); - tmp_size = std::max(tmp_size, cur_tmp_size); - // size for filtering - cub::DeviceSelect::If(nullptr, - cur_tmp_size, - fvalues_.begin(), - fvalues_cur_.begin(), - num_elements_.begin(), - gpu_batch_nrows_, - IsNotNaN()); - tmp_size = std::max(tmp_size, cur_tmp_size); - - tmp_storage_.resize(tmp_size); + return true; } - - void Sketch(const SparsePage& row_batch, const MetaInfo& info) { - // copy rows to the device - dh::safe_cuda(cudaSetDevice(device_)); - const auto& offset_vec = row_batch.offset.HostVector(); - row_ptrs_.resize(n_rows_ + 1); - thrust::copy(offset_vec.data(), offset_vec.data() + n_rows_ + 1, row_ptrs_.begin()); - size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_); - for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { - SketchBatch(row_batch, info, gpu_batch); + __device__ bool operator()(const Entry& e) const { + if (common::CheckNAN(e.fvalue) || e.fvalue == missing) { + return false; } + return true; } +}; - void SketchBatch(const SparsePage& row_batch, const MetaInfo& info, size_t gpu_batch) { - // compute start and end indices - size_t batch_row_begin = gpu_batch * gpu_batch_nrows_; - size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_, - static_cast(n_rows_)); - size_t batch_nrows = batch_row_end - batch_row_begin; - - const auto& offset_vec = row_batch.offset.HostVector(); - const auto& data_vec = row_batch.data.HostVector(); - - size_t n_entries = offset_vec[batch_row_end] - offset_vec[batch_row_begin]; - // copy the batch to the GPU - dh::safe_cuda(cudaMemcpyAsync(entries_.data().get(), - data_vec.data() + offset_vec[batch_row_begin], - n_entries * sizeof(Entry), - cudaMemcpyDefault)); - // copy the weights if necessary - if (has_weights_) { - const auto& weights_vec = info.weights_.HostVector(); - dh::safe_cuda(cudaMemcpyAsync(weights_.data().get(), - weights_vec.data() + batch_row_begin, - batch_nrows * sizeof(bst_float), - cudaMemcpyDefault)); - } +template +thrust::transform_iterator MakeTransformIterator( + IterT iter, FuncT func) { + return thrust::transform_iterator(iter, func); +} - // unpack the features; also unpack weights if present - thrust::fill(fvalues_.begin(), fvalues_.end(), NAN); - if (has_weights_) { - thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN); - } +template +void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing, + SketchContainer* sketch_container, int num_cuts) { + dh::XGBCachingDeviceAllocator alloc; + adapter->BeforeFirst(); + adapter->Next(); + auto& batch = adapter->Value(); + // Enforce single batch + CHECK(!adapter->Next()); + + auto batch_iter = MakeTransformIterator( + thrust::make_counting_iterator(0llu), + [=] __device__(size_t idx) { return batch.GetElement(idx); }); + auto entry_iter = MakeTransformIterator( + thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { + return Entry(batch.GetElement(idx).column_idx, + batch.GetElement(idx).value); + }); - dim3 block3(16, 64, 1); - // NOTE: This will typically support ~ 4M features - 64K*64 - dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), - common::DivRoundUp(num_cols_, block3.y), 1); - dh::LaunchKernel {grid3, block3} ( - UnpackFeaturesK, - fvalues_.data().get(), - has_weights_ ? feature_weights_.data().get() : nullptr, - row_ptrs_.data().get() + batch_row_begin, - has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(), - gpu_batch_nrows_, - offset_vec[batch_row_begin], - batch_nrows); - - for (int icol = 0; icol < num_cols_; ++icol) { - FindColumnCuts(batch_nrows, icol); + // Work out how many valid entries we have in each column + dh::caching_device_vector column_sizes_scan(adapter->NumColumns() + 1, + 0); + auto d_column_sizes_scan = column_sizes_scan.data().get(); + IsValidFunctor is_valid(missing); + dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) { + auto& e = batch_iter[begin + idx]; + if (is_valid(e)) { + atomicAdd(reinterpret_cast( // NOLINT + &d_column_sizes_scan[e.column_idx]), + static_cast(1)); // NOLINT } + }); + thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan.begin(), + column_sizes_scan.end(), column_sizes_scan.begin()); + thrust::host_vector host_column_sizes_scan(column_sizes_scan); + size_t num_valid = host_column_sizes_scan.back(); + + // Copy current subset of valid elements into temporary storage and sort + thrust::device_vector sorted_entries(num_valid); + thrust::copy_if(thrust::cuda::par(alloc), entry_iter + begin, + entry_iter + end, sorted_entries.begin(), is_valid); + thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), + sorted_entries.end(), EntryCompareOp()); + + // Extract the cuts from all columns concurrently + dh::caching_device_vector cuts(adapter->NumColumns() * num_cuts); + ExtractCuts(adapter->DeviceIdx(), {cuts.data().get(), cuts.size()}, num_cuts, + {sorted_entries.data().get(), sorted_entries.size()}, + {column_sizes_scan.data().get(), column_sizes_scan.size()}); + + // Push cuts into sketches stored in host memory + thrust::host_vector host_cuts(cuts); + sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); +} - // add cuts into sketches - thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); -#pragma omp parallel for default(none) schedule(static) \ -if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT - for (int icol = 0; icol < num_cols_; ++icol) { - WQSketch::SummaryContainer summary; - summary.Reserve(n_cuts_); - summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]); - - std::lock_guard lock(sketch_container_->col_locks_[icol]); - sketch_container_->sketches_[icol].PushSummary(summary); - } +template +HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, + float missing, + size_t sketch_batch_num_elements) { + CHECK(adapter->NumRows() != data::kAdapterUnknownSize); + CHECK(adapter->NumColumns() != data::kAdapterUnknownSize); + + adapter->BeforeFirst(); + adapter->Next(); + auto& batch = adapter->Value(); + + // Enforce single batch + CHECK(!adapter->Next()); + + HistogramCuts cuts; + DenseCuts dense_cuts(&cuts); + SketchContainer sketch_container(num_bins, adapter->NumColumns(), + adapter->NumRows()); + + constexpr int kFactor = 8; + double eps = 1.0 / (kFactor * num_bins); + size_t dummy_nlevel; + size_t num_cuts; + WQuantileSketch::LimitSizeLevel( + adapter->NumRows(), eps, &dummy_nlevel, &num_cuts); + num_cuts = std::min(num_cuts, adapter->NumRows()); + if (sketch_batch_num_elements == 0) { + sketch_batch_num_elements = batch.Size(); + } + for (auto begin = 0ull; begin < batch.Size(); + begin += sketch_batch_num_elements) { + size_t end = std::min(batch.Size(), begin + sketch_batch_num_elements); + ProcessBatch(adapter, begin, end, missing, &sketch_container, num_cuts); } - const int device_; - const int max_bin_; - int gpu_batch_nrows_; - size_t row_stride_; - std::unique_ptr sketch_container_; - - bst_uint n_rows_{}; - int num_cols_{0}; - size_t n_cuts_{0}; - bool has_weights_{false}; - - dh::device_vector row_ptrs_{}; - dh::device_vector entries_{}; - dh::device_vector fvalues_{}; - dh::device_vector feature_weights_{}; - dh::device_vector fvalues_cur_{}; - dh::device_vector cuts_d_{}; - thrust::host_vector cuts_h_{}; - dh::device_vector weights_{}; - dh::device_vector weights2_{}; - std::vector n_cuts_cur_{}; - dh::device_vector num_elements_{}; - dh::device_vector tmp_storage_{}; -}; - -size_t DeviceSketch(int device, - int max_bin, - int gpu_batch_nrows, - DMatrix* dmat, - HistogramCuts* hmat) { - GPUSketcher sketcher(device, max_bin, gpu_batch_nrows); - // We only need to return the result in HistogramCuts container, so it is safe to - // use a pointer of local HistogramCutsDense - DenseCuts dense_cuts(hmat); - auto res = sketcher.Sketch(dmat, &dense_cuts); - return res; + dense_cuts.Init(&sketch_container.sketches_, num_bins, adapter->NumRows()); + return cuts; } +template HistogramCuts AdapterDeviceSketch(data::CudfAdapter* adapter, + int num_bins, float missing, + size_t sketch_batch_size); +template HistogramCuts AdapterDeviceSketch(data::CupyAdapter* adapter, + int num_bins, float missing, + size_t sketch_batch_size); } // namespace common } // namespace xgboost diff --git a/src/common/hist_util.h b/src/common/hist_util.h index aa0c57ab4034..2c3cb3d4a798 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -179,16 +179,14 @@ class DenseCuts : public CutsBuilder { void Build(DMatrix* p_fmat, uint32_t max_num_bins) override; }; -// FIXME(trivialfis): Merge this into generic cut builder. -/*! \brief Builds the cut matrix on the GPU. - * - * \return The row stride across the entire dataset. - */ -size_t DeviceSketch(int device, - int max_bin, - int gpu_batch_nrows, - DMatrix* dmat, - HistogramCuts* hmat); + +HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, + size_t sketch_batch_num_elements = 10000000); + +template +HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, + float missing, + size_t sketch_batch_num_elements = 10000000); /*! * \brief preprocessed global index matrix, in CSR format diff --git a/src/data/adapter.h b/src/data/adapter.h index ad40e390246f..9d26d5e14126 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -71,6 +71,7 @@ namespace data { constexpr size_t kAdapterUnknownSize = std::numeric_limits::max(); struct COOTuple { + COOTuple() = default; XGBOOST_DEVICE COOTuple(size_t row_idx, size_t column_idx, float value) : row_idx(row_idx), column_idx(column_idx), value(value) {} diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index ff86d77aada2..ccff62bea230 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -78,6 +78,20 @@ EllpackPageImpl::EllpackPageImpl(int device, EllpackInfo info, size_t n_rows) { monitor_.StopCuda("InitCompressedData"); } +size_t GetRowStride(DMatrix* dmat) { + if (dmat->IsDense()) return dmat->Info().num_col_; + + size_t row_stride = 0; + for (const auto& batch : dmat->GetBatches()) { + const auto& row_offset = batch.offset.ConstHostVector(); + for (auto i = 1ull; i < row_offset.size(); i++) { + row_stride = std::max( + row_stride, static_cast(row_offset[i] - row_offset[i - 1])); + } + } + return row_stride; +} + // Construct an ELLPACK matrix in memory. EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { monitor_.Init("ellpack_page"); @@ -87,13 +101,13 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) { monitor_.StartCuda("Quantiles"); // Create the quantile sketches for the dmatrix and initialize HistogramCuts. - common::HistogramCuts hmat; - size_t row_stride = - common::DeviceSketch(param.gpu_id, param.max_bin, param.gpu_batch_nrows, dmat, &hmat); + size_t row_stride = GetRowStride(dmat); + auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin, + param.gpu_batch_nrows); monitor_.StopCuda("Quantiles"); monitor_.StartCuda("InitEllpackInfo"); - InitInfo(param.gpu_id, dmat->IsDense(), row_stride, hmat); + InitInfo(param.gpu_id, dmat->IsDense(), row_stride, cuts); monitor_.StopCuda("InitEllpackInfo"); monitor_.StartCuda("InitCompressedData"); diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index b16bde6a4cd2..f1befaf2a5b3 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -70,6 +70,20 @@ const EllpackPage& EllpackPageSource::Value() const { return impl_->Value(); } +size_t GetRowStride(DMatrix* dmat) { + if (dmat->IsDense()) return dmat->Info().num_col_; + + size_t row_stride = 0; + for (const auto& batch : dmat->GetBatches()) { + const auto& row_offset = batch.offset.ConstHostVector(); + for (auto i = 1ull; i < row_offset.size(); i++) { + row_stride = std::max( + row_stride, static_cast(row_offset[i] - row_offset[i - 1])); + } + } + return row_stride; +} + // Build the quantile sketch across the whole input data, then use the histogram cuts to compress // each CSR page, and write the accumulated ELLPACK pages to disk. EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat, @@ -85,13 +99,13 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat, dh::safe_cuda(cudaSetDevice(device_)); monitor_.StartCuda("Quantiles"); - common::HistogramCuts hmat; - size_t row_stride = - common::DeviceSketch(device_, param.max_bin, param.gpu_batch_nrows, dmat, &hmat); + size_t row_stride = GetRowStride(dmat); + auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin, + param.gpu_batch_nrows); monitor_.StopCuda("Quantiles"); monitor_.StartCuda("CreateEllpackInfo"); - ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, &ba_); + ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, cuts, &ba_); monitor_.StopCuda("CreateEllpackInfo"); monitor_.StartCuda("WriteEllpackPages"); diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu deleted file mode 100644 index 9bbd404c7dc8..000000000000 --- a/tests/cpp/common/test_gpu_hist_util.cu +++ /dev/null @@ -1,105 +0,0 @@ -#include -#include - -#include -#include - - -#include -#include - -#include "xgboost/c_api.h" - -#include "../../../src/common/device_helpers.cuh" -#include "../../../src/common/hist_util.h" - -#include "../helpers.h" - -namespace xgboost { -namespace common { - -void TestDeviceSketch(bool use_external_memory) { - // create the data - int nrows = 10001; - std::shared_ptr *dmat = nullptr; - - size_t num_cols = 1; - dmlc::TemporaryDirectory tmpdir; - std::string file = tmpdir.path + "/big.libsvm"; - if (use_external_memory) { - auto sp_dmat = CreateSparsePageDMatrix(nrows * 3, 128UL, file); // 3 entries/row - dmat = new std::shared_ptr(std::move(sp_dmat)); - num_cols = 5; - } else { - std::vector test_data(nrows); - auto count_iter = thrust::make_counting_iterator(0); - // fill in reverse order - std::copy(count_iter, count_iter + nrows, test_data.rbegin()); - - // create the DMatrix - DMatrixHandle dmat_handle; - XGDMatrixCreateFromMat(test_data.data(), nrows, 1, -1, - &dmat_handle); - dmat = static_cast *>(dmat_handle); - } - - int device{0}; - int max_bin{20}; - int gpu_batch_nrows{0}; - - // find quantiles on the CPU - HistogramCuts hmat_cpu; - hmat_cpu.Build((*dmat).get(), max_bin); - - // find the cuts on the GPU - HistogramCuts hmat_gpu; - size_t row_stride = DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &hmat_gpu); - - // compare the row stride with the one obtained from the dmatrix - bst_row_t expected_row_stride = 0; - for (const auto &batch : dmat->get()->GetBatches()) { - const auto &offset_vec = batch.offset.ConstHostVector(); - for (int i = 1; i <= offset_vec.size() -1; ++i) { - expected_row_stride = std::max(expected_row_stride, offset_vec[i] - offset_vec[i-1]); - } - } - - ASSERT_EQ(expected_row_stride, row_stride); - - // compare the cuts - double eps = 1e-2; - ASSERT_EQ(hmat_gpu.MinValues().size(), num_cols); - ASSERT_EQ(hmat_gpu.Ptrs().size(), num_cols + 1); - ASSERT_EQ(hmat_gpu.Values().size(), hmat_cpu.Values().size()); - ASSERT_LT(fabs(hmat_cpu.MinValues()[0] - hmat_gpu.MinValues()[0]), eps * nrows); - for (int i = 0; i < hmat_gpu.Values().size(); ++i) { - ASSERT_LT(fabs(hmat_cpu.Values()[i] - hmat_gpu.Values()[i]), eps * nrows); - } - - // Determinstic - size_t constexpr kRounds { 100 }; - for (size_t r = 0; r < kRounds; ++r) { - HistogramCuts new_sketch; - DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &new_sketch); - ASSERT_EQ(hmat_gpu.Values().size(), new_sketch.Values().size()); - for (size_t i = 0; i < hmat_gpu.Values().size(); ++i) { - ASSERT_EQ(hmat_gpu.Values()[i], new_sketch.Values()[i]); - } - for (size_t i = 0; i < hmat_gpu.MinValues().size(); ++i) { - ASSERT_EQ(hmat_gpu.MinValues()[i], new_sketch.MinValues()[i]); - } - } - - delete dmat; -} - -TEST(gpu_hist_util, DeviceSketch) { - TestDeviceSketch(false); -} - -TEST(gpu_hist_util, DeviceSketch_ExternalMemory) { - TestDeviceSketch(true); -} - -} // namespace common -} // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu new file mode 100644 index 000000000000..dceac3ebf124 --- /dev/null +++ b/tests/cpp/common/test_hist_util.cu @@ -0,0 +1,323 @@ +#include +#include + +#include +#include + + +#include +#include + +#include "xgboost/c_api.h" + +#include "../../../src/common/device_helpers.cuh" +#include "../../../src/common/hist_util.h" + +#include "../helpers.h" +#include +#include "../../../src/data/device_adapter.cuh" +#include "../data/test_array_interface.h" +#include "../../../src/common/math.h" +#include "../../../src/data/simple_dmatrix.h" +#include "test_hist_util.h" + +namespace xgboost { +namespace common { + +template +HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) { + HistogramCuts cuts; + DenseCuts builder(&cuts); + data::SimpleDMatrix dmat(adapter, missing, 1); + builder.Build(&dmat, num_bins); + return cuts; +} + +TEST(hist_util, DeviceSketch) { + int num_rows = 5; + int num_columns = 1; + int num_bins = 4; + std::vector x = {1.0, 2.0, 3.0, 4.0, 5.0}; + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + + auto device_cuts = DeviceSketch(0, dmat.get(), num_bins, 0); + HistogramCuts host_cuts; + DenseCuts builder(&host_cuts); + builder.Build(dmat.get(), num_bins); + + EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); + EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); + EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues()); +} + +TEST(hist_util, DeviceSketchDeterminism) { + int num_rows = 500; + int num_columns = 5; + int num_bins = 256; + auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + auto reference_sketch = DeviceSketch(0, dmat.get(), num_bins); + size_t constexpr kRounds{ 100 }; + for (size_t r = 0; r < kRounds; ++r) { + auto new_sketch = DeviceSketch(0, dmat.get(), num_bins); + ASSERT_EQ(reference_sketch.Values(), new_sketch.Values()); + ASSERT_EQ(reference_sketch.MinValues(), new_sketch.MinValues()); + } +} + + TEST(hist_util, DeviceSketchCategorical) { + int categorical_sizes[] = {2, 6, 8, 12}; + int num_bins = 256; + int sizes[] = {25, 100, 1000}; + for (auto n : sizes) { + for (auto num_categories : categorical_sizes) { + auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); + auto dmat = GetDMatrixFromData(x, n, 1); + auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); + ValidateCuts(cuts, x, n, 1, num_bins); + } + } +} + +TEST(hist_util, DeviceSketchMultipleColumns) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; + int num_columns = 2; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + for (auto num_bins : bin_sizes) { + auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); + ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + } + } +} + +TEST(hist_util, DeviceSketchBatches) { + int num_bins = 256; + int num_rows = 5000; + int batch_sizes[] = {0, 100, 1500, 6000}; + int num_columns = 5; + for (auto batch_size : batch_sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size); + ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + } +} + +TEST(hist_util, DeviceSketchMultipleColumnsExternal) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; + int num_columns = 2; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + dmlc::TemporaryDirectory temp; + auto dmat = + GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 100, temp); + for (auto num_bins : bin_sizes) { + auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); + ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + } + } +} + + +TEST(hist_util, AdapterDeviceSketch) +{ + int rows = 5; + int cols = 1; + int num_bins = 4; + float missing = - 1.0; + thrust::device_vector< float> data(rows*cols); + auto json_array_interface = Generate2dArrayInterface(rows, cols, "{ 1.0,2.0,3.0,4.0,5.0 }; + std::stringstream ss; + Json::Dump(json_array_interface, &ss); + std::string str = ss.str(); + data::CupyAdapter adapter(str); + + auto device_cuts = AdapterDeviceSketch(&adapter, num_bins, missing); + auto host_cuts = GetHostCuts(&adapter, num_bins, missing); + + EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); + EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); + EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues()); +} + + TEST(hist_util, AdapterDeviceSketchCategorical) { + int categorical_sizes[] = {2, 6, 8, 12}; + int num_bins = 256; + int sizes[] = {25, 100, 1000}; + for (auto n : sizes) { + for (auto num_categories : categorical_sizes) { + auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); + auto x_device = thrust::device_vector(x); + auto adapter = AdapterFromData(x_device, n, 1); + auto cuts = AdapterDeviceSketch(&adapter, num_bins, + std::numeric_limits::quiet_NaN()); + ValidateCuts(cuts, x, n, 1, num_bins); + } + } +} + +TEST(hist_util, AdapterDeviceSketchMultipleColumns) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; + int num_columns = 5; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto x_device = thrust::device_vector(x); + for (auto num_bins : bin_sizes) { + auto adapter = AdapterFromData(x_device, num_rows, num_columns); + auto cuts = AdapterDeviceSketch(&adapter, num_bins, + std::numeric_limits::quiet_NaN()); + ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + } + } +} +TEST(hist_util, AdapterDeviceSketchBatches) { + int num_bins = 256; + int num_rows = 5000; + int batch_sizes[] = {0, 100, 1500, 6000}; + int num_columns = 5; + for (auto batch_size : batch_sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto x_device = thrust::device_vector(x); + auto adapter = AdapterFromData(x_device, num_rows, num_columns); + auto cuts = AdapterDeviceSketch(&adapter, num_bins, + std::numeric_limits::quiet_NaN(), + batch_size); + ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + } +} + +TEST(hist_util, Benchmark) { + int num_bins = 256; + std::vector sizes; + for (auto i = 8ull; i < 26; i += 2) { + sizes.push_back(1 << i); + } + + std::cout << "Num rows, "; + for (auto n : sizes) { + std::cout << n << ", "; + } + std::cout << "\n"; + int num_columns = 5; + std::cout << "AdapterDeviceSketch, "; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto x_device = thrust::device_vector(x); + auto adapter = AdapterFromData(x_device, num_rows, num_columns); + Timer t; + t.Start(); + auto cuts = AdapterDeviceSketch(&adapter, num_bins, + std::numeric_limits::quiet_NaN()); + t.Stop(); + std::cout << t.ElapsedSeconds() << ", "; + } + std::cout << "\n"; + + std::cout << "DeviceSketch, "; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + dmlc::TemporaryDirectory tmpdir; + auto dmat = + GetDMatrixFromData(x, num_rows, num_columns); + Timer t; + t.Start(); + auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); + t.Stop(); + std::cout << t.ElapsedSeconds() << ", "; + } + std::cout << "\n"; + + std::cout << "WQSketch, "; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + dmlc::TemporaryDirectory tmpdir; + auto dmat = + GetDMatrixFromData(x, num_rows, num_columns); + HistogramCuts cuts; + DenseCuts dense(&cuts); + Timer t; + t.Start(); + dense.Build(dmat.get(), num_bins); + t.Stop(); + std::cout << t.ElapsedSeconds() << ", "; + } + std::cout << "\n"; +} +TEST(hist_util, BenchmarkNumColumns) { + int num_bins = 256; + int num_rows = 10; + std::vector num_columns; + for (auto i = 4ull; i < 16; i += 2) { + num_columns.push_back(1 << i); + } + + std::cout << "Num columns, "; + for (auto n : num_columns) { + std::cout << n << ", "; + } + std::cout << "\n"; + std::cout << "AdapterDeviceSketch, "; + for (auto num_column : num_columns) { + auto x = GenerateRandom(num_rows, num_column); + auto x_device = thrust::device_vector(x); + auto adapter = AdapterFromData(x_device, num_rows, num_column); + Timer t; + t.Start(); + auto cuts = AdapterDeviceSketch(&adapter, num_bins, + std::numeric_limits::quiet_NaN()); + t.Stop(); + std::cout << t.ElapsedSeconds() << ", "; + } + std::cout << "\n"; + std::cout << "DeviceSketch, "; + for (auto num_column : num_columns) { + auto x = GenerateRandom(num_rows, num_column); + dmlc::TemporaryDirectory tmpdir; + auto dmat = + GetDMatrixFromData(x, num_rows, num_column); + Timer t; + t.Start(); + auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); + t.Stop(); + std::cout << t.ElapsedSeconds() << ", "; + } + std::cout << "\n"; + std::cout << "SparseCuts, "; + for (auto num_column : num_columns) { + auto x = GenerateRandom(num_rows, num_column); + dmlc::TemporaryDirectory tmpdir; + auto dmat = + GetDMatrixFromData(x, num_rows, num_column); + HistogramCuts cuts; + SparseCuts sparse(&cuts); + Timer t; + t.Start(); + sparse.Build(dmat.get(), num_bins); + t.Stop(); + std::cout << t.ElapsedSeconds() << ", "; + } + std::cout << "\n"; + std::cout << "DenseCuts, "; + for (auto num_column : num_columns) { + auto x = GenerateRandom(num_rows, num_column); + dmlc::TemporaryDirectory tmpdir; + auto dmat = + GetDMatrixFromData(x, num_rows, num_column); + HistogramCuts cuts; + DenseCuts dense(&cuts); + Timer t; + t.Start(); + dense.Build(dmat.get(), num_bins); + t.Stop(); + std::cout << t.ElapsedSeconds() << ", "; + } + std::cout << "\n"; +} +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 663fbc4cf474..9b3fa4ac73d1 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -28,6 +28,26 @@ inline std::vector GenerateRandom(int num_rows, int num_columns) { return x; } +#ifdef __CUDACC__ +inline data::CupyAdapter AdapterFromData(const thrust::device_vector &x, + int num_rows, int num_columns) { + Json array_interface{Object()}; + std::vector shape = {Json(static_cast(num_rows)), + Json(static_cast(num_columns))}; + array_interface["shape"] = Array(shape); + std::vector j_data{ + Json(Integer(reinterpret_cast(x.data().get()))), + Json(Boolean(false))}; + array_interface["data"] = j_data; + array_interface["version"] = Integer(static_cast(1)); + array_interface["typestr"] = String(" GenerateRandomCategoricalSingleColumn(int n, int num_categories) { std::vector x(n); @@ -126,26 +146,25 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, EXPECT_EQ(std::set(cuts_begin, cuts_end).size(), cuts_end - cuts_begin); - if (sorted_column.size() <= num_bins) { + auto unique = std::set(sorted_column.begin(), sorted_column.end()); + if (unique.size() <= num_bins) { // Less unique values than number of bins // Each value should get its own bin - - // First check the inputs are unique - int num_unique = - std::set(sorted_column.begin(), sorted_column.end()).size(); - EXPECT_EQ(num_unique, sorted_column.size()); - for (auto i = 0ull; i < sorted_column.size(); i++) { - ASSERT_EQ(cuts.SearchBin(sorted_column[i], column_idx), - cuts.Ptrs()[column_idx] + i); + int i = 0; + for (auto v : unique) { + ASSERT_EQ(cuts.SearchBin(v, column_idx), cuts.Ptrs()[column_idx] + i); + i++; } } - int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx]; - std::vector column_cuts(num_cuts_column); - std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx], - cuts.Values().begin() + cuts.Ptrs()[column_idx + 1], - column_cuts.begin()); - TestBinDistribution(cuts, column_idx, sorted_column, num_bins); - TestRank(column_cuts, sorted_column); + else { + int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx]; + std::vector column_cuts(num_cuts_column); + std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx], + cuts.Values().begin() + cuts.Ptrs()[column_idx + 1], + column_cuts.begin()); + TestBinDistribution(cuts, column_idx, sorted_column, num_bins); + TestRank(column_cuts, sorted_column); + } } // x is dense and row major From c37c6f8662a73fddc7f9977cce1dad93a0338b40 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 2 Mar 2020 14:19:32 +1300 Subject: [PATCH 2/7] Add weights test --- tests/cpp/common/test_hist_util.cc | 28 +++++++-- tests/cpp/common/test_hist_util.cu | 18 +++--- tests/cpp/common/test_hist_util.h | 94 ++++++++++++++++++++---------- 3 files changed, 97 insertions(+), 43 deletions(-) diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 33729e32a58e..254c969e9903 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -261,7 +261,27 @@ TEST(hist_util, DenseCutsAccuracyTest) { HistogramCuts cuts; DenseCuts dense(&cuts); dense.Build(dmat.get(), num_bins); - ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); + } + } +} + +TEST(hist_util, DenseCutsAccuracyTestWeights) { + //int bin_sizes[] = {2, 16, 256, 512}; + //int sizes[] = {100, 1000, 1500}; + int bin_sizes[] = {2}; + int sizes[] = {100}; + int num_columns = 5; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + auto w = GenerateRandomWeights(num_rows); + dmat->Info().weights_.HostVector() = w; + for (auto num_bins : bin_sizes) { + HistogramCuts cuts; + DenseCuts dense(&cuts); + dense.Build(dmat.get(), num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } } @@ -279,7 +299,7 @@ TEST(hist_util, DenseCutsExternalMemory) { HistogramCuts cuts; DenseCuts dense(&cuts); dense.Build(dmat.get(), num_bins); - ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } } @@ -295,7 +315,7 @@ TEST(hist_util, SparseCutsAccuracyTest) { HistogramCuts cuts; SparseCuts sparse(&cuts); sparse.Build(dmat.get(), num_bins); - ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } } @@ -335,7 +355,7 @@ TEST(hist_util, SparseCutsExternalMemory) { HistogramCuts cuts; SparseCuts dense(&cuts); dense.Build(dmat.get(), num_bins); - ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } } diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index dceac3ebf124..5a81a203f3bd 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -32,7 +32,6 @@ HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) { builder.Build(&dmat, num_bins); return cuts; } - TEST(hist_util, DeviceSketch) { int num_rows = 5; int num_columns = 1; @@ -74,7 +73,7 @@ TEST(hist_util, DeviceSketchDeterminism) { auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); auto dmat = GetDMatrixFromData(x, n, 1); auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); - ValidateCuts(cuts, x, n, 1, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } } @@ -88,7 +87,7 @@ TEST(hist_util, DeviceSketchMultipleColumns) { auto dmat = GetDMatrixFromData(x, num_rows, num_columns); for (auto num_bins : bin_sizes) { auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); - ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } } @@ -102,7 +101,7 @@ TEST(hist_util, DeviceSketchBatches) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size); - ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -117,7 +116,7 @@ TEST(hist_util, DeviceSketchMultipleColumnsExternal) { GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 100, temp); for (auto num_bins : bin_sizes) { auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); - ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } } @@ -152,11 +151,12 @@ TEST(hist_util, AdapterDeviceSketch) for (auto n : sizes) { for (auto num_categories : categorical_sizes) { auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); + auto dmat = GetDMatrixFromData(x, n, 1); auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, n, 1); auto cuts = AdapterDeviceSketch(&adapter, num_bins, std::numeric_limits::quiet_NaN()); - ValidateCuts(cuts, x, n, 1, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } } @@ -167,12 +167,13 @@ TEST(hist_util, AdapterDeviceSketchMultipleColumns) { int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto x_device = thrust::device_vector(x); for (auto num_bins : bin_sizes) { auto adapter = AdapterFromData(x_device, num_rows, num_columns); auto cuts = AdapterDeviceSketch(&adapter, num_bins, std::numeric_limits::quiet_NaN()); - ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } } @@ -183,12 +184,13 @@ TEST(hist_util, AdapterDeviceSketchBatches) { int num_columns = 5; for (auto batch_size : batch_sizes) { auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, num_rows, num_columns); auto cuts = AdapterDeviceSketch(&adapter, num_bins, std::numeric_limits::quiet_NaN(), batch_size); - ValidateCuts(cuts, x, num_rows, num_columns, num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); } } diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 9b3fa4ac73d1..efb7871fbf64 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -28,6 +28,14 @@ inline std::vector GenerateRandom(int num_rows, int num_columns) { return x; } +inline std::vector GenerateRandomWeights(int num_rows) { + std::vector w(num_rows); + std::mt19937 rng(1); + std::uniform_real_distribution dist(0.0, 1.0); + std::generate(w.begin(), w.end(), [&]() { return dist(rng); }); + return w; +} + #ifdef __CUDACC__ inline data::CupyAdapter AdapterFromData(const thrust::device_vector &x, int num_rows, int num_columns) { @@ -89,21 +97,22 @@ inline std::shared_ptr GetExternalMemoryDMatrixFromData( // Test that elements are approximately equally distributed among bins inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx, - const std::vector& column, + const std::vector& sorted_column,const std::vector&sorted_weights, int num_bins) { - std::map counts; - for (auto& v : column) { - counts[cuts.SearchBin(v, column_idx)]++; + std::map bin_weights; + for (auto i = 0ull; i < sorted_column.size(); i++) { + bin_weights[cuts.SearchBin(sorted_column[i], column_idx)] += sorted_weights[i]; } int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx]; - int expected_num_elements = column.size() / local_num_bins; - // Allow about 30% deviation. This test is not very strict, it only ensures + auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(),0); + int expected_bin_weight = total_weight / local_num_bins; + // Allow up to 30% deviation. This test is not very strict, it only ensures // roughly equal distribution - int allowable_error = std::max(2, int(expected_num_elements * 0.3)); + int allowable_error = std::max(2, int(expected_bin_weight * 0.3)); // First and last bin can have smaller - for (auto& kv : counts) { - EXPECT_LE(std::abs(counts[kv.first] - expected_num_elements), + for (auto& kv : bin_weights) { + EXPECT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight), allowable_error ); } } @@ -111,26 +120,28 @@ inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx, // Test sketch quantiles against the real quantiles // Not a very strict test inline void TestRank(const std::vector& cuts, - const std::vector& sorted_x) { - float eps = 0.05; + const std::vector& sorted_x,const std::vector&sorted_weights) { + double eps = 0.05; + auto total_weight = + std::accumulate(sorted_weights.begin(), sorted_weights.end(), 0.0); // Ignore the last cut, its special + double sum_weight = 0.0; size_t j = 0; for (auto i = 0; i < cuts.size() - 1; i++) { - int expected_rank = ((i+1) * sorted_x.size()) / cuts.size(); while (cuts[i] > sorted_x[j]) { + sum_weight += sorted_weights[j]; j++; } - int actual_rank = j; - int acceptable_error = std::max(2, int(sorted_x.size() * eps)); - ASSERT_LE(std::abs(expected_rank - actual_rank), acceptable_error); + double expected_rank = ((i + 1) * total_weight) / cuts.size(); + double acceptable_error = std::max(2.0, total_weight * eps); + ASSERT_LE(std::abs(expected_rank - sum_weight), acceptable_error); } } inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, - const std::vector& column, + const std::vector& sorted_column, + const std::vector& sorted_weights, int num_bins) { - std::vector sorted_column(column); - std::sort(sorted_column.begin(), sorted_column.end()); // Check the endpoints are correct EXPECT_LT(cuts.MinValues()[column_idx], sorted_column.front()); @@ -162,23 +173,44 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx], cuts.Values().begin() + cuts.Ptrs()[column_idx + 1], column_cuts.begin()); - TestBinDistribution(cuts, column_idx, sorted_column, num_bins); - TestRank(column_cuts, sorted_column); + TestBinDistribution(cuts, column_idx, sorted_column,sorted_weights, num_bins); + TestRank(column_cuts, sorted_column,sorted_weights); } } -// x is dense and row major -inline void ValidateCuts(const HistogramCuts& cuts, std::vector& x, - int num_rows, int num_columns, +inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, int num_bins) { - for (auto i = 0; i < num_columns; i++) { - // Extract the column - std::vector column(num_rows); - for (auto j = 0; j < num_rows; j++) { - column[j] = x[j*num_columns + i]; - } - ValidateColumn(cuts,i, column, num_bins); - } + // Collect data into columns + std::vector> columns(dmat->Info().num_col_); + for (auto& batch : dmat->GetBatches()) { + for (auto i = 0ull; i < batch.Size(); i++) { + for (auto e : batch[i]) { + columns[e.index].push_back(e.fvalue); + } + } + } + // Sort + for (auto i = 0ull; i < columns.size(); i++) { + auto& col = columns.at(i); + const auto& w = dmat->Info().weights_.HostVector(); + std::vector index(col.size()); + std::iota(index.begin(), index.end(), 0); + std::sort(index.begin(), index.end(),[=](size_t a,size_t b) + { + return col[a] < col[b]; + }); + + std::vector sorted_column(col.size()); + std::vector sorted_weights(col.size(), 1.0); + for (auto i = 0ull; i < col.size(); i++) { + sorted_column[i] = col[index[i]]; + if (w.size() == col.size()) { + sorted_weights[i] = w[index[i]]; + } + } + + ValidateColumn(cuts, i, sorted_column, sorted_weights, num_bins); + } } } // namespace common From 788360c7de2617c36d7f5718ed8e6ccd357f31bc Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 3 Mar 2020 17:12:02 +1300 Subject: [PATCH 3/7] Reintroduce weights for DeviceSketch --- src/common/hist_util.cu | 166 ++++++++++++++++++++++++++--- tests/cpp/common/test_hist_util.cc | 6 +- tests/cpp/common/test_hist_util.cu | 21 +++- tests/cpp/common/test_hist_util.h | 3 +- 4 files changed, 173 insertions(+), 23 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 7368cb7e10aa..caab20913ebc 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -9,8 +9,9 @@ #include #include #include -#include #include +#include +#include #include #include @@ -22,10 +23,11 @@ #include "../tree/param.h" #include "device_helpers.cuh" #include "hist_util.h" -#include "math.h" +#include "math.h" // NOLINT #include "quantile.h" #include "xgboost/host_device_vector.h" + namespace xgboost { namespace common { @@ -97,6 +99,24 @@ struct EntryCompareOp { } }; +// Count the entries in each column and exclusive scan +void GetColumnSizesScan(int device, + dh::caching_device_vector* column_sizes_scan, + Span entries, size_t num_columns) { + column_sizes_scan->resize(num_columns + 1, 0); + auto d_column_sizes_scan = column_sizes_scan->data().get(); + auto d_entries = entries.data(); + dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) { + auto& e = d_entries[idx]; + atomicAdd(reinterpret_cast( // NOLINT + &d_column_sizes_scan[e.index]), + static_cast(1)); // NOLINT + }); + dh::XGBCachingDeviceAllocator alloc; + thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), + column_sizes_scan->end(), column_sizes_scan->begin()); +} + /** * \brief Extracts the cuts from sorted data. * @@ -129,6 +149,64 @@ void ExtractCuts(int device, Span cuts, }); } +/** +* \brief Extracts the cuts from sorted data, considering weights. +* +* \param device The device. +* \param cuts Output cuts +* \param num_cuts_per_feature Number of cuts per feature. +* \param sorted_data Sorted entries in segments of columns +* \param column_sizes_scan Describes the boundaries of column segments in +* sorted data +*/ +void ExtractWeightedCuts(int device, Span cuts, + size_t num_cuts_per_feature, Span sorted_data, + Span weights_scan, + Span column_sizes_scan) { + dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) { + // Each thread is responsible for obtaining one cut from the sorted input + size_t column_idx = idx / num_cuts_per_feature; + size_t column_size = + column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; + size_t num_available_cuts = + std::min(size_t(num_cuts_per_feature), column_size); + size_t cut_idx = idx % num_cuts_per_feature; + if (cut_idx >= num_available_cuts) return; + + Span column_entries = + sorted_data.subspan(column_sizes_scan[column_idx], column_size); + Span column_weights = + weights_scan.subspan(column_sizes_scan[column_idx], column_size); + + float total_column_weight = column_weights.back(); + size_t sample_idx = 0; + if (cut_idx == 0) { + // First cut + sample_idx = 0; + } else if (cut_idx == num_available_cuts - 1) { + // Last cut + sample_idx = column_entries.size() - 1; + } else if (num_available_cuts == column_size) { + // There are less samples available than our buffer + // Take every available sample + sample_idx = cut_idx; + } else { + bst_float rank = (total_column_weight * cut_idx) / + static_cast(num_available_cuts); + sample_idx = thrust::upper_bound(thrust::seq, column_weights.begin(), + column_weights.end(), rank) - + column_weights.begin() - 1; + sample_idx = + std::max(0llu, std::min(sample_idx, column_entries.size() - 1)); + } + // repeated values will be filtered out on the CPU + bst_float rmin = sample_idx > 0 ? column_weights[sample_idx - 1] : 0; + bst_float rmax = column_weights[sample_idx]; + cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin, + column_entries[sample_idx].fvalue); + }); +} + void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, SketchContainer* sketch_container, int num_cuts, size_t num_columns) { @@ -138,17 +216,11 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, host_data.begin() + end); thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), EntryCompareOp()); - dh::caching_device_vector column_sizes_scan(num_columns + 1, 0); - auto d_column_sizes_scan = column_sizes_scan.data().get(); - auto d_sorted_entries = sorted_entries.data().get(); - dh::LaunchN(device, sorted_entries.size(), [=] __device__(size_t idx) { - auto& e = d_sorted_entries[idx]; - atomicAdd(reinterpret_cast( // NOLINT - &d_column_sizes_scan[e.index]), - static_cast(1)); // NOLINT - }); - thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan.begin(), - column_sizes_scan.end(), column_sizes_scan.begin()); + + dh::caching_device_vector column_sizes_scan; + GetColumnSizesScan(device, &column_sizes_scan, + {sorted_entries.data().get(), sorted_entries.size()}, + num_columns); thrust::host_vector host_column_sizes_scan(column_sizes_scan); dh::caching_device_vector cuts(num_columns * num_cuts); @@ -161,6 +233,63 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); } +void ProcessWeightedBatch(int device, const SparsePage& page, + Span weights, size_t begin, size_t end, + SketchContainer* sketch_container, int num_cuts, + size_t num_columns) { + dh::XGBCachingDeviceAllocator alloc; + const auto& host_data = page.data.ConstHostVector(); + dh::device_vector sorted_entries(host_data.begin() + begin, + host_data.begin() + end); + + // Binary search to assign weights to each element + dh::device_vector temp_weights(sorted_entries.size()); + auto d_temp_weights = temp_weights.data().get(); + page.offset.SetDevice(device); + auto row_ptrs = page.offset.ConstDeviceSpan(); + size_t base_rowid = page.base_rowid; + dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) { + size_t element_idx = idx + begin; + size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(), + row_ptrs.end(), element_idx) - + row_ptrs.begin() - 1; + d_temp_weights[idx] = weights[ridx + base_rowid]; + }); + + // Sort + thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(), + sorted_entries.end(), temp_weights.begin(), + EntryCompareOp()); + std::vector entries_t(sorted_entries.begin(), sorted_entries.end()); + + // Scan weights + thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), + sorted_entries.begin(), sorted_entries.end(), + temp_weights.begin(), temp_weights.begin(), + [=] __device__(const Entry& a, const Entry& b) { + return a.index == b.index; + }); + std::vector weights_t(temp_weights.begin(), temp_weights.end()); + + dh::caching_device_vector column_sizes_scan; + GetColumnSizesScan(device, &column_sizes_scan, + {sorted_entries.data().get(), sorted_entries.size()}, + num_columns); + thrust::host_vector host_column_sizes_scan(column_sizes_scan); + + // Extract cuts + dh::caching_device_vector cuts(num_columns * num_cuts); + ExtractWeightedCuts( + device, {cuts.data().get(), cuts.size()}, num_cuts, + {sorted_entries.data().get(), sorted_entries.size()}, + {temp_weights.data().get(), temp_weights.size()}, + {column_sizes_scan.data().get(), column_sizes_scan.size()}); + + // add cuts into sketches + thrust::host_vector host_cuts(cuts); + sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); +} + HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, size_t sketch_batch_num_elements) { HistogramCuts cuts; @@ -178,13 +307,20 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, if (sketch_batch_num_elements == 0) { sketch_batch_num_elements = dmat->Info().num_nonzero_; } + dmat->Info().weights_.SetDevice(device); for (const auto& batch : dmat->GetBatches()) { size_t batch_nnz = batch.data.ConstHostVector().size(); for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) { size_t end = std::min(batch_nnz, begin + sketch_batch_num_elements); - ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts, - dmat->Info().num_col_); + if (dmat->Info().weights_.Size() > 0) { + ProcessWeightedBatch( + device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end, + &sketch_container, num_cuts, dmat->Info().num_col_); + } else { + ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts, + dmat->Info().num_col_); + } } } diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 254c969e9903..e9ef5c3cdc02 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -267,10 +267,8 @@ TEST(hist_util, DenseCutsAccuracyTest) { } TEST(hist_util, DenseCutsAccuracyTestWeights) { - //int bin_sizes[] = {2, 16, 256, 512}; - //int sizes[] = {100, 1000, 1500}; - int bin_sizes[] = {2}; - int sizes[] = {100}; + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 5a81a203f3bd..d03b106f7bd3 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -81,10 +81,26 @@ TEST(hist_util, DeviceSketchDeterminism) { TEST(hist_util, DeviceSketchMultipleColumns) { int bin_sizes[] = {2, 16, 256, 512}; int sizes[] = {100, 1000, 1500}; - int num_columns = 2; + int num_columns = 5; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + for (auto num_bins : bin_sizes) { + auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); + ValidateCuts(cuts, dmat.get(), num_bins); + } + } + +} + +TEST(hist_util, DeviceSketchMultipleColumnsWeights) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; + int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows); for (auto num_bins : bin_sizes) { auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); ValidateCuts(cuts, dmat.get(), num_bins); @@ -108,7 +124,7 @@ TEST(hist_util, DeviceSketchBatches) { TEST(hist_util, DeviceSketchMultipleColumnsExternal) { int bin_sizes[] = {2, 16, 256, 512}; int sizes[] = {100, 1000, 1500}; - int num_columns = 2; + int num_columns =5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); dmlc::TemporaryDirectory temp; @@ -121,7 +137,6 @@ TEST(hist_util, DeviceSketchMultipleColumnsExternal) { } } - TEST(hist_util, AdapterDeviceSketch) { int rows = 5; diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index efb7871fbf64..6c1a0aa48002 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -120,7 +120,8 @@ inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx, // Test sketch quantiles against the real quantiles // Not a very strict test inline void TestRank(const std::vector& cuts, - const std::vector& sorted_x,const std::vector&sorted_weights) { + const std::vector& sorted_x, + const std::vector& sorted_weights) { double eps = 0.05; auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(), 0.0); From d7fb199a5f56858de03b5c9f54c3067b4bcbf9b4 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 4 Mar 2020 11:38:10 +1300 Subject: [PATCH 4/7] Linux build --- src/common/hist_util.cu | 17 ++++++++--------- tests/cpp/common/test_hist_util.h | 6 +++--- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index caab20913ebc..386f77dbc9e2 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -48,8 +48,7 @@ struct SketchContainer { SketchContainer(int max_bin, size_t num_columns, size_t num_rows) { // Initialize Sketches for this dmatrix sketches_.resize(num_columns); -#pragma omp parallel for default(none) shared(max_bin) \ - schedule(static) if (num_columns > kOmpNumColsParallelizeLimit) // NOLINT +#pragma omp parallel for schedule(static) if (num_columns > kOmpNumColsParallelizeLimit) // NOLINT for (int icol = 0; icol < num_columns; ++icol) { // NOLINT sketches_[icol].Init(num_rows, 1.0 / (8 * max_bin)); } @@ -67,7 +66,7 @@ struct SketchContainer { void Push(size_t entries_per_column, const thrust::host_vector& entries, const thrust::host_vector& column_scan) { -#pragma omp parallel for default(none) schedule(static) if (sketches_.size() > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT +#pragma omp parallel for schedule(static) if (sketches_.size() > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT for (int icol = 0; icol < sketches_.size(); ++icol) { size_t column_size = column_scan[icol + 1] - column_scan[icol]; if (column_size == 0) continue; @@ -136,7 +135,7 @@ void ExtractCuts(int device, Span cuts, size_t column_size = column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; size_t num_available_cuts = - std::min(size_t(num_cuts_per_feature), column_size); + min(size_t(num_cuts_per_feature), column_size); size_t cut_idx = idx % num_cuts_per_feature; if (cut_idx >= num_available_cuts) return; @@ -169,7 +168,7 @@ void ExtractWeightedCuts(int device, Span cuts, size_t column_size = column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; size_t num_available_cuts = - std::min(size_t(num_cuts_per_feature), column_size); + min(size_t(num_cuts_per_feature), column_size); size_t cut_idx = idx % num_cuts_per_feature; if (cut_idx >= num_available_cuts) return; @@ -197,7 +196,7 @@ void ExtractWeightedCuts(int device, Span cuts, column_weights.end(), rank) - column_weights.begin() - 1; sample_idx = - std::max(0llu, std::min(sample_idx, column_entries.size() - 1)); + max(size_t(0), min(sample_idx, column_entries.size() - 1)); } // repeated values will be filtered out on the CPU bst_float rmin = sample_idx > 0 ? column_weights[sample_idx - 1] : 0; @@ -312,7 +311,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, size_t batch_nnz = batch.data.ConstHostVector().size(); for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) { - size_t end = std::min(batch_nnz, begin + sketch_batch_num_elements); + size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements)); if (dmat->Info().weights_.Size() > 0) { ProcessWeightedBatch( device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end, @@ -377,7 +376,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing, auto d_column_sizes_scan = column_sizes_scan.data().get(); IsValidFunctor is_valid(missing); dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) { - auto& e = batch_iter[begin + idx]; + auto e = batch_iter[begin + idx]; if (is_valid(e)) { atomicAdd(reinterpret_cast( // NOLINT &d_column_sizes_scan[e.column_idx]), @@ -438,7 +437,7 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, } for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { - size_t end = std::min(batch.Size(), begin + sketch_batch_num_elements); + size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); ProcessBatch(adapter, begin, end, missing, &sketch_container, num_cuts); } diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 6c1a0aa48002..c54296b1ce71 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -203,10 +203,10 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, std::vector sorted_column(col.size()); std::vector sorted_weights(col.size(), 1.0); - for (auto i = 0ull; i < col.size(); i++) { - sorted_column[i] = col[index[i]]; + for (auto j = 0ull; j < col.size(); j++) { + sorted_column[j] = col[index[j]]; if (w.size() == col.size()) { - sorted_weights[i] = w[index[i]]; + sorted_weights[j] = w[index[j]]; } } From 3910f0b447497544e82144466e607b8792028870 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 4 Mar 2020 12:17:38 +1300 Subject: [PATCH 5/7] Remove incorrect test --- tests/cpp/data/test_sparse_page_dmatrix.cu | 22 ---------------------- tests/cpp/tree/test_gpu_hist.cu | 1 - 2 files changed, 23 deletions(-) diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index 5f9603850eb8..59b2df31a15d 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -158,28 +158,6 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) { EXPECT_EQ(impl_ext->matrix.base_rowid, current_row); current_row += impl_ext->matrix.n_rows; } - - current_row = 0; - thrust::device_vector row_d(kCols); - thrust::device_vector row_ext_d(kCols); - std::vector row(kCols); - std::vector row_ext(kCols); - for (auto& page : dmat_ext->GetBatches(param)) { - auto impl_ext = page.Impl(); - EXPECT_EQ(impl_ext->matrix.base_rowid, current_row); - - for (size_t i = 0; i < impl_ext->Size(); i++) { - dh::LaunchN(0, kCols, ReadRowFunction(impl->matrix, current_row, row_d.data().get())); - thrust::copy(row_d.begin(), row_d.end(), row.begin()); - - dh::LaunchN(0, kCols, ReadRowFunction(impl_ext->matrix, current_row, row_ext_d.data().get())); - thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin()); - - EXPECT_EQ(row, row_ext) << "for row " << current_row; - - current_row++; - } - } } } // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 8c38c09b8660..324241cab912 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -284,7 +284,6 @@ void TestHistogramIndexImpl() { ASSERT_EQ(maker->page->matrix.info.n_bins, maker_ext->page->matrix.info.n_bins); ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size()); - ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext); } TEST(GpuHist, TestHistogramIndex) { From 5852f480f2d4f793cd27a7b56f948f592b2a2c95 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 5 Mar 2020 09:57:40 +1300 Subject: [PATCH 6/7] Address review comments --- src/common/hist_util.cu | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 386f77dbc9e2..1916ffd6fbf1 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -20,7 +20,6 @@ #include "../data/adapter.h" #include "../data/device_adapter.cuh" -#include "../tree/param.h" #include "device_helpers.cuh" #include "hist_util.h" #include "math.h" // NOLINT @@ -149,15 +148,15 @@ void ExtractCuts(int device, Span cuts, } /** -* \brief Extracts the cuts from sorted data, considering weights. -* -* \param device The device. -* \param cuts Output cuts -* \param num_cuts_per_feature Number of cuts per feature. -* \param sorted_data Sorted entries in segments of columns -* \param column_sizes_scan Describes the boundaries of column segments in -* sorted data -*/ + * \brief Extracts the cuts from sorted data, considering weights. + * + * \param device The device. + * \param cuts Output cuts. + * \param num_cuts_per_feature Number of cuts per feature. + * \param sorted_data Sorted entries in segments of columns. + * \param weights_scan Inclusive scan of weights for each entry in sorted_data. + * \param column_sizes_scan Describes the boundaries of column segments in sorted data. + */ void ExtractWeightedCuts(int device, Span cuts, size_t num_cuts_per_feature, Span sorted_data, Span weights_scan, @@ -259,7 +258,6 @@ void ProcessWeightedBatch(int device, const SparsePage& page, thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), temp_weights.begin(), EntryCompareOp()); - std::vector entries_t(sorted_entries.begin(), sorted_entries.end()); // Scan weights thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), @@ -268,7 +266,6 @@ void ProcessWeightedBatch(int device, const SparsePage& page, [=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; }); - std::vector weights_t(temp_weights.begin(), temp_weights.end()); dh::caching_device_vector column_sizes_scan; GetColumnSizesScan(device, &column_sizes_scan, @@ -308,7 +305,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, } dmat->Info().weights_.SetDevice(device); for (const auto& batch : dmat->GetBatches()) { - size_t batch_nnz = batch.data.ConstHostVector().size(); + size_t batch_nnz = batch.data.Size(); for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) { size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements)); @@ -345,9 +342,10 @@ struct IsValidFunctor : public thrust::unary_function { } }; +// Thrust version of this function causes error on Windows template thrust::transform_iterator MakeTransformIterator( - IterT iter, FuncT func) { + IterT iter, FuncT func) { return thrust::transform_iterator(iter, func); } @@ -357,19 +355,17 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing, dh::XGBCachingDeviceAllocator alloc; adapter->BeforeFirst(); adapter->Next(); - auto& batch = adapter->Value(); + auto &batch = adapter->Value(); // Enforce single batch CHECK(!adapter->Next()); - auto batch_iter = MakeTransformIterator( - thrust::make_counting_iterator(0llu), - [=] __device__(size_t idx) { return batch.GetElement(idx); }); + thrust::make_counting_iterator(0llu), + [=] __device__(size_t idx) { return batch.GetElement(idx); }); auto entry_iter = MakeTransformIterator( thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { return Entry(batch.GetElement(idx).column_idx, batch.GetElement(idx).value); }); - // Work out how many valid entries we have in each column dh::caching_device_vector column_sizes_scan(adapter->NumColumns() + 1, 0); From 9479cf79f96ca63e5e5ad553a046a8a5cb043898 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 5 Mar 2020 11:05:07 +1300 Subject: [PATCH 7/7] Remove benchmark code --- tests/cpp/common/test_hist_util.cu | 128 ----------------------------- 1 file changed, 128 deletions(-) diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index d03b106f7bd3..503f839d96db 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -208,133 +208,5 @@ TEST(hist_util, AdapterDeviceSketchBatches) { ValidateCuts(cuts, dmat.get(), num_bins); } } - -TEST(hist_util, Benchmark) { - int num_bins = 256; - std::vector sizes; - for (auto i = 8ull; i < 26; i += 2) { - sizes.push_back(1 << i); - } - - std::cout << "Num rows, "; - for (auto n : sizes) { - std::cout << n << ", "; - } - std::cout << "\n"; - int num_columns = 5; - std::cout << "AdapterDeviceSketch, "; - for (auto num_rows : sizes) { - auto x = GenerateRandom(num_rows, num_columns); - auto x_device = thrust::device_vector(x); - auto adapter = AdapterFromData(x_device, num_rows, num_columns); - Timer t; - t.Start(); - auto cuts = AdapterDeviceSketch(&adapter, num_bins, - std::numeric_limits::quiet_NaN()); - t.Stop(); - std::cout << t.ElapsedSeconds() << ", "; - } - std::cout << "\n"; - - std::cout << "DeviceSketch, "; - for (auto num_rows : sizes) { - auto x = GenerateRandom(num_rows, num_columns); - dmlc::TemporaryDirectory tmpdir; - auto dmat = - GetDMatrixFromData(x, num_rows, num_columns); - Timer t; - t.Start(); - auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); - t.Stop(); - std::cout << t.ElapsedSeconds() << ", "; - } - std::cout << "\n"; - - std::cout << "WQSketch, "; - for (auto num_rows : sizes) { - auto x = GenerateRandom(num_rows, num_columns); - dmlc::TemporaryDirectory tmpdir; - auto dmat = - GetDMatrixFromData(x, num_rows, num_columns); - HistogramCuts cuts; - DenseCuts dense(&cuts); - Timer t; - t.Start(); - dense.Build(dmat.get(), num_bins); - t.Stop(); - std::cout << t.ElapsedSeconds() << ", "; - } - std::cout << "\n"; -} -TEST(hist_util, BenchmarkNumColumns) { - int num_bins = 256; - int num_rows = 10; - std::vector num_columns; - for (auto i = 4ull; i < 16; i += 2) { - num_columns.push_back(1 << i); - } - - std::cout << "Num columns, "; - for (auto n : num_columns) { - std::cout << n << ", "; - } - std::cout << "\n"; - std::cout << "AdapterDeviceSketch, "; - for (auto num_column : num_columns) { - auto x = GenerateRandom(num_rows, num_column); - auto x_device = thrust::device_vector(x); - auto adapter = AdapterFromData(x_device, num_rows, num_column); - Timer t; - t.Start(); - auto cuts = AdapterDeviceSketch(&adapter, num_bins, - std::numeric_limits::quiet_NaN()); - t.Stop(); - std::cout << t.ElapsedSeconds() << ", "; - } - std::cout << "\n"; - std::cout << "DeviceSketch, "; - for (auto num_column : num_columns) { - auto x = GenerateRandom(num_rows, num_column); - dmlc::TemporaryDirectory tmpdir; - auto dmat = - GetDMatrixFromData(x, num_rows, num_column); - Timer t; - t.Start(); - auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); - t.Stop(); - std::cout << t.ElapsedSeconds() << ", "; - } - std::cout << "\n"; - std::cout << "SparseCuts, "; - for (auto num_column : num_columns) { - auto x = GenerateRandom(num_rows, num_column); - dmlc::TemporaryDirectory tmpdir; - auto dmat = - GetDMatrixFromData(x, num_rows, num_column); - HistogramCuts cuts; - SparseCuts sparse(&cuts); - Timer t; - t.Start(); - sparse.Build(dmat.get(), num_bins); - t.Stop(); - std::cout << t.ElapsedSeconds() << ", "; - } - std::cout << "\n"; - std::cout << "DenseCuts, "; - for (auto num_column : num_columns) { - auto x = GenerateRandom(num_rows, num_column); - dmlc::TemporaryDirectory tmpdir; - auto dmat = - GetDMatrixFromData(x, num_rows, num_column); - HistogramCuts cuts; - DenseCuts dense(&cuts); - Timer t; - t.Start(); - dense.Build(dmat.get(), num_bins); - t.Stop(); - std::cout << t.ElapsedSeconds() << ", "; - } - std::cout << "\n"; -} } // namespace common } // namespace xgboost