283 changes: 233 additions & 50 deletions include/boost/random/discrete_distribution.hpp
Expand Up @@ -20,7 +20,7 @@
#include <iterator>
#include <boost/assert.hpp>
#include <boost/random/uniform_01.hpp>
#include <boost/random/uniform_int.hpp>
#include <boost/random/uniform_int_distribution.hpp>
#include <boost/random/detail/config.hpp>
#include <boost/random/detail/operators.hpp>
#include <boost/random/detail/vector_io.hpp>
Expand All @@ -36,6 +36,186 @@

namespace boost {
namespace random {
namespace detail {

template<class IntType, class WeightType>
struct integer_alias_table {
WeightType get_weight(IntType bin) const {
WeightType result = _average;
if(bin < _excess) ++result;
return result;
}
template<class Iter>
WeightType init_average(Iter begin, Iter end) {
WeightType weight_average = 0;
IntType excess = 0;
IntType n = 0;
// weight_average * n + excess == current partial sum
// This is a bit messy, but it's guaranteed not to overflow
for(Iter iter = begin; iter != end; ++iter) {
++n;
if(*iter < weight_average) {
WeightType diff = weight_average - *iter;
weight_average -= diff / n;
if(diff % n > excess) {
--weight_average;
excess += n - diff % n;
} else {
excess -= diff % n;
}
} else {
WeightType diff = *iter - weight_average;
weight_average += diff / n;
if(diff % n < n - excess) {
excess += diff % n;
} else {
++weight_average;
excess -= n - diff % n;
}
}
}
_alias_table.resize(static_cast<std::size_t>(n));
_average = weight_average;
_excess = excess;
return weight_average;
}
void init_empty()
{
_alias_table.clear();
_alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
static_cast<IntType>(0)));
_average = static_cast<WeightType>(1);
_excess = static_cast<IntType>(0);
}
bool operator==(const integer_alias_table& other) const
{
return _alias_table == other._alias_table &&
_average == other._average && _excess == other._excess;
}
static WeightType normalize(WeightType val, WeightType average)
{
return val;
}
static void normalize(std::vector<WeightType>&) {}
template<class URNG>
WeightType test(URNG &urng) const
{
return uniform_int_distribution<WeightType>(0, _average)(urng);
}
bool accept(IntType result, WeightType val) const
{
return result < _excess || val < _average;
}
static WeightType try_get_sum(const std::vector<WeightType>& weights)
{
WeightType result = static_cast<WeightType>(0);
for(typename std::vector<WeightType>::const_iterator
iter = weights.begin(), end = weights.end();
iter != end; ++iter)
{
if((std::numeric_limits<WeightType>::max)() - result > *iter) {
return static_cast<WeightType>(0);
}
result += *iter;
}
return result;
}
template<class URNG>
static WeightType generate_in_range(URNG &urng, WeightType max)
{
return uniform_int_distribution<WeightType>(
static_cast<WeightType>(0), max-1)(urng);
}
typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
alias_table_t _alias_table;
WeightType _average;
IntType _excess;
};

template<class IntType, class WeightType>
struct real_alias_table {
WeightType get_weight(IntType) const
{
return WeightType(1.0);
}
template<class Iter>
WeightType init_average(Iter first, Iter last)
{
std::size_t size = std::distance(first, last);
WeightType weight_sum =
std::accumulate(first, last, static_cast<WeightType>(0));
_alias_table.resize(size);
return weight_sum / size;
}
void init_empty()
{
_alias_table.clear();
_alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
static_cast<IntType>(0)));
}
bool operator==(const real_alias_table& other) const
{
return _alias_table == other._alias_table;
}
static WeightType normalize(WeightType val, WeightType average)
{
return val / average;
}
static void normalize(std::vector<WeightType>& weights)
{
WeightType sum =
std::accumulate(weights.begin(), weights.end(),
static_cast<WeightType>(0));
for(typename std::vector<WeightType>::iterator
iter = weights.begin(),
end = weights.end();
iter != end; ++iter)
{
*iter /= sum;
}
}
template<class URNG>
WeightType test(URNG &urng) const
{
return uniform_01<WeightType>()(urng);
}
bool accept(IntType, WeightType) const
{
return true;
}
static WeightType try_get_sum(const std::vector<WeightType>& weights)
{
return static_cast<WeightType>(1);
}
template<class URNG>
static WeightType generate_in_range(URNG &urng, WeightType)
{
return uniform_01<WeightType>()(urng);
}
typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
alias_table_t _alias_table;
};

