Skip to content

Commit

Permalink
Deterministic GPU histogram.
Browse files Browse the repository at this point in the history
* Use pre-rounding based method to obtain reproducible floating point
  summation.
* GPU Hist for regression and classification are bit-by-bit reproducible.
* Add doc.
* Switch to thrust reduce for `node_sum_gradient`.
  • Loading branch information
trivialfis committed Mar 3, 2020
1 parent 655cf17 commit 1065c72
Show file tree
Hide file tree
Showing 17 changed files with 378 additions and 97 deletions.
14 changes: 14 additions & 0 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,20 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See tutorial for more information

Additional parameters for `gpu_hist` tree method
================================================

* ``single_precision_histogram``, [default=``false``]

- Use single precision to build histograms. See document for GPU support for more details.

* ``deterministic_histogram``, [default=``true``]

- Build histogram on GPU deterministically. Histogram building is not deterministic due
to the non-associative aspect of floating point summation. We employ a pre-rounding
routine to mitigate the issue, which may lead to slightly lower accuracy. Set to
``false`` to disable it.

Additional parameters for Dart Booster (``booster=dart``)
=========================================================

Expand Down
33 changes: 5 additions & 28 deletions include/xgboost/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ class GradientPairInternal {
/*! \brief second order gradient statistics */
T hess_;

XGBOOST_DEVICE void SetGrad(float g) { grad_ = g; }
XGBOOST_DEVICE void SetHess(float h) { hess_ = h; }
XGBOOST_DEVICE void SetGrad(T g) { grad_ = g; }
XGBOOST_DEVICE void SetHess(T h) { hess_ = h; }

public:
using ValueT = T;

XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {}

XGBOOST_DEVICE GradientPairInternal(float grad, float hess) {
XGBOOST_DEVICE GradientPairInternal(T grad, T hess) {
SetGrad(grad);
SetHess(hess);
}
Expand All @@ -160,8 +160,8 @@ class GradientPairInternal {
SetHess(g.GetHess());
}

XGBOOST_DEVICE float GetGrad() const { return grad_; }
XGBOOST_DEVICE float GetHess() const { return hess_; }
XGBOOST_DEVICE T GetGrad() const { return grad_; }
XGBOOST_DEVICE T GetHess() const { return hess_; }

XGBOOST_DEVICE GradientPairInternal<T> &operator+=(
const GradientPairInternal<T> &rhs) {
Expand Down Expand Up @@ -234,24 +234,6 @@ class GradientPairInternal {
return os;
}
};

template<>
inline XGBOOST_DEVICE float GradientPairInternal<int64_t>::GetGrad() const {
return grad_ * 1e-4f;
}
template<>
inline XGBOOST_DEVICE float GradientPairInternal<int64_t>::GetHess() const {
return hess_ * 1e-4f;
}
template<>
inline XGBOOST_DEVICE void GradientPairInternal<int64_t>::SetGrad(float g) {
grad_ = static_cast<int64_t>(std::round(g * 1e4));
}
template<>
inline XGBOOST_DEVICE void GradientPairInternal<int64_t>::SetHess(float h) {
hess_ = static_cast<int64_t>(std::round(h * 1e4));
}

} // namespace detail

/*! \brief gradient statistics pair usually needed in gradient boosting */
Expand All @@ -260,11 +242,6 @@ using GradientPair = detail::GradientPairInternal<float>;
/*! \brief High precision gradient statistics pair */
using GradientPairPrecise = detail::GradientPairInternal<double>;

/*! \brief High precision gradient statistics pair with integer backed
* storage. Operators are associative where floating point versions are not
* associative. */
using GradientPairInteger = detail::GradientPairInternal<int64_t>;

using Args = std::vector<std::pair<std::string, std::string> >;

/*! \brief small eps gap for minimum split decision. */
Expand Down
1 change: 0 additions & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,6 @@ def get_score(self, fmap='', importance_type='weight'):
if importance_type == 'weight':
# do a simpler tree dump to save time
trees = self.get_dump(fmap, with_stats=False)

fmap = {}
for tree in trees:
for line in tree.split('\n'):
Expand Down
4 changes: 3 additions & 1 deletion python-package/xgboost/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def plot_importance(booster, ax=None, height=0.2,
raise ValueError('tree must be Booster, XGBModel or dict instance')

if not importance:
raise ValueError('Booster.get_score() results in empty')
raise ValueError(
'Booster.get_score() results in empty. ' +
'This maybe caused by having all trees as decision dumps.')

tuples = [(k, importance[k]) for k in importance]
if max_num_features is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/common/observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#include "xgboost/base.h"
#include "xgboost/tree_model.h"

#if defined(XGBOOST_STRICT_R_MODE)
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
#define OBSERVER_PRINT LOG(INFO)
#define OBSERVER_ENDL ""
#define OBSERVER_NEWLINE ""
#else
#define OBSERVER_PRINT std::cout
#define OBSERVER_PRINT std::cout << std::setprecision(17)
#define OBSERVER_ENDL std::endl
#define OBSERVER_NEWLINE "\n"
#endif // defined(XGBOOST_STRICT_R_MODE)
Expand Down
4 changes: 2 additions & 2 deletions src/data/ellpack_page_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ bool EllpackPageSource::Next() {
EllpackPage& EllpackPageSource::Value() {
LOG(FATAL) << "Internal Error: "
"XGBoost is not compiled with CUDA but EllpackPageSource is required";
EllpackPage* page;
EllpackPage* page { nullptr };
return *page;
}

const EllpackPage& EllpackPageSource::Value() const {
LOG(FATAL) << "Internal Error: "
"XGBoost is not compiled with CUDA but EllpackPageSource is required";
EllpackPage* page;
EllpackPage* page { nullptr };
return *page;
}

Expand Down
1 change: 1 addition & 0 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ class LearnerImpl : public Learner {

monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &predt, true);
TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
monitor_.Stop("PredictRaw");

monitor_.Start("GetGradient");
Expand Down
1 change: 1 addition & 0 deletions src/tree/gpu_hist/gradient_based_sampler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class SamplingStrategy {
public:
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
virtual GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) = 0;
virtual ~SamplingStrategy() = default;
};

/*! \brief No sampling in in-memory mode. */
Expand Down
184 changes: 184 additions & 0 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#include <thrust/reduce.h>
#include <thrust/iterator/transform_iterator.h>
#include <algorithm>
#include <ctgmath>
#include <limits>

#include "xgboost/base.h"
#include "row_partitioner.cuh"

#include "histogram.cuh"

#include "../../data/ellpack_page.cuh"
#include "../../common/device_helpers.cuh"

namespace xgboost {
namespace tree {
// Following 2 functions are slightly modifed version of fbcuda.

/* \brief Constructs a rounding factor used to truncate elements in a sum such that the
sum of the truncated elements is the same no matter what the order of the sum is.
* Algorithm 5: Reproducible Sequential Sum in 'Fast Reproducible Floating-Point
* Summation' by Demmel and Nguyen
* In algorithm 5 the bound is calculated as $max(|v_i|) * n$. Here we use the bound
*
* \begin{equation}
* max( fl(\sum^{V}_{v_i>0}{v_i}), fl(\sum^{V}_{v_i<0}|v_i|) )
* \end{equation}
*
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
*/
template <typename T>
DEV_INLINE __host__ T CreateRoundingFactor(T max_abs, int n) {
T delta = max_abs / (static_cast<T>(1.0) - 2 * n * std::numeric_limits<T>::epsilon());

// Calculate ceil(log_2(delta)).
// frexpf() calculates exp and returns `x` such that
// delta = x * 2^exp, where `x` in (-1.0, -0.5] U [0.5, 1).
// Because |x| < 1, exp is exactly ceil(log_2(delta)).
int exp;
std::frexp(delta, &exp);

// return M = 2 ^ ceil(log_2(delta))
return std::ldexp(static_cast<T>(1.0), exp);
}

namespace {
struct Pair {
GradientPair first;
GradientPair second;
};
DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
return {lhs.first + rhs.first, lhs.second + rhs.second};
}
} // anonymous namespace

struct Clip : public thrust::unary_function<GradientPair, Pair> {
static DEV_INLINE float Pclip(float v) {
return v > 0 ? v : 0;
}
static DEV_INLINE float Nclip(float v) {
return v < 0 ? abs(v) : 0;
}

DEV_INLINE Pair operator()(GradientPair x) const {
auto pg = Pclip(x.GetGrad());
auto ph = Pclip(x.GetHess());

auto ng = Nclip(x.GetGrad());
auto nh = Nclip(x.GetHess());

return { GradientPair{ pg, ph }, GradientPair{ ng, nh } };
}
};

template <typename GradientSumT>
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair) {
using T = typename GradientSumT::ValueT;
dh::XGBCachingDeviceAllocator<char> alloc;

thrust::device_ptr<GradientPair const> gpair_beg {gpair.data()};
thrust::device_ptr<GradientPair const> gpair_end {gpair.data() + gpair.size()};
auto beg = thrust::make_transform_iterator(gpair_beg, Clip());
auto end = thrust::make_transform_iterator(gpair_end, Clip());
Pair p = thrust::reduce(thrust::cuda::par(alloc), beg, end, Pair{});
GradientPair positive_sum {p.first}, negative_sum {p.second};

auto histogram_rounding = GradientSumT {
CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()),
gpair.size()),
CreateRoundingFactor<T>(std::max(positive_sum.GetHess(), negative_sum.GetHess()),
gpair.size()) };
return histogram_rounding;
}

