Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch UpdatePosition calls #7925

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2b4cf67
Remove single_precision_histogram
RAMitchell Apr 21, 2022
f140ebc
Batch nodes from driver
RAMitchell Apr 25, 2022
80a3e78
Categoricals broken
RAMitchell Apr 29, 2022
e1fb702
Refactor categoricals
RAMitchell May 1, 2022
dc100cf
Refactor categoricals 2
RAMitchell May 2, 2022
bc74458
Skip copy if no categoricals
RAMitchell May 2, 2022
c4f8eac
Review comment
RAMitchell May 5, 2022
2a53849
Merge branch 'master' of github.com:dmlc/xgboost into categorical
RAMitchell May 5, 2022
a1cddaa
Revert "Categoricals broken"
RAMitchell May 5, 2022
829bda6
Merge branch 'master' of github.com:dmlc/xgboost into fuse
RAMitchell May 5, 2022
0bc8745
Merge branch 'categorical' of github.com:RAMitchell/xgboost into fuse
RAMitchell May 5, 2022
fd0e25e
Lint
RAMitchell May 5, 2022
9fab64e
Merge branch 'master' of github.com:dmlc/xgboost into fuse
RAMitchell May 5, 2022
56785f3
Revert "Revert "Categoricals broken""
RAMitchell May 6, 2022
1dd1a6c
Limit concurrent nodes
RAMitchell May 10, 2022
8751d14
Lint
RAMitchell May 11, 2022
49809bf
Basic blockwise partitioning
RAMitchell May 11, 2022
181d7cf
Working block partition
RAMitchell May 12, 2022
666eb9b
Reduction
RAMitchell May 12, 2022
66173c7
Some failing tests
RAMitchell May 13, 2022
ec7fea8
Handle empty candidate
RAMitchell May 13, 2022
49c5f90
Cleanup
RAMitchell May 13, 2022
bd48082
experiments
RAMitchell May 14, 2022
c3ef1f6
Improvements
RAMitchell May 14, 2022
ba8bbdf
Fused scan
RAMitchell May 14, 2022
f4ef4ca
Register blocking
RAMitchell May 15, 2022
9c27dd0
Cleanup
RAMitchell May 17, 2022
0bcc84a
Working tests
RAMitchell May 18, 2022
723ff47
Transplanted new code
RAMitchell May 18, 2022
199bed9
Optimised
RAMitchell May 19, 2022
0e35e99
Do not initialise data structures to maximum possible tree size.
RAMitchell May 19, 2022
daa9b56
Comments, cleanup
RAMitchell May 19, 2022
8ab989e
Refactor FinalizePosition
RAMitchell May 20, 2022
d50ec4b
Remove redundant functions
RAMitchell May 20, 2022
c34c3ad
Lint
RAMitchell May 20, 2022
e534edc
Merge branch 'master' of github.com:dmlc/xgboost into batch-position-…
RAMitchell May 20, 2022
47bfc6e
Remove old kernel
RAMitchell May 20, 2022
a53ba87
Add tests for AtomicIncrement
RAMitchell May 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
127 changes: 5 additions & 122 deletions src/tree/gpu_hist/row_partitioner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,98 +10,18 @@

namespace xgboost {
namespace tree {
struct IndexFlagTuple {
size_t idx;
size_t flag;
};

struct IndexFlagOp {
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a,
const IndexFlagTuple& b) const {
return {b.idx, a.flag + b.flag};
}
};

struct WriteResultsFunctor {
bst_node_t left_nidx;
common::Span<bst_node_t> position_in;
common::Span<bst_node_t> position_out;
common::Span<RowPartitioner::RowIndexT> ridx_in;
common::Span<RowPartitioner::RowIndexT> ridx_out;
int64_t* d_left_count;

__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
// the ex_scan_result represents how many rows have been assigned to left
// node so far during scan.
int scatter_address;
if (position_in[x.idx] == left_nidx) {
scatter_address = x.flag - 1; // -1 because inclusive scan
} else {
// current number of rows belong to right node + total number of rows
// belong to left node
scatter_address = (x.idx - x.flag) + *d_left_count;
}
// copy the node id to output
position_out[scatter_address] = position_in[x.idx];
ridx_out[scatter_address] = ridx_in[x.idx];

// Discard
return {};
}
};