template<bool IsIntegral>
struct select_alias_table;

template<>
struct select_alias_table<true> {
template<class IntType, class WeightType>
struct apply {
typedef integer_alias_table<IntType, WeightType> type;
};
};

template<>
struct select_alias_table<false> {
template<class IntType, class WeightType>
struct apply {
typedef real_alias_table<IntType, WeightType> type;
};
};

}

/**
* The class @c discrete_distribution models a \random_distribution.
Expand Down Expand Up @@ -155,16 +335,7 @@ class discrete_distribution {
{}
void normalize()
{
WeightType sum =
std::accumulate(_probabilities.begin(), _probabilities.end(),
static_cast<WeightType>(0));
for(typename std::vector<WeightType>::iterator
iter = _probabilities.begin(),
end = _probabilities.end();
iter != end; ++iter)
{
*iter /= sum;
}
impl_type::normalize(_probabilities);
}
std::vector<WeightType> _probabilities;
/// @endcond
Expand All @@ -176,8 +347,7 @@ class discrete_distribution {
*/
discrete_distribution()
{
_alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
static_cast<IntType>(0)));
_impl.init_empty();
}
/**
* Constructs a discrete_distribution from an iterator range.
Expand Down Expand Up @@ -257,13 +427,17 @@ class discrete_distribution {
template<class URNG>
IntType operator()(URNG& urng) const
{
BOOST_ASSERT(!_alias_table.empty());
WeightType test = uniform_01<WeightType>()(urng);
IntType result = uniform_int<IntType>((min)(), (max)())(urng);
if(test < _alias_table[result].first) {
BOOST_ASSERT(!_impl._alias_table.empty());
IntType result;
WeightType test;
do {
result = uniform_int_distribution<IntType>((min)(), (max)())(urng);
test = _impl.test(urng);
} while(!_impl.accept(result, test));
if(test < _impl._alias_table[result].first) {
return result;
} else {
return(_alias_table[result].second);
return(_impl._alias_table[result].second);
}
}

Expand All @@ -274,28 +448,36 @@ class discrete_distribution {
template<class URNG>
IntType operator()(URNG& urng, const param_type& parm) const
{
while(true) {
WeightType val = uniform_01<WeightType>()(urng);
if(WeightType limit = impl_type::try_get_sum(parm._probabilities)) {
WeightType val = impl_type::generate_in_range(urng, limit);
WeightType sum = 0;
std::size_t result = 0;
for(typename std::vector<WeightType>::const_iterator
iter = parm._probabilities.begin(),
end = parm._probabilities.end();
iter = parm._probabilities.begin(),
end = parm._probabilities.end();
iter != end; ++iter, ++result)
{
sum += *iter;
if(sum > val) {
return result;
}
}
// This shouldn't be reachable, but round-off error
// can prevent any match from being found when val is
// very close to 1.
return static_cast<IntType>(parm._probabilities.size() - 1);
} else {
// WeightType is integral and sum(parm._probabilities)
// would overflow. Just use the easy solution.
return discrete_distribution(parm)(urng);
}
}

/** Returns the smallest value that the distribution can produce. */
result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const { return 0; }
/** Returns the largest value that the distribution can produce. */
result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const
{ return static_cast<result_type>(_alias_table.size() - 1); }
{ return static_cast<result_type>(_impl._alias_table.size() - 1); }