template GradientPairPrecise CreateRoundingFactor(common::Span<GradientPair const> gpair);
template GradientPair CreateRoundingFactor(common::Span<GradientPair const> gpair);

template <typename GradientSumT>
__global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* __restrict__ d_node_hist,
const GradientPair* __restrict__ d_gpair,
size_t n_elements,
GradientSumT const rounding,
bool use_shared_memory_histograms) {
using T = typename GradientSumT::ValueT;
extern __shared__ char smem[];
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
if (use_shared_memory_histograms) {
dh::BlockFill(smem_arr, matrix.info.n_bins, GradientSumT());
__syncthreads();
}
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.info.row_stride];
int gidx =
matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride];
if (gidx != matrix.info.n_bins) {
GradientSumT truncated {
TruncateWithRoundingFactor<T>(rounding.GetGrad(), d_gpair[ridx].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(), d_gpair[ridx].GetHess()),
};
// If we are not using shared memory, accumulate the values directly into
// global memory
GradientSumT* atomic_add_ptr =
use_shared_memory_histograms ? smem_arr : d_node_hist;
dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated);
}
}

if (use_shared_memory_histograms) {
// Write shared memory back to global memory
__syncthreads();
for (auto i : dh::BlockStrideRange(static_cast<size_t>(0), matrix.info.n_bins)) {
GradientSumT truncated {
TruncateWithRoundingFactor<T>(rounding.GetGrad(), smem_arr[i].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(), smem_arr[i].GetHess()),
};
dh::AtomicAddGpair(d_node_hist + i, truncated);
}
}
}

