Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in interaction constraint. #4587

Merged
merged 1 commit into from Jun 20, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 21 additions & 19 deletions src/tree/constraints.cu
Expand Up @@ -134,9 +134,9 @@ void FeatureInteractionConstraint::Configure(
feature_buffer_ = dh::ToSpan(d_feature_buffer_storage_);

// --- Initialize result buffers.
output_buffer_bits_storage_.resize(n_features);
output_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
output_buffer_bits_ = BitField(dh::ToSpan(output_buffer_bits_storage_));
input_buffer_bits_storage_.resize(n_features);
input_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
input_buffer_bits_ = BitField(dh::ToSpan(input_buffer_bits_storage_));
result_buffer_.resize(n_features);
s_result_buffer_ = dh::ToSpan(result_buffer_);
Expand All @@ -155,10 +155,10 @@ void FeatureInteractionConstraint::Reset() {
}

__global__ void ClearBuffersKernel(
BitField result_buffer_self, BitField result_buffer_input, BitField feature_buffer) {
BitField result_buffer_output, BitField result_buffer_input) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < result_buffer_self.Size()) {
result_buffer_self.Clear(tid);
if (tid < result_buffer_output.Size()) {
result_buffer_output.Clear(tid);
}
if (tid < result_buffer_input.Size()) {
result_buffer_input.Clear(tid);
Expand All @@ -172,7 +172,7 @@ void FeatureInteractionConstraint::ClearBuffers() {
const int n_grids = static_cast<int>(
dh::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
ClearBuffersKernel<<<n_grids, kBlockThreads>>>(
output_buffer_bits_, input_buffer_bits_, feature_buffer_);
output_buffer_bits_, input_buffer_bits_);
}

common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
Expand All @@ -199,18 +199,18 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
return {s_result_buffer_.data(), s_result_buffer_.data() + n_available};
}

__global__ void QueryFeatureListKernel(common::Span<int32_t> feature_list_input,
common::Span<int32_t> node_feature_list,
BitField result_buffer_input,
BitField result_buffer_output) {
__global__ void SetInputBufferKernel(common::Span<int32_t> feature_list_input,
BitField result_buffer_input) {
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < feature_list_input.size()) {
result_buffer_input.Set(feature_list_input[tid]);
}
}

if (tid < node_feature_list.size()) {
result_buffer_output.Set(node_feature_list[tid]);
}
__global__ void QueryFeatureListKernel(BitField node_constraints,
BitField result_buffer_input,
BitField result_buffer_output) {
result_buffer_output |= node_constraints;
result_buffer_output &= result_buffer_input;
}

Expand All @@ -219,17 +219,19 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
if (!has_constraint_ || nid == 0) {
return feature_list;
}
auto selected = this->QueryNode(nid);

ClearBuffers();

BitField node_constraints = s_node_constraints_[nid];
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());

int constexpr kBlockThreads = 256;
const int n_grids = static_cast<int>(
dh::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
SetInputBufferKernel<<<n_grids, kBlockThreads>>>(feature_list, input_buffer_bits_);

QueryFeatureListKernel<<<n_grids, kBlockThreads>>>
(feature_list,
selected,
input_buffer_bits_,
output_buffer_bits_);
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>(
node_constraints, input_buffer_bits_, output_buffer_bits_);

thrust::counting_iterator<int32_t> begin(0);
thrust::counting_iterator<int32_t> end(result_buffer_.size());
Expand Down