Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] Cleanup gpu_id configuration. #6971

Closed
wants to merge 17 commits into from
22 changes: 22 additions & 0 deletions src/gbm/gblinear.cc
Expand Up @@ -67,11 +67,32 @@ class GBLinear : public GradientBooster {
sum_weight_complete_(false),
is_converged_(false) {}

void ValidateUpdater() {
if (generic_param_->gpu_id != GenericParameter::kCpuId) {
// On GPU.
CHECK(param_.updater == "gpu_coord_descent" || param_.updater == "coord_descent")
<< "`gpu_id` is set to: " << generic_param_->gpu_id << ". "
<< "Only coordinate descent supports GPU training. "
<< "Set `updater` to `coord_descent` or `gpu_coord_descent` along with "
<< "`gpu_id` to enable GPU acceleration.";
} else {
CHECK_NE(param_.updater, "gpu_coord_descent")
<< "[Internal Error]: `gpu_coord_descent` is used but `gpu_id` is "
"configured: "
<< generic_param_->gpu_id;
}
}

void Configure(const Args& cfg) override {
if (model_.weight.size() == 0) {
model_.Configure(cfg);
}
param_.UpdateAllowUnknown(cfg);
if (generic_param_->gpu_id != GenericParameter::kCpuId &&
param_.updater == "coord_descent") {
param_.updater = "gpu_coord_descent";
}

updater_.reset(LinearUpdater::Create(param_.updater, generic_param_));
updater_->Configure(cfg);
monitor_.Init("GBLinear");
Expand Down Expand Up @@ -126,6 +147,7 @@ class GBLinear : public GradientBooster {
HostDeviceVector<GradientPair> *in_gpair,
PredictionCacheEntry*) override {
monitor_.Start("DoBoost");
this->ValidateUpdater();

model_.LazyInitModel();
this->LazySumWeights(p_fmat);
Expand Down
282 changes: 215 additions & 67 deletions src/gbm/gbtree.cc

Large diffs are not rendered by default.

16 changes: 10 additions & 6 deletions src/gbm/gbtree.h
Expand Up @@ -289,7 +289,7 @@ class GBTree : public GradientBooster {
}
LOG(FATAL) << msg;
} else {
bool success = this->GetPredictor()->InplacePredict(
bool success = this->GetPredictor(false)->InplacePredict(
x, p_m, model_, missing, out_preds, tree_begin, tree_end);
CHECK(success) << msg;
}
Expand All @@ -312,7 +312,7 @@ class GBTree : public GradientBooster {
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, "
"n_iteration), use model slicing instead.";
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, tree_end);
this->GetPredictor(false)->PredictLeaf(p_fmat, out_preds, model_, tree_end);
}

void PredictContribution(DMatrix* p_fmat,
Expand All @@ -325,7 +325,7 @@ class GBTree : public GradientBooster {
CHECK_EQ(tree_begin, 0)
<< "Predict contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictContribution(
this->GetPredictor(false)->PredictContribution(
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
}

Expand All @@ -338,7 +338,7 @@ class GBTree : public GradientBooster {
CHECK_EQ(tree_begin, 0)
<< "Predict interaction contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictInteractionContributions(
this->GetPredictor(false)->PredictInteractionContributions(
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
}

Expand All @@ -358,8 +358,10 @@ class GBTree : public GradientBooster {
int bst_group,
std::vector<std::unique_ptr<RegTree> >* ret);

std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
DMatrix* f_dmat = nullptr) const;
std::unique_ptr<Predictor> const &
GetPredictor(bool is_training,
HostDeviceVector<float> const *out_pred = nullptr,
DMatrix *f_dmat = nullptr) const;

// commit new trees all at once
virtual void CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees,
Expand All @@ -378,6 +380,7 @@ class GBTree : public GradientBooster {
Args cfg_;
// the updaters that can be applied to each of tree
std::vector<std::unique_ptr<TreeUpdater>> updaters_;

// Predictors
std::unique_ptr<Predictor> cpu_predictor_;
#if defined(XGBOOST_USE_CUDA)
Expand All @@ -386,6 +389,7 @@ class GBTree : public GradientBooster {
#if defined(XGBOOST_USE_ONEAPI)
std::unique_ptr<Predictor> oneapi_predictor_;
#endif // defined(XGBOOST_USE_ONEAPI)

common::Monitor monitor_;
};

Expand Down
6 changes: 4 additions & 2 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -811,7 +811,6 @@ class GPUPredictor : public xgboost::Predictor {
void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *predictions,
const gbm::GBTreeModel &model,
unsigned tree_end) const override {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);

const MetaInfo& info = p_fmat->Info();
Expand Down Expand Up @@ -877,9 +876,12 @@ class GPUPredictor : public xgboost::Predictor {
/*! \brief Reconfigure the device when GPU is changed. */
static size_t ConfigureDevice(int device) {
if (device >= 0) {
dh::safe_cuda(cudaSetDevice(device));
return dh::MaxSharedMemory(device);
} else {
dh::safe_cuda(cudaSetDevice(0));
return dh::MaxSharedMemory(0);
}
return 0;
}
};

Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/predictor/test_predictor.cc
Expand Up @@ -99,6 +99,9 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
learner->SetParam("subsample", "0.5");
learner->SetParam("gpu_id", std::to_string(device));
learner->SetParam("predictor", predictor);
if (predictor == "gpu_predictor") {
learner->SetParam("tree_method", "gpu_hist");
}
for (int32_t it = 0; it < 4; ++it) {
learner->UpdateOneIter(it, m);
}
Expand Down
5 changes: 5 additions & 0 deletions tests/cpp/test_learner.cc
Expand Up @@ -314,6 +314,11 @@ TEST(Learner, GPUConfiguration) {
learner->SetParams({Arg{"tree_method", "hist"},
Arg{"gpu_id", "0"}});
learner->UpdateOneIter(0, p_dmat);
Json config {Object{}};
learner->SaveConfig(&config);
CHECK_EQ(get<String>(config["learner"]["gradient_booster"]
["gbtree_train_param"]["tree_method"]),
"gpu_hist");
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
}
{
Expand Down
7 changes: 3 additions & 4 deletions tests/cpp/test_serialization.cc
Expand Up @@ -166,7 +166,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
learner->Save(&mem_out);
ASSERT_EQ(model_at_kiter, serialised_model_tmp);

learner->SetParam("gpu_id", "0");
// learner->SetParam("gpu_id", "0");
// Pull data to device
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(0);
Expand All @@ -184,9 +184,8 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr

Json m_0 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()});
Json m_1 = Json::Load(StringView{serialised_model_tmp.c_str(), serialised_model_tmp.size()});
// GPU ID is changed as data is coming from device.
ASSERT_EQ(get<Object>(m_0["Config"]["learner"]["generic_param"]).erase("gpu_id"),
get<Object>(m_1["Config"]["learner"]["generic_param"]).erase("gpu_id"));
ASSERT_EQ(get<Object>(m_0["Config"]["learner"]["generic_param"]),
get<Object>(m_1["Config"]["learner"]["generic_param"]));
}
}

Expand Down
11 changes: 8 additions & 3 deletions tests/python-gpu/test_gpu_basic_models.py
Expand Up @@ -6,11 +6,13 @@
sys.path.append("tests/python")
# Don't import the test class, otherwise they will run twice.
import test_callback as test_cb # noqa
import test_basic_models as test_bm
rng = np.random.RandomState(1994)


class TestGPUBasicModels:
cputest = test_cb.TestCallbacks()
cpu_test_cb = test_cb.TestCallbacks()
cpu_test_bm = test_bm.TestModels()

def run_cls(self, X, y, deterministic):
cls = xgb.XGBClassifier(tree_method='gpu_hist',
Expand All @@ -35,9 +37,12 @@ def run_cls(self, X, y, deterministic):

return hash(model_0), hash(model_1)

def test_custom_objective(self):
self.cpu_test_bm.run_custom_objective("gpu_hist")

def test_eta_decay_gpu_hist(self):
self.cputest.run_eta_decay('gpu_hist', True)
self.cputest.run_eta_decay('gpu_hist', False)
self.cpu_test_cb.run_eta_decay('gpu_hist', True)
self.cpu_test_cb.run_eta_decay('gpu_hist', False)

def test_deterministic_gpu_hist(self):
kRows = 1000
Expand Down
6 changes: 5 additions & 1 deletion tests/python-gpu/test_gpu_prediction.py
Expand Up @@ -248,13 +248,16 @@ def predict_df(x):
tm.dataset_strategy, shap_parameter_strategy)
@settings(deadline=None)
def test_shap(self, num_rounds, dataset, param):
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param.update({"gpu_id": 0})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)

test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
bst.set_param({"predictor": "gpu_predictor"})
shap = bst.predict(test_dmat, pred_contribs=True)
margin = bst.predict(test_dmat, output_margin=True)

assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3)