template <typename GradientSumT>
void BuildGradientHistogram(EllpackMatrix const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> d_ridx,
common::Span<GradientSumT> histogram,
GradientSumT rounding, bool shared) {
const size_t smem_size =
shared
? sizeof(GradientSumT) * matrix.info.n_bins
: 0;
auto n_elements = d_ridx.size() * matrix.info.row_stride;

uint32_t items_per_thread = 8;
uint32_t block_threads = 256;
auto grid_size = static_cast<uint32_t>(
common::DivRoundUp(n_elements, items_per_thread * block_threads));
dh::LaunchKernel {grid_size, block_threads, smem_size} (
SharedMemHistKernel<GradientSumT>,
matrix, d_ridx, histogram.data(), gpair.data(), n_elements,
rounding, shared);
}

template void BuildGradientHistogram<GradientPair>(
EllpackMatrix const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPair> histogram,
GradientPair rounding, bool shared);

template void BuildGradientHistogram<GradientPairPrecise>(
EllpackMatrix const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPairPrecise> histogram,
GradientPairPrecise rounding, bool shared);
} // namespace tree
} // namespace xgboost
29 changes: 29 additions & 0 deletions src/tree/gpu_hist/histogram.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#ifndef HISTOGRAM_CUH_
#define HISTOGRAM_CUH_
#include <thrust/transform.h>
#include "../../data/ellpack_page.cuh"

namespace xgboost {
namespace tree {

template <typename GradientSumT>
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair);

template <typename T>
DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x) {
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
}

template <typename GradientSumT>
void BuildGradientHistogram(EllpackMatrix const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientSumT> histogram,
GradientSumT rounding, bool shared);
} // namespace tree
} // namespace xgboost

#endif // HISTOGRAM_CUH_
11 changes: 10 additions & 1 deletion src/tree/updater_gpu_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ struct DeviceSplitCandidate {
}
}
XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }

friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) {
os << "loss_chg:" << c.loss_chg << ", "
<< "dir: " << c.dir << ", "
<< "findex: " << c.findex << ", "
<< "fvalue: " << c.fvalue << ", "
<< "left sum: " << c.left_sum << ", "
<< "right sum: " << c.right_sum << std::endl;
return os;
}
};

struct DeviceSplitCandidateReduceOp {
Expand Down Expand Up @@ -186,6 +196,5 @@ struct SumCallbackOp {
XGBOOST_DEVICE inline int MaxNodesDepth(int depth) {
return (1 << (depth + 1)) - 1;
}

} // namespace tree
} // namespace xgboost

0 comments on commit 1065c72

Please sign in to comment.