// Implement partitioning via single scan operation using transform output to
// write the result
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out,
bst_node_t left_nidx, bst_node_t,
int64_t* d_left_count, cudaStream_t stream) {
WriteResultsFunctor write_results{left_nidx, position, position_out,
ridx, ridx_out, d_left_count};
auto discard_write_iterator =
thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
auto counting = thrust::make_counting_iterator(0llu);
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
counting, [=] __device__(size_t idx) {
return IndexFlagTuple{idx, static_cast<size_t>(position[idx] == left_nidx)};
});
size_t temp_bytes = 0;
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(),
position.size(), stream);
dh::TemporaryArray<int8_t> temp(temp_bytes);
cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(),
position.size(), stream);
}

void Reset(int device_idx, common::Span<RowPartitioner::RowIndexT> ridx,
common::Span<bst_node_t> position) {
CHECK_EQ(ridx.size(), position.size());
dh::LaunchN(ridx.size(), [=] __device__(size_t idx) {
ridx[idx] = idx;
position[idx] = 0;
});
}

RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
: device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows),
ridx_b_(num_rows), position_b_(num_rows) {
: device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows), scan_inputs_(num_rows) {
dh::safe_cuda(cudaSetDevice(device_idx_));
ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_};
position_ = dh::DoubleBuffer<bst_node_t>{&position_a_, &position_b_};
ridx_segments_.emplace_back(Segment(0, num_rows));

Reset(device_idx, ridx_.CurrentSpan(), position_.CurrentSpan());
left_counts_.resize(256);
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
streams_.resize(2);
for (auto& stream : streams_) {
dh::safe_cuda(cudaStreamCreate(&stream));
}
}

RowPartitioner::~RowPartitioner() {
dh::safe_cuda(cudaSetDevice(device_idx_));
for (auto& stream : streams_) {
Expand All @@ -117,16 +37,13 @@ common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
if (segment.Size() == 0) {
return {};
}
return ridx_.CurrentSpan().subspan(segment.begin, segment.Size());
return dh::ToSpan(ridx_).subspan(segment.begin, segment.Size());
}

common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
return ridx_.CurrentSpan();
return dh::ToSpan(ridx_);
}

common::Span<const bst_node_t> RowPartitioner::GetPosition() {
return position_.CurrentSpan();
}
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
bst_node_t nidx) {
auto span = GetRows(nidx);
Expand All @@ -135,39 +52,5 @@ std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
return rows;
}

std::vector<bst_node_t> RowPartitioner::GetPositionHost() {
auto span = GetPosition();
std::vector<bst_node_t> position(span.size());
dh::CopyDeviceSpanToVector(&position, span);
return position;
}

void RowPartitioner::SortPositionAndCopy(const Segment& segment,
bst_node_t left_nidx,
bst_node_t right_nidx,
int64_t* d_left_count,
cudaStream_t stream) {
SortPosition(
// position_in
common::Span<bst_node_t>(position_.Current() + segment.begin,
segment.Size()),
// position_out
common::Span<bst_node_t>(position_.Other() + segment.begin,
segment.Size()),
// row index in
common::Span<RowIndexT>(ridx_.Current() + segment.begin, segment.Size()),
// row index out
common::Span<RowIndexT>(ridx_.Other() + segment.begin, segment.Size()),
left_nidx, right_nidx, d_left_count, stream);
// Copy back key/value
const auto d_position_current = position_.Current() + segment.begin;
const auto d_position_other = position_.Other() + segment.begin;
const auto d_ridx_current = ridx_.Current() + segment.begin;
const auto d_ridx_other = ridx_.Other() + segment.begin;
dh::LaunchN(segment.Size(), stream, [=] __device__(size_t idx) {
d_position_current[idx] = d_position_other[idx];
d_ridx_current[idx] = d_ridx_other[idx];
});
}
}; // namespace tree
}; // namespace xgboost