/**
* Returns a vector containing the probabilities of each
Expand All @@ -307,22 +489,24 @@ class discrete_distribution {
* @endcode
*
* the vector, p will contain {0.1, 0.4, 0.5}.
*
* If @c WeightType is integral, then the weights
* will be returned unchanged.
*/
std::vector<WeightType> probabilities() const
{
std::vector<WeightType> result(_alias_table.size());
const WeightType mean =
static_cast<WeightType>(1) / _alias_table.size();
std::vector<WeightType> result(_impl._alias_table.size());
std::size_t i = 0;
for(typename alias_table_t::const_iterator
iter = _alias_table.begin(),
end = _alias_table.end();
for(typename impl_type::alias_table_t::const_iterator
iter = _impl._alias_table.begin(),
end = _impl._alias_table.end();
iter != end; ++iter, ++i)
{
WeightType val = iter->first * mean;
WeightType val = iter->first;
result[i] += val;
result[iter->second] += mean - val;
result[iter->second] += _impl.get_weight(i) - val;
}
impl_type::normalize(result);
return(result);
}

Expand Down Expand Up @@ -366,7 +550,7 @@ class discrete_distribution {
*/
BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(discrete_distribution, lhs, rhs)
{
return lhs._alias_table == rhs._alias_table;
return lhs._impl == rhs._impl;
}
/**
* Returns true if the two distributions may return different
Expand All @@ -389,59 +573,58 @@ class discrete_distribution {
{
std::vector<std::pair<WeightType, IntType> > below_average;
std::vector<std::pair<WeightType, IntType> > above_average;
std::size_t size = std::distance(first, last);
WeightType weight_sum =
std::accumulate(first, last, static_cast<WeightType>(0));
WeightType weight_average = weight_sum / size;
WeightType weight_average = _impl.init_average(first, last);
WeightType normalized_average = _impl.get_weight(0);
std::size_t i = 0;
for(; first != last; ++first, ++i) {
WeightType val = *first / weight_average;
WeightType val = impl_type::normalize(*first, weight_average);
std::pair<WeightType, IntType> elem(val, static_cast<IntType>(i));
if(val < static_cast<WeightType>(1)) {
if(val < normalized_average) {
below_average.push_back(elem);
} else {
above_average.push_back(elem);
}
}

_alias_table.resize(size);
typename alias_table_t::iterator
typename impl_type::alias_table_t::iterator
b_iter = below_average.begin(),
b_end = below_average.end(),
a_iter = above_average.begin(),
a_end = above_average.end()
;
while(b_iter != b_end && a_iter != a_end) {
_alias_table[b_iter->second] =
_impl._alias_table[b_iter->second] =
std::make_pair(b_iter->first, a_iter->second);
a_iter->first -= (static_cast<WeightType>(1) - b_iter->first);
if(a_iter->first < static_cast<WeightType>(1)) {
a_iter->first -= (_impl.get_weight(b_iter->second) - b_iter->first);
if(a_iter->first < normalized_average) {
*b_iter = *a_iter++;
} else {
++b_iter;
}
}
for(; b_iter != b_end; ++b_iter) {
_alias_table[b_iter->second].first = static_cast<WeightType>(1);
_impl._alias_table[b_iter->second].first =
_impl.get_weight(b_iter->second);
}
for(; a_iter != a_end; ++a_iter) {
_alias_table[a_iter->second].first = static_cast<WeightType>(1);
_impl._alias_table[a_iter->second].first =
_impl.get_weight(a_iter->second);
}
}
template<class Iter>
void init(Iter first, Iter last)
{
if(first == last) {
_alias_table.clear();
_alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
static_cast<IntType>(0)));
_impl.init_empty();
} else {
typename std::iterator_traits<Iter>::iterator_category category;
init(first, last, category);
}
}
typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
alias_table_t _alias_table;
typedef typename detail::select_alias_table<
(::boost::is_integral<WeightType>::value)
>::template apply<IntType, WeightType>::type impl_type;
impl_type _impl;
/// @endcond
};

Expand Down
2 changes: 1 addition & 1 deletion test/test_discrete.cpp
Expand Up @@ -92,7 +92,7 @@ bool handle_option(int& argc, char**& argv, char opt, T& value) {

int main(int argc, char** argv) {
int repeat = 10;
int max_n = 100000;
int max_n = 10000;
long long trials = 1000000ll;

if(argc > 0) {
Expand Down