Expand Down Expand Up @@ -332,6 +335,7 @@ def test_predict_categorical_split(self, df):
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)

@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.parametrize("n_classes", [2, 3])
def test_predict_dart(self, n_classes):
from sklearn.datasets import make_classification
Expand Down
12 changes: 10 additions & 2 deletions tests/python/test_basic_models.py
Expand Up @@ -138,8 +138,13 @@ def test_boost_from_existing_model(self):
# behaviour is considered sub-optimal, feel free to change.
assert booster.num_boosted_rounds() == 4

def test_custom_objective(self):
param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'}
def run_custom_objective(self, tree_method=None):
param = {
'max_depth': 2,
'eta': 1,
'objective': 'reg:logistic',
"tree_method": tree_method
}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 10

Expand Down Expand Up @@ -181,6 +186,9 @@ def neg_evalerror(preds, dtrain):
if int(preds2[i] > 0.5) != labels[i]) / float(len(preds2))
assert err == err2

def test_custom_objective(self):
self.run_custom_objective()

def test_multi_eval_metric(self):
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
param = {'max_depth': 2, 'eta': 0.2, 'verbosity': 1,
Expand Down
4 changes: 3 additions & 1 deletion tests/python/test_predict.py
Expand Up @@ -49,13 +49,14 @@ def run_predict_leaf(predictor):
{
"num_parallel_tree": num_parallel_tree,
"num_class": classes,
"predictor": predictor,
"tree_method": "hist",
},
m,
num_boost_round=num_boost_round,
)

booster.set_param({"predictor": predictor})

empty = xgb.DMatrix(np.ones(shape=(0, cols)))
empty_leaf = booster.predict(empty, pred_leaf=True)
assert empty_leaf.shape[0] == 0
Expand All @@ -78,6 +79,7 @@ def run_predict_leaf(predictor):

# When there's only 1 tree, the output is a 1 dim vector
booster = xgb.train({"tree_method": "hist"}, num_boost_round=1, dtrain=m)
booster.set_param({"predictor": predictor})
assert booster.predict(m, pred_leaf=True).shape == (rows, )

return leaf
Expand Down