Skip to content

Commit

Permalink
CPP tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 7, 2021
1 parent d0481b6 commit ed76b8f
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
11 changes: 7 additions & 4 deletions src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ auto SetDeviceToPtr(void *ptr) {
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
};
} // anonymous namespace
}
} // anonymous namespace

void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
Expand Down Expand Up @@ -134,8 +134,11 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {

auto h_num_runs_out = d_num_runs_out.HostSpan()[0];
group_ptr_.clear(); group_ptr_.resize(h_num_runs_out + 1, 0);
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out, group_ptr_.begin() + 1);
thrust::inclusive_scan(group_ptr_.begin(), group_ptr_.end(), group_ptr_.begin());
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
group_ptr_.begin() + 1);
return;
} else if (key == "label_lower_bound") {
CopyInfoImpl(array_interface, &labels_lower_bound_);
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/data/test_array_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows,

template <typename T>
Json Generate2dArrayInterface(int rows, int cols, std::string typestr,
thrust::device_vector<T>* p_data) {
thrust::device_vector<T> *p_data) {
auto& data = *p_data;
thrust::sequence(data.begin(), data.end());

Expand Down
18 changes: 18 additions & 0 deletions tests/cpp/data/test_metainfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,24 @@ TEST(MetaInfo, LoadQid) {
}
}

TEST(MetaInfo, CPUQid) {
xgboost::MetaInfo info;
info.num_row_ = 100;
std::vector<uint32_t> qid(info.num_row_, 0);
for (size_t i = 0; i < qid.size(); ++i) {
qid[i] = i;
}

info.SetInfo("qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_);
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
ASSERT_EQ(info.group_ptr_.front(), 0);
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);

for (size_t i = 0; i < info.num_row_ + 1; ++i) {
ASSERT_EQ(info.group_ptr_[i], i);
}
}

TEST(MetaInfo, Validate) {
xgboost::MetaInfo info;
info.num_row_ = 10;
Expand Down
23 changes: 23 additions & 0 deletions tests/cpp/data/test_metainfo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <xgboost/data.h>
#include <xgboost/json.h>
#include <thrust/device_vector.h>
#include "test_array_interface.h"
#include "../../../src/common/device_helpers.cuh"

namespace xgboost {
Expand Down Expand Up @@ -105,6 +106,28 @@ TEST(MetaInfo, Group) {
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
}

TEST(MetaInfo, GPUQid) {
xgboost::MetaInfo info;
info.num_row_ = 100;
thrust::device_vector<uint32_t> qid(info.num_row_, 0);
for (size_t i = 0; i < qid.size(); ++i) {
qid[i] = i;
}
auto column = Generate2dArrayInterface(info.num_row_, 1, "<u4", &qid);
Json array{std::vector<Json>{column}};
std::string array_str;
Json::Dump(array, &array_str);
info.SetInfo("qid", array_str.c_str());
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
ASSERT_EQ(info.group_ptr_.front(), 0);
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);

for (size_t i = 0; i < info.num_row_ + 1; ++i) {
ASSERT_EQ(info.group_ptr_[i], i);
}
}


TEST(MetaInfo, DeviceExtend) {
dh::safe_cuda(cudaSetDevice(0));
size_t const kRows = 100;
Expand Down

0 comments on commit ed76b8f

Please sign in to comment.