From 48112fea88e07a1d272e0f1e8cfbeb2cfbd7685d Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 13 Sep 2022 18:41:18 +0800 Subject: [PATCH] Merge dispatching into median. --- src/common/stats.cu | 13 ++++++++----- src/common/stats.h | 26 ++++++++++++++++---------- src/objective/regression_obj.cu | 7 +------ 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/common/stats.cu b/src/common/stats.cu index 956c812c9346..dcb04ac4b5de 100644 --- a/src/common/stats.cu +++ b/src/common/stats.cu @@ -2,11 +2,14 @@ * Copyright 2022 by XGBoost Contributors */ -#include "common.h" -#include "stats.cuh" -#include "xgboost/generic_parameters.h" -#include "xgboost/host_device_vector.h" -#include "xgboost/linalg.h" +#include // thrust::make_counting_iterator + +#include "common.h" // common::OptionalWeights +#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend +#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile +#include "xgboost/generic_parameters.h" // Context +#include "xgboost/host_device_vector.h" // HostDeviceVector +#include "xgboost/linalg.h" // linalg::TensorView, UnravelIndex, Apply namespace xgboost { namespace common { diff --git a/src/common/stats.h b/src/common/stats.h index 547063c43a56..f191deb21575 100644 --- a/src/common/stats.h +++ b/src/common/stats.h @@ -103,23 +103,29 @@ inline float Median(Context const*, linalg::TensorView, common:: #endif // !defined(XGBOOST_USE_CUDA) } // namespace cuda -inline float Median(Context const* ctx, linalg::TensorView t, - common::OptionalWeights weights) { +inline float Median(Context const* ctx, linalg::Tensor const& t, + HostDeviceVector const& weights) { if (!ctx->IsCPU()) { - return cuda::Median(ctx, t, weights); + weights.SetDevice(ctx->gpu_id); + auto opt_weights = OptionalWeights(weights.ConstDeviceSpan()); + auto t_v = t.View(ctx->gpu_id); + return cuda::Median(ctx, t_v, opt_weights); } + + auto opt_weights = OptionalWeights(weights.ConstHostSpan()); + auto t_v = t.HostView(); auto iter = common::MakeIndexTransformIter( - [&](size_t i) { return linalg::detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); }); + [&](size_t i) { return linalg::detail::Apply(t_v, linalg::UnravelIndex(i, t_v.Shape())); }); float q{0}; - if (weights.weights.empty()) { - q = common::Quantile(0.5, iter, iter + t.Size()); + if (opt_weights.Empty()) { + q = common::Quantile(0.5, iter, iter + t_v.Size()); } else { - CHECK_NE(t.Shape(1), 0); + CHECK_NE(t_v.Shape(1), 0); auto w_it = common::MakeIndexTransformIter([&](size_t i) { - auto sample_idx = i / t.Shape(1); - return weights[sample_idx]; + auto sample_idx = i / t_v.Shape(1); + return opt_weights[sample_idx]; }); - q = common::WeightedQuantile(0.5, iter, iter + t.Size(), w_it); + q = common::WeightedQuantile(0.5, iter, iter + t_v.Size(), w_it); } return q; } diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index f0087a88d469..fe7e6c84251f 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -711,13 +711,8 @@ class MeanAbsoluteError : public ObjFunction { if (info.num_row_ == 0) { out(0) = 0; invalid++; - } else if (ctx_->IsCPU()) { - out(0) = common::Median(ctx_, info.labels.HostView(), - common::OptionalWeights{info.weights_.ConstHostSpan()}); } else { - info.weights_.SetDevice(ctx_->gpu_id); - out(0) = common::Median(ctx_, info.labels.View(ctx_->gpu_id), - common::OptionalWeights{info.weights_.DeviceSpan()}); + out(0) = common::Median(ctx_, info.labels, info.weights_); } auto world = static_cast(rabit::GetWorldSize());