-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix memory usage of device sketching #5407
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,19 @@ struct EntryCompareOp { | |
} | ||
}; | ||
|
||
// Compute number of sample cuts needed on local node to maintain accuracy | ||
// We take more cuts than needed and then reduce them later | ||
size_t RequiredSampleCuts(int max_bins, size_t num_rows) { | ||
constexpr int kFactor = 8; | ||
double eps = 1.0 / (kFactor * max_bins); | ||
size_t dummy_nlevel; | ||
size_t num_cuts; | ||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel( | ||
num_rows, eps, &dummy_nlevel, &num_cuts); | ||
return std::min(num_cuts, num_rows); | ||
} | ||
|
||
|
||
// Count the entries in each column and exclusive scan | ||
void GetColumnSizesScan(int device, | ||
dh::caching_device_vector<size_t>* column_sizes_scan, | ||
|
@@ -208,7 +221,8 @@ void ExtractWeightedCuts(int device, Span<SketchEntry> cuts, | |
void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, | ||
SketchContainer* sketch_container, int num_cuts, | ||
size_t num_columns) { | ||
dh::XGBCachingDeviceAllocator<char> alloc; | ||
dh::XGBCachingDeviceAllocator<char> caching_alloc; | ||
dh::XGBDeviceAllocator<char> alloc; | ||
const auto& host_data = page.data.ConstHostVector(); | ||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin, | ||
host_data.begin() + end); | ||
|
@@ -221,7 +235,7 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, | |
num_columns); | ||
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan); | ||
|
||
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts); | ||
dh::device_vector<SketchEntry> cuts(num_columns * num_cuts); | ||
RAMitchell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ExtractCuts(device, {cuts.data().get(), cuts.size()}, num_cuts, | ||
{sorted_entries.data().get(), sorted_entries.size()}, | ||
{column_sizes_scan.data().get(), column_sizes_scan.size()}); | ||
|
@@ -235,7 +249,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page, | |
Span<const float> weights, size_t begin, size_t end, | ||
SketchContainer* sketch_container, int num_cuts, | ||
size_t num_columns) { | ||
dh::XGBCachingDeviceAllocator<char> alloc; | ||
dh::XGBCachingDeviceAllocator<char> caching_alloc; | ||
const auto& host_data = page.data.ConstHostVector(); | ||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin, | ||
host_data.begin() + end); | ||
|
@@ -255,12 +269,12 @@ void ProcessWeightedBatch(int device, const SparsePage& page, | |
}); | ||
|
||
// Sort | ||
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(), | ||
thrust::sort_by_key(thrust::cuda::par(caching_alloc), sorted_entries.begin(), | ||
sorted_entries.end(), temp_weights.begin(), | ||
EntryCompareOp()); | ||
|
||
// Scan weights | ||
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), | ||
thrust::inclusive_scan_by_key(thrust::cuda::par(caching_alloc), | ||
sorted_entries.begin(), sorted_entries.end(), | ||
temp_weights.begin(), temp_weights.begin(), | ||
[=] __device__(const Entry& a, const Entry& b) { | ||
|
@@ -288,28 +302,29 @@ void ProcessWeightedBatch(int device, const SparsePage& page, | |
|
||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, | ||
size_t sketch_batch_num_elements) { | ||
// Configure batch size based on available memory | ||
bool has_weights = dmat->Info().weights_.Size() > 0; | ||
size_t num_cuts = RequiredSampleCuts(max_bins, dmat->Info().num_row_); | ||
if (sketch_batch_num_elements == 0) { | ||
int bytes_per_element = has_weights ? 24 : 16; | ||
size_t bytes_cuts = num_cuts * dmat->Info().num_col_ * sizeof(SketchEntry); | ||
// use up to 80% of available space | ||
sketch_batch_num_elements = | ||
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element; | ||
} | ||
|
||
HistogramCuts cuts; | ||
DenseCuts dense_cuts(&cuts); | ||
SketchContainer sketch_container(max_bins, dmat->Info().num_col_, | ||
dmat->Info().num_row_); | ||
|
||
constexpr int kFactor = 8; | ||
double eps = 1.0 / (kFactor * max_bins); | ||
size_t dummy_nlevel; | ||
size_t num_cuts; | ||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel( | ||
dmat->Info().num_row_, eps, &dummy_nlevel, &num_cuts); | ||
num_cuts = std::min(num_cuts, dmat->Info().num_row_); | ||
if (sketch_batch_num_elements == 0) { | ||
sketch_batch_num_elements = dmat->Info().num_nonzero_; | ||
} | ||
dmat->Info().weights_.SetDevice(device); | ||
for (const auto& batch : dmat->GetBatches<SparsePage>()) { | ||
size_t batch_nnz = batch.data.Size(); | ||
for (auto begin = 0ull; begin < batch_nnz; | ||
begin += sketch_batch_num_elements) { | ||
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements)); | ||
if (dmat->Info().weights_.Size() > 0) { | ||
if (has_weights) { | ||
ProcessWeightedBatch( | ||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end, | ||
&sketch_container, num_cuts, dmat->Info().num_col_); | ||
|
@@ -369,6 +384,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing, | |
// Work out how many valid entries we have in each column | ||
dh::caching_device_vector<size_t> column_sizes_scan(adapter->NumColumns() + 1, | ||
0); | ||
|
||
auto d_column_sizes_scan = column_sizes_scan.data().get(); | ||
IsValidFunctor is_valid(missing); | ||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) { | ||
|
@@ -385,7 +401,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing, | |
size_t num_valid = host_column_sizes_scan.back(); | ||
|
||
// Copy current subset of valid elements into temporary storage and sort | ||
thrust::device_vector<Entry> sorted_entries(num_valid); | ||
dh::device_vector<Entry> sorted_entries(num_valid); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think a caching_device_vector can be used everywhere. what is precluding its usage? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep that should be used everywhere here. Allocations larger than 1gb will just be standard allocations anyway. The only danger of caching_device_vector is that it does not default initialise memory. |
||
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + begin, | ||
entry_iter + end, sorted_entries.begin(), is_valid); | ||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), | ||
|
@@ -406,6 +422,17 @@ template <typename AdapterT> | |
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, | ||
float missing, | ||
size_t sketch_batch_num_elements) { | ||
size_t num_cuts = RequiredSampleCuts(num_bins, adapter->NumRows()); | ||
if (sketch_batch_num_elements == 0) { | ||
int bytes_per_element = 16; | ||
size_t bytes_cuts = num_cuts * adapter->NumColumns() * sizeof(SketchEntry); | ||
size_t bytes_num_columns = (adapter->NumColumns() + 1) * sizeof(size_t); | ||
// use up to 80% of available space | ||
sketch_batch_num_elements = (dh::AvailableMemory(adapter->DeviceIdx()) - | ||
bytes_cuts - bytes_num_columns) * | ||
0.8 / bytes_per_element; | ||
} | ||
|
||
CHECK(adapter->NumRows() != data::kAdapterUnknownSize); | ||
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize); | ||
|
||
|
@@ -421,16 +448,6 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, | |
SketchContainer sketch_container(num_bins, adapter->NumColumns(), | ||
adapter->NumRows()); | ||
|
||
constexpr int kFactor = 8; | ||
double eps = 1.0 / (kFactor * num_bins); | ||
size_t dummy_nlevel; | ||
size_t num_cuts; | ||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel( | ||
adapter->NumRows(), eps, &dummy_nlevel, &num_cuts); | ||
num_cuts = std::min(num_cuts, adapter->NumRows()); | ||
if (sketch_batch_num_elements == 0) { | ||
sketch_batch_num_elements = batch.Size(); | ||
} | ||
for (auto begin = 0ull; begin < batch.Size(); | ||
begin += sketch_batch_num_elements) { | ||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,20 +44,13 @@ struct GPUHistMakerTrainParam | |
: public XGBoostParameter<GPUHistMakerTrainParam> { | ||
bool single_precision_histogram; | ||
bool deterministic_histogram; | ||
// number of rows in a single GPU batch | ||
int gpu_batch_nrows; | ||
bool debug_synchronize; | ||
// declare parameters | ||
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { | ||
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( | ||
"Use single precision to build histograms."); | ||
DMLC_DECLARE_FIELD(deterministic_histogram).set_default(true).describe( | ||
"Pre-round the gradient for obtaining deterministic gradient histogram."); | ||
DMLC_DECLARE_FIELD(gpu_batch_nrows) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are there existing use-cases that use this? it looks like the (breaking) new behavior is to auto-deduce and i'm wondering if there are configs that use -1 to pull everything in one shot as opposed to looping (with perhaps better latencies). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Current implementation will use up to 80% of available memory so the 'do everything in one batch' approach would only be slightly better in the case where >80% memory is used. Current autodetect behaviour is able to use more available memory than the old implementation and would have faster latencies. The |
||
.set_lower_bound(-1) | ||
.set_default(0) | ||
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; " | ||
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce"); | ||
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe( | ||
"Check if all distributed tree are identical after tree construction."); | ||
} | ||
|
@@ -1018,7 +1011,6 @@ class GPUHistMakerSpecialised { | |
BatchParam batch_param{ | ||
device_, | ||
param_.max_bin, | ||
hist_maker_param_.gpu_batch_nrows, | ||
generic_param_->gpu_page_size | ||
}; | ||
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we get rid of this and use the caching_alloc throughout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes