From be2f28ec08a7b1f3bcc133200a9ece958b3108da Mon Sep 17 00:00:00 2001 From: RAMitchell Date: Sat, 26 Nov 2016 06:43:22 +1300 Subject: [PATCH] Update build instructions, improve memory usage (#1811) --- plugin/updater_gpu/README.md | 39 +- plugin/updater_gpu/speed_test.py | 9 +- .../{cuda_helpers.cuh => device_helpers.cuh} | 277 +++++++++-- plugin/updater_gpu/src/find_split.cuh | 30 +- .../updater_gpu/src/find_split_multiscan.cuh | 216 +++------ plugin/updater_gpu/src/find_split_sorting.cuh | 197 ++++---- plugin/updater_gpu/src/gpu_builder.cu | 433 +++++++++--------- plugin/updater_gpu/src/gpu_builder.cuh | 10 +- plugin/updater_gpu/src/types.cuh | 13 +- src/tree/param.h | 2 + 10 files changed, 653 insertions(+), 573 deletions(-) rename plugin/updater_gpu/src/{cuda_helpers.cuh => device_helpers.cuh} (57%) diff --git a/plugin/updater_gpu/README.md b/plugin/updater_gpu/README.md index e801fc354966..526f64d1bd96 100644 --- a/plugin/updater_gpu/README.md +++ b/plugin/updater_gpu/README.md @@ -3,25 +3,56 @@ ## Usage Specify the updater parameter as 'grow_gpu'. +This plugin currently works with the CLI version and python version. + Python example: ```python param['updater'] = 'grow_gpu' ``` +## Memory usage +Device memory usage can be calculated as approximately: +``` +bytes = (10 x n_rows) + (44 x n_rows x n_columns x column_density) +``` +Data is stored in a sparse format. For example, missing values produced by one hot encoding are not stored. If a one hot encoding separates a categorical variable into 5 columns the column_density of these columns is 1/5 = 0.2. + +A 4GB graphics card will process approximately 3.5 million rows of the well known Kaggle higgs dataset. + +The algorithm will automatically perform row subsampling if it detects there is not enough memory on the device. + ## Dependencies A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depends on shuffle and vote instructions introduced in Kepler). +Building the plug-in requires CUDA Toolkit 7.5 or later. + The plugin also depends on CUB 1.5.4 - http://nvlabs.github.io/cub/index.html. CUB is a header only cuda library which provides sort/reduce/scan primitives. ## Build -The plugin can be built using cmake and specifying the option PLUGIN_UPDATER_GPU=ON. +To use the plugin xgboost must be built using cmake specifying the option PLUGIN_UPDATER_GPU=ON. The location of the CUB library must also be specified with the cmake variable CUB_DIRECTORY. CMake will prepare a build system depending on which platform you are on. -Specify the location of the CUB library with the cmake variable CUB_DIRECTORY. +From the command line on Windows or Linux starting from the xgboost directory: -It is recommended to build with Cuda Toolkit 7.5 or greater. +```bash +$ mkdir build +$ cd build +$ cmake .. -DPLUGIN_UPDATER_GPU=ON -DCUB_DIRECTORY= +``` + +On Windows you may also need to specify your generator as 64 bit, so the cmake command becomes: +```bash +$ cmake .. -G"Visual Studio 12 2013 Win64" -DPLUGIN_UPDATER_GPU=ON -DCUB_DIRECTORY= +``` +You may also be able to use a later version of visual studio depending on whether the CUDA toolkit supports it. + +On an linux cmake will generate a Makefile in the build directory. Invoking the command 'make' from this directory will build the project. If the build fails try invoking make again. There can sometimes be problems with the order items are built. + +On Windows cmake will generate an xgboost.sln solution file in the build directory. Build this solution in release mode. This is also a good time to check it is being built as x64. If not make sure the cmake generator is set correctly. + +The build process generates an xgboost library and executable as normal but containing the GPU tree construction algorithm. ## Author Rory Mitchell @@ -29,3 +60,5 @@ Rory Mitchell Report any bugs to r.a.mitchell.nz at google mail. + + diff --git a/plugin/updater_gpu/speed_test.py b/plugin/updater_gpu/speed_test.py index eaf8111b5608..fc76d98b266b 100644 --- a/plugin/updater_gpu/speed_test.py +++ b/plugin/updater_gpu/speed_test.py @@ -4,7 +4,6 @@ import numpy as np import xgboost as xgb import time -test_size = 550000 # path to where the data lies dpath = '../../demo/data' @@ -13,6 +12,9 @@ dtrain = np.loadtxt( dpath+'/training.csv', delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s') } ) dtrain = np.concatenate((dtrain, np.copy(dtrain))) dtrain = np.concatenate((dtrain, np.copy(dtrain))) +dtrain = np.concatenate((dtrain, np.copy(dtrain))) +test_size = len(dtrain) + print(len(dtrain)) print ('finish loading from csv ') @@ -37,10 +39,9 @@ # scale weight of positive examples param['scale_pos_weight'] = sum_wneg/sum_wpos param['bst:eta'] = 0.1 -param['max_depth'] = 16 +param['max_depth'] = 15 param['eval_metric'] = 'auc' -param['silent'] = 1 -param['nthread'] = 4 +param['nthread'] = 16 plst = param.items()+[('eval_metric', 'ams@0.15')] diff --git a/plugin/updater_gpu/src/cuda_helpers.cuh b/plugin/updater_gpu/src/device_helpers.cuh similarity index 57% rename from plugin/updater_gpu/src/cuda_helpers.cuh rename to plugin/updater_gpu/src/device_helpers.cuh index bd6fae113718..c895c16e0cf4 100644 --- a/plugin/updater_gpu/src/cuda_helpers.cuh +++ b/plugin/updater_gpu/src/device_helpers.cuh @@ -7,21 +7,30 @@ #include #include #include -#include #include +#include #include #include +#include #ifdef _WIN32 #include #endif +// Uncomment to enable +// #define DEVICE_TIMER +// #define TIMERS + +namespace dh { + +/* + * Error handling functions + */ + #define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__) cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, int line) { if (code != cudaSuccess) { - std::cout << file; - std::cout << line; std::stringstream ss; ss << file << "(" << line << ")"; std::string file_and_line; @@ -44,36 +53,10 @@ inline void gpuAssert(cudaError_t code, const char *file, int line, } } -// Keep track of cub library device allocation -struct CubMemory { - void *d_temp_storage; - size_t temp_storage_bytes; - - CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {} - - ~CubMemory() { - if (d_temp_storage != NULL) { - safe_cuda(cudaFree(d_temp_storage)); - } - } +/* + * Timers + */ - void Allocate() { - safe_cuda(cudaMalloc(&d_temp_storage, temp_storage_bytes)); - } - - bool IsAllocated() { return d_temp_storage != NULL; } -}; - -// Utility function: rounds up integer division. -template T div_round_up(const T a, const T b) { - return static_cast(ceil(static_cast(a) / b)); -} - -template thrust::device_ptr dptr(T *d_ptr) { - return thrust::device_pointer_cast(d_ptr); -} - -// #define DEVICE_TIMER #define MAX_WARPS 32 // Maximum number of warps to time #define MAX_SLOTS 10 #define TIMER_BLOCKID 0 // Block to time @@ -135,10 +118,8 @@ struct DeviceTimer { #endif #ifdef DEVICE_TIMER - __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) // NOLINT - : - GTimer(GTimer), - start(clock()), slot(slot) {} + __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) // NOLINT + : GTimer(GTimer), start(clock()), slot(slot) {} #else __device__ DeviceTimer(DeviceTimerGlobal >imer, int slot) {} // NOLINT #endif @@ -155,7 +136,6 @@ struct DeviceTimer { } }; -// #define TIMERS struct Timer { volatile double start; Timer() { reset(); } @@ -190,6 +170,10 @@ struct Timer { } }; +/* + * Utility functions + */ + template void print(const thrust::device_vector &v, size_t max_items = 10) { thrust::host_vector h = v; @@ -211,6 +195,34 @@ void print(char *label, const thrust::device_vector &v, std::cout << "\n"; } +template T1 div_round_up(const T1 a, const T2 b) { + return static_cast(ceil(static_cast(a) / b)); +} + +template thrust::device_ptr dptr(T *d_ptr) { + return thrust::device_pointer_cast(d_ptr); +} + +template T *raw(thrust::device_vector &v) { // NOLINT + return raw_pointer_cast(v.data()); +} + +template size_t size_bytes(const thrust::device_vector &v) { + return sizeof(T) * v.size(); +} + +// Threadblock iterates over range, filling with value +template +__device__ void block_fill(IterT begin, size_t n, ValueT value) { + for (auto i : block_stride_range(static_cast(0), n)) { + begin[i] = value; + } +} + +/* + * Range iterator + */ + class range { public: class iterator { @@ -270,11 +282,192 @@ template __device__ range block_stride_range(T begin, T end) { return r; } -// Converts device_vector to raw pointer -template T *raw(thrust::device_vector &v) { // NOLINT - return raw_pointer_cast(v.data()); +/* + * Memory + */ + +class bulk_allocator; + +template class dvec { + friend bulk_allocator; + + private: + T *_ptr; + size_t _size; + + void external_allocate(void *ptr, size_t size) { + if (!empty()) { + throw std::runtime_error("Tried to allocate dvec but already allocated"); + } + + _ptr = static_cast(ptr); + _size = size; + } + + public: + dvec() : _ptr(NULL), _size(0) {} + size_t size() { return _size; } + bool empty() { return _ptr == NULL || _size == 0; } + T *data() { return _ptr; } + + std::vector as_vector() { + std::vector h_vector(size()); + safe_cuda(cudaMemcpy(h_vector.data(), _ptr, size() * sizeof(T), + cudaMemcpyDeviceToHost)); + return h_vector; + } + + void fill(T value) { + thrust::fill_n(thrust::device_pointer_cast(_ptr), size(), value); + } + + void print() { + auto h_vector = this->as_vector(); + + for (auto e : h_vector) { + std::cout << e << " "; + } + + std::cout << "\n"; + } + + thrust::device_ptr tbegin() { return thrust::device_pointer_cast(_ptr); } + + thrust::device_ptr tend() { + return thrust::device_pointer_cast(_ptr + size()); + } + + template dvec &operator=(const std::vector &other) { + if (other.size() != size()) { + throw std::runtime_error( + "Cannot copy assign vector to dvec, sizes are different"); + } + + thrust::copy(other.begin(), other.end(), this->tbegin()); + + return *this; + } + + dvec &operator=(dvec &other) { + if (other.size() != size()) { + throw std::runtime_error( + "Cannot copy assign dvec to dvec, sizes are different"); + } + + thrust::copy(other.tbegin(), other.tend(), this->tbegin()); + + return *this; + } +}; + +class bulk_allocator { + char *d_ptr; + size_t _size; + + const size_t align = 256; + + template size_t align_round_up(SizeT n) { + if (n % align == 0) { + return n; + } else { + return n + align - (n % align); + } + } + + template + size_t get_size_bytes(dvec *first_vec, SizeT first_size) { + return align_round_up(first_size * sizeof(T)); + } + + template + size_t get_size_bytes(dvec *first_vec, SizeT first_size, Args... args) { + return align_round_up(first_size * sizeof(T)) + get_size_bytes(args...); + } + + template + void allocate_dvec(char *ptr, dvec *first_vec, SizeT first_size) { + first_vec->external_allocate(static_cast(ptr), first_size); + } + + template + void allocate_dvec(char *ptr, dvec *first_vec, SizeT first_size, + Args... args) { + first_vec->external_allocate(static_cast(ptr), first_size); + ptr += align_round_up(first_size * sizeof(T)); + allocate_dvec(ptr, args...); + } + + public: + bulk_allocator() : _size(0), d_ptr(NULL) {} + + ~bulk_allocator() { + if (!d_ptr == NULL) { + safe_cuda(cudaFree(d_ptr)); + } + } + + size_t size() { return _size; } + + template void allocate(Args... args) { + if (d_ptr != NULL) { + throw std::runtime_error("Bulk allocator already allocated"); + } + + _size = get_size_bytes(args...); + + safe_cuda(cudaMalloc(&d_ptr, _size)); + + allocate_dvec(d_ptr, args...); + } +}; + +// Keep track of cub library device allocation +struct CubMemory { + void *d_temp_storage; + size_t temp_storage_bytes; + + CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {} + + ~CubMemory() { + if (d_temp_storage != NULL) { + safe_cuda(cudaFree(d_temp_storage)); + } + } + + void Allocate() { + safe_cuda(cudaMalloc(&d_temp_storage, temp_storage_bytes)); + } + + bool IsAllocated() { return d_temp_storage != NULL; } +}; + +inline size_t available_memory() { + size_t device_free = 0; + size_t device_total = 0; + dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total)); + return device_free; } -template size_t size_bytes(const thrust::device_vector &v) { - return sizeof(T) * v.size(); +inline std::string device_name() { + cudaDeviceProp prop; + dh::safe_cuda(cudaGetDeviceProperties(&prop, 0)); + return std::string(prop.name); +} + +/* + * Kernel launcher + */ + +template __global__ void launch_n_kernel(size_t n, L lambda) { + for (auto i : grid_stride_range(static_cast(0), n)) { + lambda(i); + } +} + +template +inline void launch_n(size_t n, L lambda) { + const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); + + launch_n_kernel<<>>(n, lambda); } +} // namespace dh diff --git a/plugin/updater_gpu/src/find_split.cuh b/plugin/updater_gpu/src/find_split.cuh index d1b7958d8271..458d8fba71f1 100644 --- a/plugin/updater_gpu/src/find_split.cuh +++ b/plugin/updater_gpu/src/find_split.cuh @@ -4,7 +4,7 @@ #pragma once #include #include -#include "cuda_helpers.cuh" +#include "device_helpers.cuh" #include "find_split_multiscan.cuh" #include "find_split_sorting.cuh" #include "types_functions.cuh" @@ -54,29 +54,27 @@ void reduce_split_candidates(Split *d_split_candidates, Node *d_nodes, int n_current_nodes = 1 << level; const int BLOCK_THREADS = 512; - const int GRID_SIZE = div_round_up(n_current_nodes, BLOCK_THREADS); + const int GRID_SIZE = dh::div_round_up(n_current_nodes, BLOCK_THREADS); reduce_split_candidates_kernel<<>>( d_split_candidates, d_current_nodes, d_new_nodes, n_current_nodes, n_features, param); - safe_cuda(cudaDeviceSynchronize()); + dh::safe_cuda(cudaDeviceSynchronize()); } -void find_split(const Item *d_items, Split *d_split_candidates, - const NodeIdT *d_node_id, Node *d_nodes, bst_uint num_items, - int num_features, const int *d_feature_offsets, - gpu_gpair *d_node_sums, int *d_node_offsets, - const GPUTrainingParam param, const int level, - bool multiscan_algorithm) { +void find_split(const ItemIter items_iter, Split *d_split_candidates, + Node *d_nodes, bst_uint num_items, int num_features, + const int *d_feature_offsets, gpu_gpair *d_node_sums, + int *d_node_offsets, const GPUTrainingParam param, + const int level, bool multiscan_algorithm) { if (multiscan_algorithm) { - find_split_candidates_multiscan(d_items, d_split_candidates, d_node_id, - d_nodes, num_items, num_features, - d_feature_offsets, param, level); + find_split_candidates_multiscan(items_iter, d_split_candidates, d_nodes, + num_items, num_features, d_feature_offsets, + param, level); } else { - find_split_candidates_sorted(d_items, d_split_candidates, d_node_id, - d_nodes, num_items, num_features, - d_feature_offsets, d_node_sums, d_node_offsets, - param, level); + find_split_candidates_sorted(items_iter, d_split_candidates, d_nodes, + num_items, num_features, d_feature_offsets, + d_node_sums, d_node_offsets, param, level); } // Find the best split for each node diff --git a/plugin/updater_gpu/src/find_split_multiscan.cuh b/plugin/updater_gpu/src/find_split_multiscan.cuh index fc49946d253c..0cd906e73497 100644 --- a/plugin/updater_gpu/src/find_split_multiscan.cuh +++ b/plugin/updater_gpu/src/find_split_multiscan.cuh @@ -4,7 +4,7 @@ #pragma once #include #include -#include "cuda_helpers.cuh" +#include "device_helpers.cuh" #include "types_functions.cuh" namespace xgboost { @@ -86,8 +86,7 @@ template struct ReduceEnactorMultiscan { struct Reduction : cub::Uninitialized<_Reduction> {}; // Thread local member variables - const Item *d_items; - const NodeIdT *d_node_id; + const ItemIter item_iter; _TempStorage &temp_storage; _Reduction &reduction; gpu_gpair gpair; @@ -95,12 +94,12 @@ template struct ReduceEnactorMultiscan { NodeIdT node_id_adjusted; const int node_begin; - __device__ __forceinline__ ReduceEnactorMultiscan( - TempStorage &temp_storage, // NOLINT - Reduction &reduction, // NOLINT - const Item *d_items, const NodeIdT *d_node_id, const int node_begin) + __device__ __forceinline__ + ReduceEnactorMultiscan(TempStorage &temp_storage, // NOLINT + Reduction &reduction, // NOLINT + const ItemIter item_iter, const int node_begin) : temp_storage(temp_storage.Alias()), reduction(reduction.Alias()), - d_items(d_items), d_node_id(d_node_id), node_begin(node_begin) {} + item_iter(item_iter), node_begin(node_begin) {} __device__ __forceinline__ void ResetPartials() { if (threadIdx.x < ParamsT::N_WARPS) { @@ -119,8 +118,11 @@ template struct ReduceEnactorMultiscan { __device__ __forceinline__ void LoadTile(const bst_uint &offset, const bst_uint &num_remaining) { if (threadIdx.x < num_remaining) { - gpair = d_items[offset + threadIdx.x].gpair; - node_id = d_node_id[offset + threadIdx.x]; + bst_uint i = offset + threadIdx.x; + gpair = thrust::get<0>(item_iter[i]); + // gpair = d_items[offset + threadIdx.x].gpair; + // node_id = d_node_id[offset + threadIdx.x]; + node_id = thrust::get<2>(item_iter[i]); node_id_adjusted = node_id - node_begin; } else { gpair = gpu_gpair(); @@ -231,12 +233,12 @@ struct FindSplitEnactorMultiscan { struct TempStorage : cub::Uninitialized<_TempStorage> {}; // Thread local member variables - const Item *d_items; + const ItemIter item_iter; Split *d_split_candidates_out; - const NodeIdT *d_node_id; const Node *d_nodes; _TempStorage &temp_storage; - Item item; + gpu_gpair gpair; + float fvalue; NodeIdT node_id; NodeIdT node_id_adjusted; const NodeIdT node_begin; @@ -246,15 +248,14 @@ struct FindSplitEnactorMultiscan { FlagPrefixCallbackOp flag_prefix_op; __device__ __forceinline__ FindSplitEnactorMultiscan( - TempStorage &temp_storage, const Item *d_items, // NOLINT - Split *d_split_candidates_out, const NodeIdT *d_node_id, - const Node *d_nodes, const NodeIdT node_begin, - const GPUTrainingParam ¶m, const ReductionT reduction, - const int level) - : temp_storage(temp_storage.Alias()), d_items(d_items), - d_split_candidates_out(d_split_candidates_out), d_node_id(d_node_id), - d_nodes(d_nodes), node_begin(node_begin), param(param), - reduction(reduction), level(level), flag_prefix_op() {} + TempStorage &temp_storage, const ItemIter item_iter, // NOLINT + Split *d_split_candidates_out, const Node *d_nodes, + const NodeIdT node_begin, const GPUTrainingParam ¶m, + const ReductionT reduction, const int level) + : temp_storage(temp_storage.Alias()), item_iter(item_iter), + d_split_candidates_out(d_split_candidates_out), d_nodes(d_nodes), + node_begin(node_begin), param(param), reduction(reduction), + level(level), flag_prefix_op() {} __device__ __forceinline__ void UpdateTileCarry() { if (threadIdx.x < ParamsT::N_NODES) { @@ -308,16 +309,17 @@ struct FindSplitEnactorMultiscan { __device__ __forceinline__ void LoadTile(bst_uint offset, bst_uint num_remaining) { - bst_uint index = offset + threadIdx.x; if (threadIdx.x < num_remaining) { - item = d_items[index]; - node_id = d_node_id[index]; + bst_uint i = offset + threadIdx.x; + gpair = thrust::get<0>(item_iter[i]); + fvalue = thrust::get<1>(item_iter[i]); + node_id = thrust::get<2>(item_iter[i]); node_id_adjusted = node_id - node_begin; } else { node_id = -1; node_id_adjusted = -1; - item.fvalue = -FLT_MAX; - item.gpair = gpu_gpair(); + fvalue = -FLT_MAX; + gpair = gpu_gpair(); } } @@ -333,10 +335,10 @@ struct FindSplitEnactorMultiscan { int left_index = offset + threadIdx.x - 1; float left_fvalue = left_index >= static_cast(segment_begin) && threadIdx.x < num_remaining - ? d_items[left_index].fvalue + ? thrust::get<1>(item_iter[left_index]) : -FLT_MAX; - return left_fvalue != item.fvalue; + return left_fvalue != fvalue; } // Prevent splitting in the middle of same valued instances @@ -434,9 +436,9 @@ struct FindSplitEnactorMultiscan { for (int warp = 0; warp < ParamsT::N_WARPS; warp++) { if (threadIdx.x / 32 == warp) { for (int lane = 0; lane < 32; lane++) { - gpu_gpair g = cub::ShuffleIndex(item.gpair, lane); + gpu_gpair g = cub::ShuffleIndex(gpair, lane); gpu_gpair missing_broadcast = cub::ShuffleIndex(missing, lane); - float fvalue_broadcast = __shfl(item.fvalue, lane); + float fvalue_broadcast = __shfl(fvalue, lane); bool thread_active_broadcast = __shfl(thread_active, lane); float loss_chg_broadcast = __shfl(loss_chg, lane); NodeIdT node_id_broadcast = cub::ShuffleIndex(node_id, lane); @@ -476,7 +478,7 @@ struct FindSplitEnactorMultiscan { bool missing_left; float loss_chg = thread_active - ? loss_chg_missing(item.gpair, missing, parent_sum, + ? loss_chg_missing(gpair, missing, parent_sum, parent_gain, param, missing_left) : -FLT_MAX; @@ -488,16 +490,16 @@ struct FindSplitEnactorMultiscan { : 0.0f; if (QueryUpdateWarpSplit(loss_chg, warp_best_loss)) { - float fvalue_split = item.fvalue - FVALUE_EPS; + float fvalue_split = fvalue - FVALUE_EPS; if (missing_left) { - gpu_gpair left_sum = missing + item.gpair; + gpu_gpair left_sum = missing + gpair; gpu_gpair right_sum = parent_sum - left_sum; temp_storage.warp_best_splits[node_id_adjusted][warp_id].Update( loss_chg, missing_left, fvalue_split, blockIdx.x, left_sum, right_sum, param); } else { - gpu_gpair left_sum = item.gpair; + gpu_gpair left_sum = gpair; gpu_gpair right_sum = parent_sum - left_sum; temp_storage.warp_best_splits[node_id_adjusted][warp_id].Update( loss_chg, missing_left, fvalue_split, blockIdx.x, left_sum, @@ -506,30 +508,6 @@ struct FindSplitEnactorMultiscan { } } - /* - __device__ __forceinline__ void WarpExclusiveScan(bool active, gpu_gpair - input, gpu_gpair &output, gpu_gpair &sum) - { - - output = input; - - for (int offset = 1; offset < 32; offset <<= 1){ - float tmp1 = __shfl_up(output.grad(), offset); - - float tmp2 = __shfl_up(output.hess(), offset); - if (cub::LaneId() >= offset) - { - output.grad += tmp1; - output.hess += tmp2; - } - } - - sum.grad = __shfl(output.grad, 31); - sum.hess = __shfl(output.hess, 31); - - output -= input; - } - */ __device__ __forceinline__ void BlockExclusiveScan() { ResetPartials(); @@ -547,14 +525,12 @@ struct FindSplitEnactorMultiscan { if (ballot > 0) { WarpScanT(temp_storage.warp_gpair_scan[warp_id]) - .InclusiveScan(node_active ? item.gpair : gpu_gpair(), scan_result, + .InclusiveScan(node_active ? gpair : gpu_gpair(), scan_result, cub::Sum(), warp_sum); - // WarpExclusiveScan( node_active, node_active ? item.gpair : - // gpu_gpair(), scan_result, warp_sum); } if (node_active) { - item.gpair = scan_result - item.gpair; + gpair = scan_result - gpair; } if (lane_id == 0) { @@ -589,8 +565,8 @@ struct FindSplitEnactorMultiscan { __syncthreads(); if (NodeActive()) { - item.gpair += temp_storage.partial_sums[node_id_adjusted][warp_id] + - temp_storage.tile_carry[node_id_adjusted]; + gpair += temp_storage.partial_sums[node_id_adjusted][warp_id] + + temp_storage.tile_carry[node_id_adjusted]; } __syncthreads(); @@ -633,67 +609,12 @@ struct FindSplitEnactorMultiscan { } } - /* - __device__ void SequentialAlgorithm(bst_uint segment_begin, - bst_uint segment_end) { - if (threadIdx.x != 0) { - return; - } - - __shared__ Split best_split[ParamsT::N_NODES]; - - __shared__ gpu_gpair scan[ParamsT::N_NODES]; - - __shared__ Node nodes[ParamsT::N_NODES]; - - __shared__ gpu_gpair missing[ParamsT::N_NODES]; - - float previous_fvalue[ParamsT::N_NODES]; - - // Initialise counts - for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { - best_split[NODE] = Split(); - scan[NODE] = gpu_gpair(); - nodes[NODE] = d_nodes[node_begin + NODE]; - missing[NODE] = nodes[NODE].sum_gradients - reduction.node_sums[NODE]; - previous_fvalue[NODE] = FLT_MAX; - } - - for (bst_uint i = segment_begin; i < segment_end; i++) { - int8_t nodeid_adjusted = d_node_id[i] - node_begin; - float fvalue = d_items[i].fvalue; - - if (NodeActive(nodeid_adjusted)) { - if (fvalue != previous_fvalue[nodeid_adjusted]) { - float f_split; - if (previous_fvalue[nodeid_adjusted] != FLT_MAX) { - f_split = (previous_fvalue[nodeid_adjusted] + fvalue) * 0.5; - } else { - f_split = fvalue; - } - - best_split[nodeid_adjusted].UpdateCalcLoss( - f_split, scan[nodeid_adjusted], missing[nodeid_adjusted], - nodes[nodeid_adjusted], param); - } - - scan[nodeid_adjusted] += d_items[i].gpair; - previous_fvalue[nodeid_adjusted] = fvalue; - } - } - - for (int NODE = 0; NODE < ParamsT::N_NODES; NODE++) { - temp_storage.best_splits[NODE] = best_split[NODE]; - } - } - */ - __device__ __forceinline__ void ResetSplitCandidates() { const int max_nodes = 1 << level; const int begin = blockIdx.x * max_nodes; const int end = begin + max_nodes; - for (auto i : block_stride_range(begin, end)) { + for (auto i : dh::block_stride_range(begin, end)) { d_split_candidates_out[i] = Split(); } } @@ -730,9 +651,9 @@ __global__ void __launch_bounds__(1024, 2) #endif find_split_candidates_multiscan_kernel( - const Item *d_items, Split *d_split_candidates_out, - const NodeIdT *d_node_id, const Node *d_nodes, const int node_begin, - bst_uint num_items, int num_features, const int *d_feature_offsets, + const ItemIter items_iter, Split *d_split_candidates_out, + const Node *d_nodes, const int node_begin, bst_uint num_items, + int num_features, const int *d_feature_offsets, const GPUTrainingParam param, const int level) { if (num_items <= 0) { return; @@ -753,22 +674,22 @@ __launch_bounds__(1024, 2) __shared__ typename ReduceT::Reduction reduction; - ReduceT(temp_storage.reduce, reduction, d_items, d_node_id, node_begin) + ReduceT(temp_storage.reduce, reduction, items_iter, node_begin) .ProcessRegion(segment_begin, segment_end); __syncthreads(); - FindSplitT find_split(temp_storage.find_split, d_items, - d_split_candidates_out, d_node_id, d_nodes, node_begin, - param, reduction.Alias(), level); + FindSplitT find_split(temp_storage.find_split, items_iter, + d_split_candidates_out, d_nodes, node_begin, param, + reduction.Alias(), level); find_split.ProcessRegion(segment_begin, segment_end); } template void find_split_candidates_multiscan_variation( - const Item *d_items, Split *d_split_candidates, const NodeIdT *d_node_id, - const Node *d_nodes, int node_begin, int node_end, bst_uint num_items, - int num_features, const int *d_feature_offsets, - const GPUTrainingParam param, const int level) { + const ItemIter items_iter, Split *d_split_candidates, const Node *d_nodes, + int node_begin, int node_end, bst_uint num_items, int num_features, + const int *d_feature_offsets, const GPUTrainingParam param, + const int level) { const int BLOCK_THREADS = 512; @@ -786,47 +707,46 @@ void find_split_candidates_multiscan_variation( find_split_candidates_multiscan_kernel< find_split_params, reduce_params><<>>( - d_items, d_split_candidates, d_node_id, d_nodes, node_begin, num_items, + items_iter, d_split_candidates, d_nodes, node_begin, num_items, num_features, d_feature_offsets, param, level); - safe_cuda(cudaDeviceSynchronize()); + dh::safe_cuda(cudaDeviceSynchronize()); } void find_split_candidates_multiscan( - const Item *d_items, Split *d_split_candidates, const NodeIdT *d_node_id, - const Node *d_nodes, bst_uint num_items, int num_features, - const int *d_feature_offsets, const GPUTrainingParam param, - const int level) { + const ItemIter items_iter, Split *d_split_candidates, const Node *d_nodes, + bst_uint num_items, int num_features, const int *d_feature_offsets, + const GPUTrainingParam param, const int level) { // Select templated variation of split finding algorithm switch (level) { case 0: find_split_candidates_multiscan_variation<1>( - d_items, d_split_candidates, d_node_id, d_nodes, 0, 1, num_items, - num_features, d_feature_offsets, param, level); + items_iter, d_split_candidates, d_nodes, 0, 1, num_items, num_features, + d_feature_offsets, param, level); break; case 1: find_split_candidates_multiscan_variation<2>( - d_items, d_split_candidates, d_node_id, d_nodes, 1, 3, num_items, - num_features, d_feature_offsets, param, level); + items_iter, d_split_candidates, d_nodes, 1, 3, num_items, num_features, + d_feature_offsets, param, level); break; case 2: find_split_candidates_multiscan_variation<4>( - d_items, d_split_candidates, d_node_id, d_nodes, 3, 7, num_items, - num_features, d_feature_offsets, param, level); + items_iter, d_split_candidates, d_nodes, 3, 7, num_items, num_features, + d_feature_offsets, param, level); break; case 3: find_split_candidates_multiscan_variation<8>( - d_items, d_split_candidates, d_node_id, d_nodes, 7, 15, num_items, - num_features, d_feature_offsets, param, level); + items_iter, d_split_candidates, d_nodes, 7, 15, num_items, num_features, + d_feature_offsets, param, level); break; case 4: find_split_candidates_multiscan_variation<16>( - d_items, d_split_candidates, d_node_id, d_nodes, 15, 31, num_items, + items_iter, d_split_candidates, d_nodes, 15, 31, num_items, num_features, d_feature_offsets, param, level); break; case 5: find_split_candidates_multiscan_variation<32>( - d_items, d_split_candidates, d_node_id, d_nodes, 31, 63, num_items, + items_iter, d_split_candidates, d_nodes, 31, 63, num_items, num_features, d_feature_offsets, param, level); break; } diff --git a/plugin/updater_gpu/src/find_split_sorting.cuh b/plugin/updater_gpu/src/find_split_sorting.cuh index 0cf422e45fed..661e7df915d3 100644 --- a/plugin/updater_gpu/src/find_split_sorting.cuh +++ b/plugin/updater_gpu/src/find_split_sorting.cuh @@ -4,7 +4,7 @@ #pragma once #include #include -#include "cuda_helpers.cuh" +#include "device_helpers.cuh" #include "types_functions.cuh" namespace xgboost { @@ -59,30 +59,8 @@ struct GpairCallbackOp { } }; -template -struct FindSplitParamsSorting { - enum { - BLOCK_THREADS = _BLOCK_THREADS, - TILE_ITEMS = BLOCK_THREADS, - N_WARPS = _BLOCK_THREADS / 32, - DEBUG_VALIDATE = _DEBUG_VALIDATE, - ITEMS_PER_THREAD = 1 - }; -}; - -template struct ReduceParamsSorting { - enum { - BLOCK_THREADS = _BLOCK_THREADS, - ITEMS_PER_THREAD = 1, - TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, - N_WARPS = _BLOCK_THREADS / 32, - DEBUG_VALIDATE = _DEBUG_VALIDATE - }; -}; - -template struct ReduceEnactorSorting { - typedef cub::BlockScan GpairScanT; - +template struct ReduceEnactorSorting { + typedef cub::BlockScan GpairScanT; struct _TempStorage { typename GpairScanT::TempStorage gpair_scan; }; @@ -92,10 +70,9 @@ template struct ReduceEnactorSorting { // Thread local member variables gpu_gpair *d_block_node_sums; int *d_block_node_offsets; - const NodeIdT *d_node_id; - const Item *d_items; + const ItemIter item_iter; _TempStorage &temp_storage; - Item item; + gpu_gpair gpair; NodeIdT node_id; NodeIdT right_node_id; // Contains node_id relative to the current level only @@ -103,32 +80,24 @@ template struct ReduceEnactorSorting { GpairTupleCallbackOp callback_op; const int level; - __device__ __forceinline__ ReduceEnactorSorting( - TempStorage &temp_storage, // NOLINT - gpu_gpair *d_block_node_sums, int *d_block_node_offsets, - const Item *d_items, const NodeIdT *d_node_id, const int level) + __device__ __forceinline__ + ReduceEnactorSorting(TempStorage &temp_storage, // NOLINT + gpu_gpair *d_block_node_sums, int *d_block_node_offsets, + ItemIter item_iter, const int level) : temp_storage(temp_storage.Alias()), d_block_node_sums(d_block_node_sums), - d_block_node_offsets(d_block_node_offsets), d_items(d_items), - d_node_id(d_node_id), callback_op(), level(level) {} - - __device__ __forceinline__ void ResetSumsOffsets() { - const int max_nodes = 1 << level; - - for (auto i : block_stride_range(0, max_nodes)) { - d_block_node_sums[i] = gpu_gpair(); - d_block_node_offsets[i] = -1; - } - } + d_block_node_offsets(d_block_node_offsets), item_iter(item_iter), + callback_op(), level(level) {} __device__ __forceinline__ void LoadTile(const bst_uint &offset, const bst_uint &num_remaining) { if (threadIdx.x < num_remaining) { - item = d_items[offset + threadIdx.x]; - node_id = d_node_id[offset + threadIdx.x]; + bst_uint i = offset + threadIdx.x; + gpair = thrust::get<0>(item_iter[i]); + node_id = thrust::get<2>(item_iter[i]); right_node_id = threadIdx.x == num_remaining - 1 ? -1 - : d_node_id[offset + threadIdx.x + 1]; + : thrust::get<2>(item_iter[i + 1]); // Prevent overflow const int level_begin = (1 << level) - 1; node_id_adjusted = @@ -140,7 +109,7 @@ template struct ReduceEnactorSorting { const bst_uint &num_remaining) { LoadTile(offset, num_remaining); - ScanTuple t(item.gpair, node_id); + ScanTuple t(gpair, node_id); GpairScanT(temp_storage.gpair_scan).InclusiveSum(t, t, callback_op); __syncthreads(); @@ -156,33 +125,36 @@ template struct ReduceEnactorSorting { __device__ __forceinline__ void ProcessRegion(const bst_uint &segment_begin, const bst_uint &segment_end) { + const int max_nodes = 1 << level; + dh::block_fill(d_block_node_offsets, max_nodes, -1); + dh::block_fill(d_block_node_sums, max_nodes, gpu_gpair()); + // Current position bst_uint offset = segment_begin; - ResetSumsOffsets(); - __syncthreads(); // Process full tiles while (offset < segment_end) { ProcessTile(offset, segment_end - offset); - offset += ParamsT::TILE_ITEMS; + offset += BLOCK_THREADS; } } }; -template struct FindSplitEnactorSorting { - typedef cub::BlockScan GpairScanT; - typedef cub::BlockReduce SplitReduceT; +template +struct FindSplitEnactorSorting { + typedef cub::BlockScan GpairScanT; + typedef cub::BlockReduce SplitReduceT; typedef cub::WarpReduce WarpLossReduceT; struct _TempStorage { union { typename GpairScanT::TempStorage gpair_scan; typename SplitReduceT::TempStorage split_reduce; - typename WarpLossReduceT::TempStorage loss_reduce[ParamsT::N_WARPS]; + typename WarpLossReduceT::TempStorage loss_reduce[N_WARPS]; }; - Split warp_best_splits[ParamsT::N_WARPS]; + Split warp_best_splits[N_WARPS]; }; struct TempStorage : cub::Uninitialized<_TempStorage> {}; @@ -191,10 +163,10 @@ template struct FindSplitEnactorSorting { _TempStorage &temp_storage; gpu_gpair *d_block_node_sums; int *d_block_node_offsets; - const Item *d_items; - const NodeIdT *d_node_id; + const ItemIter item_iter; const Node *d_nodes; - Item item; + gpu_gpair gpair; + float fvalue; NodeIdT node_id; float left_fvalue; const GPUTrainingParam ¶m; @@ -203,27 +175,27 @@ template struct FindSplitEnactorSorting { __device__ __forceinline__ FindSplitEnactorSorting( TempStorage &temp_storage, gpu_gpair *d_block_node_sums, // NOLINT - int *d_block_node_offsets, const Item *d_items, const NodeIdT *d_node_id, - const Node *d_nodes, const GPUTrainingParam ¶m, - Split *d_split_candidates_out, const int level) + int *d_block_node_offsets, const ItemIter item_iter, const Node *d_nodes, + const GPUTrainingParam ¶m, Split *d_split_candidates_out, + const int level) : temp_storage(temp_storage.Alias()), d_block_node_sums(d_block_node_sums), - d_block_node_offsets(d_block_node_offsets), d_items(d_items), - d_node_id(d_node_id), d_nodes(d_nodes), - d_split_candidates_out(d_split_candidates_out), level(level), - param(param) {} + d_block_node_offsets(d_block_node_offsets), item_iter(item_iter), + d_nodes(d_nodes), d_split_candidates_out(d_split_candidates_out), + level(level), param(param) {} __device__ __forceinline__ void LoadTile(NodeIdT node_id_adjusted, const bst_uint &node_begin, const bst_uint &offset, const bst_uint &num_remaining) { if (threadIdx.x < num_remaining) { - node_id = d_node_id[offset + threadIdx.x]; - - item = d_items[offset + threadIdx.x]; + bst_uint i = offset + threadIdx.x; + gpair = thrust::get<0>(item_iter[i]); + fvalue = thrust::get<1>(item_iter[i]); + node_id = thrust::get<2>(item_iter[i]); bool first_item = offset + threadIdx.x == node_begin; - left_fvalue = first_item ? item.fvalue - FVALUE_EPS - : d_items[offset + threadIdx.x - 1].fvalue; + left_fvalue = + first_item ? fvalue - FVALUE_EPS : thrust::get<1>(item_iter[i - 1]); } } @@ -233,12 +205,12 @@ template struct FindSplitEnactorSorting { return; } - for (int warp = 0; warp < ParamsT::N_WARPS; warp++) { + for (int warp = 0; warp < N_WARPS; warp++) { if (threadIdx.x / 32 == warp) { for (int lane = 0; lane < 32; lane++) { - gpu_gpair g = cub::ShuffleIndex(item.gpair, lane); + gpu_gpair g = cub::ShuffleIndex(gpair, lane); gpu_gpair missing_broadcast = cub::ShuffleIndex(missing, lane); - float fvalue_broadcast = __shfl(item.fvalue, lane); + float fvalue_broadcast = __shfl(fvalue, lane); bool thread_active_broadcast = __shfl(thread_active, lane); float loss_chg_broadcast = __shfl(loss_chg, lane); if (threadIdx.x == 32 * warp) { @@ -278,7 +250,7 @@ template struct FindSplitEnactorSorting { } __device__ __forceinline__ bool LeftmostFvalue() { - return item.fvalue != left_fvalue; + return fvalue != left_fvalue; } __device__ __forceinline__ void @@ -293,19 +265,19 @@ template struct FindSplitEnactorSorting { : gpu_gpair(); bool missing_left; - float loss_chg = - thread_active ? loss_chg_missing(item.gpair, missing, n.sum_gradients, - n.root_gain, param, missing_left) - : -FLT_MAX; + float loss_chg = thread_active + ? loss_chg_missing(gpair, missing, n.sum_gradients, + n.root_gain, param, missing_left) + : -FLT_MAX; int warp_id = threadIdx.x / 32; volatile float warp_best_loss = temp_storage.warp_best_splits[warp_id].loss_chg; if (QueryUpdateWarpSplit(loss_chg, warp_best_loss, thread_active)) { - float fvalue_split = (item.fvalue + left_fvalue) / 2.0f; + float fvalue_split = (fvalue + left_fvalue) / 2.0f; - gpu_gpair left_sum = item.gpair; + gpu_gpair left_sum = gpair; if (missing_left) { left_sum += missing; } @@ -325,23 +297,16 @@ template struct FindSplitEnactorSorting { // Scan gpair const bool thread_active = threadIdx.x < num_remaining && node_id >= 0; GpairScanT(temp_storage.gpair_scan) - .ExclusiveSum(thread_active ? item.gpair : gpu_gpair(), item.gpair, - callback_op); + .ExclusiveSum(thread_active ? gpair : gpu_gpair(), gpair, callback_op); __syncthreads(); // Evaluate split EvaluateSplits(node_id_adjusted, node_begin, offset, num_remaining); } - __device__ __forceinline__ void ResetWarpSplits() { - if (threadIdx.x < ParamsT::N_WARPS) { - temp_storage.warp_best_splits[threadIdx.x] = Split(); - } - } - __device__ __forceinline__ void WriteBestSplit(const NodeIdT &node_id_adjusted) { if (threadIdx.x < 32) { - bool active = threadIdx.x < ParamsT::N_WARPS; + bool active = threadIdx.x < N_WARPS; float warp_loss = active ? temp_storage.warp_best_splits[threadIdx.x].loss_chg : -FLT_MAX; @@ -356,7 +321,7 @@ template struct FindSplitEnactorSorting { __device__ __forceinline__ void ProcessNode(const NodeIdT &node_id_adjusted, const bst_uint &node_begin, const bst_uint &node_end) { - ResetWarpSplits(); + dh::block_fill(temp_storage.warp_best_splits, N_WARPS, Split()); GpairCallbackOp callback_op = GpairCallbackOp(); @@ -365,7 +330,7 @@ template struct FindSplitEnactorSorting { while (offset < node_end) { ProcessTile(node_id_adjusted, node_begin, offset, node_end - offset, callback_op); - offset += ParamsT::TILE_ITEMS; + offset += BLOCK_THREADS; __syncthreads(); } @@ -375,11 +340,8 @@ template struct FindSplitEnactorSorting { __device__ __forceinline__ void ResetSplitCandidates() { const int max_nodes = 1 << level; const int begin = blockIdx.x * max_nodes; - const int end = begin + max_nodes; - for (auto i : block_stride_range(begin, end)) { - d_split_candidates_out[i] = Split(); - } + dh::block_fill(d_split_candidates_out + begin, max_nodes, Split()); } __device__ __forceinline__ void ProcessFeature(const bst_uint &segment_begin, @@ -410,13 +372,12 @@ template struct FindSplitEnactorSorting { } }; -template +template __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel( - const Item *d_items, Split *d_split_candidates_out, - const NodeIdT *d_node_id, const Node *d_nodes, bst_uint num_items, - const int num_features, const int *d_feature_offsets, - gpu_gpair *d_node_sums, int *d_node_offsets, const GPUTrainingParam param, - const int level) { + const ItemIter items_iter, Split *d_split_candidates_out, + const Node *d_nodes, bst_uint num_items, const int num_features, + const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets, + const GPUTrainingParam param, const int level) { if (num_items <= 0) { return; @@ -425,50 +386,48 @@ __global__ __launch_bounds__(1024, 1) void find_split_candidates_sorted_kernel( bst_uint segment_begin = d_feature_offsets[blockIdx.x]; bst_uint segment_end = d_feature_offsets[blockIdx.x + 1]; - typedef ReduceEnactorSorting ReduceT; - typedef FindSplitEnactorSorting FindSplitT; + typedef ReduceEnactorSorting ReduceT; + typedef FindSplitEnactorSorting FindSplitT; __shared__ union { typename ReduceT::TempStorage reduce; typename FindSplitT::TempStorage find_split; } temp_storage; - const int max_modes_level = 1 << level; gpu_gpair *d_block_node_sums = d_node_sums + blockIdx.x * max_modes_level; int *d_block_node_offsets = d_node_offsets + blockIdx.x * max_modes_level; - ReduceT(temp_storage.reduce, d_block_node_sums, d_block_node_offsets, d_items, - d_node_id, level) + ReduceT(temp_storage.reduce, d_block_node_sums, d_block_node_offsets, + items_iter, level) .ProcessRegion(segment_begin, segment_end); __syncthreads(); FindSplitT(temp_storage.find_split, d_block_node_sums, d_block_node_offsets, - d_items, d_node_id, d_nodes, param, d_split_candidates_out, level) + items_iter, d_nodes, param, d_split_candidates_out, level) .ProcessFeature(segment_begin, segment_end); } -void find_split_candidates_sorted( - const Item *d_items, Split *d_split_candidates, const NodeIdT *d_node_id, - Node *d_nodes, bst_uint num_items, int num_features, - const int *d_feature_offsets, gpu_gpair *d_node_sums, int *d_node_offsets, - const GPUTrainingParam param, const int level) { - +void find_split_candidates_sorted(const ItemIter items_iter, + Split *d_split_candidates, Node *d_nodes, + bst_uint num_items, int num_features, + const int *d_feature_offsets, + gpu_gpair *d_node_sums, int *d_node_offsets, + const GPUTrainingParam param, + const int level) { const int BLOCK_THREADS = 512; CHECK(BLOCK_THREADS / 32 < 32) << "Too many active warps."; - typedef FindSplitParamsSorting find_split_params; - typedef ReduceParamsSorting reduce_params; int grid_size = num_features; find_split_candidates_sorted_kernel< - reduce_params, find_split_params><<>>( - d_items, d_split_candidates, d_node_id, d_nodes, num_items, num_features, + BLOCK_THREADS><<>>( + items_iter, d_split_candidates, d_nodes, num_items, num_features, d_feature_offsets, d_node_sums, d_node_offsets, param, level); - safe_cuda(cudaGetLastError()); - safe_cuda(cudaDeviceSynchronize()); + dh::safe_cuda(cudaGetLastError()); + dh::safe_cuda(cudaDeviceSynchronize()); } } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/src/gpu_builder.cu b/plugin/updater_gpu/src/gpu_builder.cu index 266bb31d9757..4279b35e8636 100644 --- a/plugin/updater_gpu/src/gpu_builder.cu +++ b/plugin/updater_gpu/src/gpu_builder.cu @@ -1,20 +1,22 @@ /*! * Copyright 2016 Rory mitchell */ -#include "gpu_builder.cuh" +#include +#include +#include #include #include #include #include #include #include -#include -#include -#include #include +#include #include -#include "cuda_helpers.cuh" +#include "../../../src/common/random.h" +#include "device_helpers.cuh" #include "find_split.cuh" +#include "gpu_builder.cuh" #include "types_functions.cuh" namespace xgboost { @@ -26,29 +28,31 @@ struct GPUData { int n_features; int n_instances; + dh::bulk_allocator ba; GPUTrainingParam param; - CubMemory cub_mem; - - thrust::device_vector fvalues; - thrust::device_vector foffsets; - thrust::device_vector instance_id; - thrust::device_vector feature_id; - thrust::device_vector node_id; - thrust::device_vector node_id_temp; - thrust::device_vector node_id_instance; - thrust::device_vector node_id_instance_temp; - thrust::device_vector gpair; - thrust::device_vector nodes; - thrust::device_vector split_candidates; - - thrust::device_vector items; - thrust::device_vector items_temp; - - thrust::device_vector node_sums; - thrust::device_vector node_offsets; - thrust::device_vector sort_index_in; - thrust::device_vector sort_index_out; + dh::dvec fvalues; + dh::dvec fvalues_temp; + dh::dvec fvalues_cached; + dh::dvec foffsets; + dh::dvec instance_id; + dh::dvec instance_id_temp; + dh::dvec instance_id_cached; + dh::dvec feature_id; + dh::dvec node_id; + dh::dvec node_id_temp; + dh::dvec node_id_instance; + dh::dvec gpair; + dh::dvec nodes; + dh::dvec split_candidates; + dh::dvec node_sums; + dh::dvec node_offsets; + dh::dvec sort_index_in; + dh::dvec sort_index_out; + + dh::dvec cub_mem; + + ItemIter items_iter; void Init(const std::vector &in_fvalues, const std::vector &in_foffsets, @@ -56,100 +60,75 @@ struct GPUData { const std::vector &in_feature_id, const std::vector &in_gpair, bst_uint n_instances_in, bst_uint n_features_in, int max_depth, const TrainParam ¶m_in) { - Timer t; - - // Track allocated device memory - size_t n_bytes = 0; - n_features = n_features_in; n_instances = n_instances_in; + uint32_t max_nodes = (1 << (max_depth + 1)) - 1; + uint32_t max_nodes_level = 1 << max_depth; + + // Calculate memory for sort + size_t cub_mem_size = 0; + cub::DeviceSegmentedRadixSort::SortPairs( + cub_mem.data(), cub_mem_size, cub::DoubleBuffer(), + cub::DoubleBuffer(), in_fvalues.size(), n_features, + foffsets.data(), foffsets.data() + 1); + + // Allocate memory + size_t free_memory = dh::available_memory(); + ba.allocate(&fvalues, in_fvalues.size(), &fvalues_temp, in_fvalues.size(), + &fvalues_cached, in_fvalues.size(), &foffsets, + in_foffsets.size(), &instance_id, in_instance_id.size(), + &instance_id_temp, in_instance_id.size(), &instance_id_cached, + in_instance_id.size(), &feature_id, in_feature_id.size(), + &node_id, in_fvalues.size(), &node_id_temp, in_fvalues.size(), + &node_id_instance, n_instances, &gpair, n_instances, &nodes, + max_nodes, &split_candidates, max_nodes_level * n_features, + &node_sums, max_nodes_level * n_features, &node_offsets, + max_nodes_level * n_features, &sort_index_in, in_fvalues.size(), + &sort_index_out, in_fvalues.size(), &cub_mem, cub_mem_size); + + if (!param_in.silent) { + const int mb_size = 1048576; + LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/" + << free_memory / mb_size << " MB on " << dh::device_name(); + } + node_id.fill(0); + node_id_instance.fill(0); + fvalues = in_fvalues; - n_bytes += size_bytes(fvalues); + fvalues_cached = fvalues; foffsets = in_foffsets; - n_bytes += size_bytes(foffsets); instance_id = in_instance_id; - n_bytes += size_bytes(instance_id); + instance_id_cached = instance_id; feature_id = in_feature_id; - n_bytes += size_bytes(feature_id); param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda, param_in.reg_alpha, param_in.max_delta_step); - gpair = thrust::device_vector(in_gpair.begin(), in_gpair.end()); - n_bytes += size_bytes(gpair); + gpair = in_gpair; - uint32_t max_nodes_level = 1 << max_depth; - - node_sums = thrust::device_vector(max_nodes_level * n_features); - n_bytes += size_bytes(node_sums); - node_offsets = thrust::device_vector(max_nodes_level * n_features); - n_bytes += size_bytes(node_offsets); - - node_id_instance = thrust::device_vector(n_instances, 0); - n_bytes += size_bytes(node_id_instance); - - node_id = thrust::device_vector(fvalues.size(), 0); - n_bytes += size_bytes(node_id); - node_id_temp = thrust::device_vector(fvalues.size()); - n_bytes += size_bytes(node_id_temp); - - uint32_t max_nodes = (1 << (max_depth + 1)) - 1; - nodes = thrust::device_vector(max_nodes); - n_bytes += size_bytes(nodes); - - split_candidates = - thrust::device_vector(max_nodes_level * n_features); - n_bytes += size_bytes(split_candidates); - - // Init items - items = thrust::device_vector(fvalues.size()); - n_bytes += size_bytes(items); - items_temp = thrust::device_vector(fvalues.size()); - n_bytes += size_bytes(items_temp); - - sort_index_in = thrust::device_vector(fvalues.size()); - n_bytes += size_bytes(sort_index_in); - sort_index_out = thrust::device_vector(fvalues.size()); - n_bytes += size_bytes(sort_index_out); + nodes.fill(Node()); - // std::cout << "Device memory allocated: " << n_bytes << "\n"; + items_iter = thrust::make_zip_iterator(thrust::make_tuple( + thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()), + fvalues.tbegin(), node_id.tbegin())); - this->CreateItems(); allocated = true; + + dh::safe_cuda(cudaGetLastError()); } ~GPUData() {} - // Create items array using gpair, instaoce_id, fvalue - void CreateItems() { - auto d_items = items.data(); - auto d_instance_id = instance_id.data(); - auto d_gpair = gpair.data(); - auto d_fvalue = fvalues.data(); - - auto counting = thrust::make_counting_iterator(0); - thrust::for_each(counting, counting + fvalues.size(), - [=] __device__(bst_uint i) { - Item item; - item.instance_id = d_instance_id[i]; - item.fvalue = d_fvalue[i]; - item.gpair = d_gpair[item.instance_id]; - d_items[i] = item; - }); - } - // Reset memory for new boosting iteration - void Reset(const std::vector &in_gpair, - const std::vector &in_fvalues, - const std::vector &in_instance_id) { + void Reset(const std::vector &in_gpair) { CHECK(allocated); - thrust::copy(in_gpair.begin(), in_gpair.end(), gpair.begin()); - thrust::fill(nodes.begin(), nodes.end(), Node()); - thrust::fill(node_id_instance.begin(), node_id_instance.end(), 0); - thrust::fill(node_id.begin(), node_id.end(), 0); - - this->CreateItems(); + gpair = in_gpair; + instance_id = instance_id_cached; + fvalues = fvalues_cached; + nodes.fill(Node()); + node_id_instance.fill(0); + node_id.fill(0); } bool IsAllocated() { return allocated; } @@ -157,16 +136,14 @@ struct GPUData { // Gather from node_id_instance into node_id according to instance_id void GatherNodeId() { // Update node_id for each item - auto d_items = items.data(); auto d_node_id = node_id.data(); auto d_node_id_instance = node_id_instance.data(); + auto d_instance_id = instance_id.data(); - auto counting = thrust::make_counting_iterator(0); - thrust::for_each(counting, counting + fvalues.size(), - [=] __device__(bst_uint i) { - Item item = d_items[i]; - d_node_id[i] = d_node_id_instance[item.instance_id]; - }); + dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) { + // Item item = d_items[i]; + d_node_id[i] = d_node_id_instance[d_instance_id[i]]; + }); } }; @@ -174,20 +151,22 @@ GPUBuilder::GPUBuilder() { gpu_data = new GPUData(); } void GPUBuilder::Init(const TrainParam ¶m_in) { param = param_in; - CHECK(param.max_depth < 16) << "Max depth > 15 not supported."; + CHECK(param.max_depth < 16) << "Tree depth too large."; } GPUBuilder::~GPUBuilder() { delete gpu_data; } -template -__global__ void update_nodeid_missing_kernel(NodeIdT *d_node_id_instance, - Node *d_nodes, const OffsetT n) { - for (auto i : grid_stride_range(OffsetT(0), n)) { +void GPUBuilder::UpdateNodeId(int level) { + auto *d_node_id_instance = gpu_data->node_id_instance.data(); + Node *d_nodes = gpu_data->nodes.data(); + + dh::launch_n(gpu_data->node_id_instance.size(), [=] __device__(int i) { NodeIdT item_node_id = d_node_id_instance[i]; if (item_node_id < 0) { - continue; + return; } + Node node = d_nodes[item_node_id]; if (node.IsLeaf()) { @@ -197,132 +176,77 @@ __global__ void update_nodeid_missing_kernel(NodeIdT *d_node_id_instance, } else { d_node_id_instance[i] = item_node_id * 2 + 2; } - } -} + }); -__device__ void load_as_words(const int n_nodes, Node *d_nodes, Node *s_nodes) { - const int upper_range = n_nodes * (sizeof(Node) / sizeof(int)); - for (auto i : block_stride_range(0, upper_range)) { - reinterpret_cast(s_nodes)[i] = reinterpret_cast(d_nodes)[i]; - } -} + dh::safe_cuda(cudaDeviceSynchronize()); -template -__global__ void -update_nodeid_fvalue_kernel(NodeIdT *d_node_id, NodeIdT *d_node_id_instance, - Item *d_items, Node *d_nodes, const int n_nodes, - const int *d_foffsets, const int *d_feature_id, - const size_t n, const int n_features, - bool cache_nodes) { - // Load nodes into shared memory - extern __shared__ Node s_nodes[]; - - if (cache_nodes) { - load_as_words(n_nodes, d_nodes, s_nodes); - __syncthreads(); - } + auto *d_fvalues = gpu_data->fvalues.data(); + auto *d_instance_id = gpu_data->instance_id.data(); + auto *d_node_id = gpu_data->node_id.data(); + auto *d_feature_id = gpu_data->feature_id.data(); - for (auto i : grid_stride_range(size_t(0), n)) { - Item item = d_items[i]; + // Update node based on fvalue where exists + dh::launch_n(gpu_data->fvalues.size(), [=] __device__(int i) { NodeIdT item_node_id = d_node_id[i]; if (item_node_id < 0) { - continue; + return; } - Node node = cache_nodes ? s_nodes[item_node_id] : d_nodes[item_node_id]; + Node node = d_nodes[item_node_id]; if (node.IsLeaf()) { - continue; + return; } int feature_id = d_feature_id[i]; if (feature_id == node.split.findex) { - if (item.fvalue < node.split.fvalue) { - d_node_id_instance[item.instance_id] = item_node_id * 2 + 1; + float fvalue = d_fvalues[i]; + bst_uint instance_id = d_instance_id[i]; + + if (fvalue < node.split.fvalue) { + d_node_id_instance[instance_id] = item_node_id * 2 + 1; } else { - d_node_id_instance[item.instance_id] = item_node_id * 2 + 2; + d_node_id_instance[instance_id] = item_node_id * 2 + 2; } } - } -} + }); -void GPUBuilder::UpdateNodeId(int level) { - // Update all nodes based on missing direction - { - const bst_uint n = gpu_data->node_id_instance.size(); - const bst_uint ITEMS_PER_THREAD = 8; - const bst_uint BLOCK_THREADS = 256; - const bst_uint GRID_SIZE = - div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); - - update_nodeid_missing_kernel< - ITEMS_PER_THREAD><<>>( - raw(gpu_data->node_id_instance), raw(gpu_data->nodes), n); - - safe_cuda(cudaDeviceSynchronize()); - } - - // Update node based on fvalue where exists - { - const bst_uint n = gpu_data->fvalues.size(); - const bst_uint ITEMS_PER_THREAD = 4; - const bst_uint BLOCK_THREADS = 256; - const bst_uint GRID_SIZE = - div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS); - - // Use smem cache version if possible - const bool cache_nodes = level < 7; - int n_nodes = (1 << (level + 1)) - 1; - int smem_size = cache_nodes ? sizeof(Node) * n_nodes : 0; - update_nodeid_fvalue_kernel< - ITEMS_PER_THREAD><<>>( - raw(gpu_data->node_id), raw(gpu_data->node_id_instance), - raw(gpu_data->items), raw(gpu_data->nodes), n_nodes, - raw(gpu_data->foffsets), raw(gpu_data->feature_id), - gpu_data->fvalues.size(), gpu_data->n_features, cache_nodes); - - safe_cuda(cudaGetLastError()); - safe_cuda(cudaDeviceSynchronize()); - } + dh::safe_cuda(cudaDeviceSynchronize()); gpu_data->GatherNodeId(); } void GPUBuilder::Sort(int level) { - thrust::sequence(gpu_data->sort_index_in.begin(), - gpu_data->sort_index_in.end()); + thrust::sequence(gpu_data->sort_index_in.tbegin(), + gpu_data->sort_index_in.tend()); - cub::DoubleBuffer d_keys(raw(gpu_data->node_id), - raw(gpu_data->node_id_temp)); - cub::DoubleBuffer d_values(raw(gpu_data->sort_index_in), - raw(gpu_data->sort_index_out)); + cub::DoubleBuffer d_keys(gpu_data->node_id.data(), + gpu_data->node_id_temp.data()); + cub::DoubleBuffer d_values(gpu_data->sort_index_in.data(), + gpu_data->sort_index_out.data()); - if (!gpu_data->cub_mem.IsAllocated()) { - cub::DeviceSegmentedRadixSort::SortPairs( - gpu_data->cub_mem.d_temp_storage, gpu_data->cub_mem.temp_storage_bytes, - d_keys, d_values, gpu_data->fvalues.size(), gpu_data->n_features, - raw(gpu_data->foffsets), raw(gpu_data->foffsets) + 1); - gpu_data->cub_mem.Allocate(); - } + size_t temp_size = gpu_data->cub_mem.size(); cub::DeviceSegmentedRadixSort::SortPairs( - gpu_data->cub_mem.d_temp_storage, gpu_data->cub_mem.temp_storage_bytes, - d_keys, d_values, gpu_data->fvalues.size(), gpu_data->n_features, - raw(gpu_data->foffsets), raw(gpu_data->foffsets) + 1); - + gpu_data->cub_mem.data(), temp_size, d_keys, d_values, + gpu_data->fvalues.size(), gpu_data->n_features, gpu_data->foffsets.data(), + gpu_data->foffsets.data() + 1); + + auto zip = thrust::make_zip_iterator(thrust::make_tuple( + gpu_data->fvalues.tbegin(), gpu_data->instance_id.tbegin())); + auto zip_temp = thrust::make_zip_iterator(thrust::make_tuple( + gpu_data->fvalues_temp.tbegin(), gpu_data->instance_id_temp.tbegin())); thrust::gather(thrust::device_pointer_cast(d_values.Current()), thrust::device_pointer_cast(d_values.Current()) + gpu_data->sort_index_out.size(), - gpu_data->items.begin(), gpu_data->items_temp.begin()); - - thrust::copy(gpu_data->items_temp.begin(), gpu_data->items_temp.end(), - gpu_data->items.begin()); + zip, zip_temp); + thrust::copy(zip_temp, zip_temp + gpu_data->fvalues.size(), zip); - if (d_keys.Current() == raw(gpu_data->node_id_temp)) { - thrust::copy(gpu_data->node_id_temp.begin(), gpu_data->node_id_temp.end(), - gpu_data->node_id.begin()); + if (d_keys.Current() == gpu_data->node_id_temp.data()) { + thrust::copy(gpu_data->node_id_temp.tbegin(), gpu_data->node_id_temp.tend(), + gpu_data->node_id.tbegin()); } } @@ -330,8 +254,8 @@ void GPUBuilder::Update(const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree) { cudaProfilerStart(); try { - Timer update; - Timer t; + dh::Timer update; + dh::Timer t; this->InitData(gpair, *p_fmat, *p_tree); t.printElapsed("init data"); this->InitFirstNode(); @@ -341,24 +265,23 @@ void GPUBuilder::Update(const std::vector &gpair, DMatrix *p_fmat, t.reset(); if (level > 0) { - Timer update_node; + dh::Timer update_node; this->UpdateNodeId(level); update_node.printElapsed("node"); } if (level > 0 && !use_multiscan_algorithm) { - Timer s; + dh::Timer s; this->Sort(level); s.printElapsed("sort"); } - Timer split; - find_split(raw(gpu_data->items), raw(gpu_data->split_candidates), - raw(gpu_data->node_id), raw(gpu_data->nodes), - (bst_uint)gpu_data->fvalues.size(), gpu_data->n_features, - raw(gpu_data->foffsets), raw(gpu_data->node_sums), - raw(gpu_data->node_offsets), gpu_data->param, level, - use_multiscan_algorithm); + dh::Timer split; + find_split(gpu_data->items_iter, gpu_data->split_candidates.data(), + gpu_data->nodes.data(), (bst_uint)gpu_data->fvalues.size(), + gpu_data->n_features, gpu_data->foffsets.data(), + gpu_data->node_sums.data(), gpu_data->node_offsets.data(), + gpu_data->param, level, use_multiscan_algorithm); split.printElapsed("split"); @@ -379,30 +302,71 @@ void GPUBuilder::Update(const std::vector &gpair, DMatrix *p_fmat, cudaProfilerStop(); } +float GPUBuilder::GetSubsamplingRate(MetaInfo info) { + float subsample = 1.0; + size_t required = 10 * info.num_row + 44 * info.num_nonzero; + size_t available = dh::available_memory(); + while (available < required) { + subsample -= 0.05; + required = 10 * info.num_row + subsample * (44 * info.num_nonzero); + } + + return subsample; +} + void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, const RegTree &tree) { - CHECK_EQ(tree.param.num_nodes, tree.param.num_roots) - << "ColMaker: can only grow new tree"; - CHECK(fmat.SingleColBlock()) << "GPUMaker: must have single column block"; if (gpu_data->IsAllocated()) { - gpu_data->Reset(gpair, fvalues, instance_id); + gpu_data->Reset(gpair); return; } - Timer t; + dh::Timer t; MetaInfo info = fmat.info(); - dmlc::DataIter *iter = fmat.ColIterator(); + + // Work out if dataset will fit on GPU + float subsample = this->GetSubsamplingRate(info); + CHECK(subsample > 0.0); + if (!param.silent && subsample < param.subsample) { + LOG(CONSOLE) << "Not enough device memory for entire dataset."; + } + + // Override subsample parameter if user-specified parameter is lower + subsample = std::min(param.subsample, subsample); + + std::vector row_flags; + + if (subsample < 1.0) { + if (!param.silent && subsample < 1.0) { + LOG(CONSOLE) << "Subsampling " << subsample * 100 << "% of rows."; + } + + const RowSet &rowset = fmat.buffered_rowset(); + row_flags.resize(info.num_row); + std::bernoulli_distribution coin_flip(subsample); + auto &rnd = common::GlobalRandom(); + for (size_t i = 0; i < rowset.size(); ++i) { + const bst_uint ridx = rowset[i]; + if (gpair[ridx].hess < 0.0f) + continue; + row_flags[ridx] = coin_flip(rnd); + } + } std::vector foffsets; foffsets.push_back(0); std::vector feature_id; + std::vector fvalues; + std::vector instance_id; fvalues.reserve(info.num_col * info.num_row); instance_id.reserve(info.num_col * info.num_row); feature_id.reserve(info.num_col * info.num_row); + dmlc::DataIter *iter = fmat.ColIterator(); + while (iter->Next()) { const ColBatch &batch = iter->Value(); @@ -411,9 +375,18 @@ void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, for (const ColBatch::Entry *it = col.data; it != col.data + col.length; it++) { - fvalues.push_back(it->fvalue); - instance_id.push_back(it->index); - feature_id.push_back(i); + bst_uint inst_id = it->index; + if (subsample < 1.0) { + if (row_flags[inst_id]) { + fvalues.push_back(it->fvalue); + instance_id.push_back(inst_id); + feature_id.push_back(i); + } + } else { + fvalues.push_back(it->fvalue); + instance_id.push_back(inst_id); + feature_id.push_back(i); + } } foffsets.push_back(fvalues.size()); } @@ -430,13 +403,15 @@ void GPUBuilder::InitData(const std::vector &gpair, DMatrix &fmat, void GPUBuilder::InitFirstNode() { // Build the root node on the CPU and copy to device gpu_gpair sum_gradients = - thrust::reduce(gpu_data->gpair.begin(), gpu_data->gpair.end(), + thrust::reduce(gpu_data->gpair.tbegin(), gpu_data->gpair.tend(), gpu_gpair(0, 0), cub::Sum()); - gpu_data->nodes[0] = Node( + Node tmp = Node( sum_gradients, CalcGain(gpu_data->param, sum_gradients.grad(), sum_gradients.hess()), CalcWeight(gpu_data->param, sum_gradients.grad(), sum_gradients.hess())); + + thrust::copy_n(&tmp, 1, gpu_data->nodes.tbegin()); } enum NodeType { @@ -469,7 +444,7 @@ void flag_nodes(const thrust::host_vector &nodes, // Copy gpu dense representation of tree to xgboost sparse representation void GPUBuilder::CopyTree(RegTree &tree) { - thrust::host_vector h_nodes = gpu_data->nodes; + std::vector h_nodes = gpu_data->nodes.as_vector(); std::vector node_flags(h_nodes.size(), UNUSED); flag_nodes(h_nodes, &node_flags, 0, NODE); diff --git a/plugin/updater_gpu/src/gpu_builder.cuh b/plugin/updater_gpu/src/gpu_builder.cuh index abfecefcd720..fdb308285c3e 100644 --- a/plugin/updater_gpu/src/gpu_builder.cuh +++ b/plugin/updater_gpu/src/gpu_builder.cuh @@ -22,11 +22,12 @@ class GPUBuilder { void Update(const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree); + void UpdateNodeId(int level); private: void InitData(const std::vector &gpair, DMatrix &fmat, // NOLINT const RegTree &tree); - void UpdateNodeId(int level); + float GetSubsamplingRate(MetaInfo info); void Sort(int level); void InitFirstNode(); void CopyTree(RegTree &tree); // NOLINT @@ -34,13 +35,8 @@ class GPUBuilder { TrainParam param; GPUData *gpu_data; - // Keep host copies of these arrays as the device versions change between - // boosting iterations - std::vector fvalues; - std::vector instance_id; - int multiscan_levels = - 5; // Number of levels before switching to sorting algorithm + 0; // Number of levels before switching to sorting algorithm }; } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/src/types.cuh b/plugin/updater_gpu/src/types.cuh index 63689b81272b..544d1f0e1235 100644 --- a/plugin/updater_gpu/src/types.cuh +++ b/plugin/updater_gpu/src/types.cuh @@ -3,6 +3,7 @@ */ #pragma once #include +#include // The linter is not very smart and thinks we need this namespace xgboost { namespace tree { @@ -78,11 +79,13 @@ struct gpu_gpair { } }; -struct Item { - bst_uint instance_id; - float fvalue; - gpu_gpair gpair; -}; +typedef thrust::device_vector::iterator uint_iter; +typedef thrust::device_vector::iterator gpair_iter; +typedef thrust::device_vector::iterator float_iter; +typedef thrust::device_vector::iterator node_id_iter; +typedef thrust::permutation_iterator gpair_perm_iter; +typedef thrust::tuple ItemTuple; +typedef thrust::zip_iterator ItemIter; struct GPUTrainingParam { // minimum amount of hessian(weight) allowed in a child diff --git a/src/tree/param.h b/src/tree/param.h index 5d247e23108b..ff05fe0e9b3d 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -7,6 +7,8 @@ #ifndef XGBOOST_TREE_PARAM_H_ #define XGBOOST_TREE_PARAM_H_ +#include +#include #include #include #include