Skip to content

Commit

Permalink
Add helpers for implementing uniform integer distribution
Browse files Browse the repository at this point in the history
* Utility for extended mult n x n bits -> 2n bits
* Utility to adapt output from URBG to target (unsigned) integral
  type
* Utility to reorder signed values into unsigned type while keeping
  the order.
  • Loading branch information
horenmar committed Dec 10, 2023
1 parent ab1b079 commit 04a829b
Show file tree
Hide file tree
Showing 8 changed files with 384 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ set(IMPL_HEADERS
${SOURCES_DIR}/internal/catch_preprocessor.hpp
${SOURCES_DIR}/internal/catch_preprocessor_remove_parens.hpp
${SOURCES_DIR}/internal/catch_random_floating_point_helpers.hpp
${SOURCES_DIR}/internal/catch_random_integer_helpers.hpp
${SOURCES_DIR}/internal/catch_random_number_generator.hpp
${SOURCES_DIR}/internal/catch_random_seed_generation.hpp
${SOURCES_DIR}/internal/catch_reporter_registry.hpp
Expand Down
1 change: 1 addition & 0 deletions src/catch2/catch_all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
#include <catch2/internal/catch_preprocessor_internal_stringify.hpp>
#include <catch2/internal/catch_preprocessor_remove_parens.hpp>
#include <catch2/internal/catch_random_floating_point_helpers.hpp>
#include <catch2/internal/catch_random_integer_helpers.hpp>
#include <catch2/internal/catch_random_number_generator.hpp>
#include <catch2/internal/catch_random_seed_generation.hpp>
#include <catch2/internal/catch_reporter_registry.hpp>
Expand Down
202 changes: 202 additions & 0 deletions src/catch2/internal/catch_random_integer_helpers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@

// Copyright Catch2 Authors
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE.txt or copy at
// https://www.boost.org/LICENSE_1_0.txt)

// SPDX-License-Identifier: BSL-1.0

#ifndef CATCH_RANDOM_INTEGER_HELPERS_HPP_INCLUDED
#define CATCH_RANDOM_INTEGER_HELPERS_HPP_INCLUDED

#include <climits>
#include <cstddef>
#include <cstdint>
#include <type_traits>

namespace Catch {
namespace Detail {

template <std::size_t>
struct SizedUnsignedType;
#define SizedUnsignedTypeHelper( TYPE ) \
template <> \
struct SizedUnsignedType<sizeof( TYPE )> { \
using type = TYPE; \
}

SizedUnsignedTypeHelper( std::uint8_t );
SizedUnsignedTypeHelper( std::uint16_t );
SizedUnsignedTypeHelper( std::uint32_t );
SizedUnsignedTypeHelper( std::uint64_t );
#undef SizedUnsignedTypeHelper

template <std::size_t sz>
using SizedUnsignedType_t = typename SizedUnsignedType<sz>::type;

template <typename T>
using DoubleWidthUnsignedType_t = SizedUnsignedType_t<2 * sizeof( T )>;

template <typename T>
struct ExtendedMultResult {
T upper;
T lower;
friend bool operator==( ExtendedMultResult const& lhs,
ExtendedMultResult const& rhs ) {
return lhs.upper == rhs.upper && lhs.lower == rhs.lower;
}
};

// Returns 128 bit result of multiplying lhs and rhs
constexpr ExtendedMultResult<std::uint64_t>
extendedMult( std::uint64_t lhs, std::uint64_t rhs ) {
// We use the simple long multiplication approach for
// correctness, we can use platform specific builtins
// for performance later.

// Split the lhs and rhs into two 32bit "digits", so that we can
// do 64 bit arithmetic to handle carry bits.
// 32b 32b 32b 32b
// lhs L1 L2
// * rhs R1 R2
// ------------------------
// | R2 * L2 |
// | R2 * L1 |
// | R1 * L2 |
// | R1 * L1 |
// -------------------------
// | a | b | c | d |

#define CarryBits( x ) ( x >> 32 )
#define Digits( x ) ( x & 0xFF'FF'FF'FF )

auto r2l2 = Digits( rhs ) * Digits( lhs );
auto r2l1 = Digits( rhs ) * CarryBits( lhs );
auto r1l2 = CarryBits( rhs ) * Digits( lhs );
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs );

// Sum to columns first
auto d = Digits( r2l2 );
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
auto a = CarryBits( r1l1 );

// Propagate carries between columns
c += CarryBits( d );
b += CarryBits( c );
a += CarryBits( b );

// Remove the used carries
c = Digits( c );
b = Digits( b );
a = Digits( a );

#undef CarryBits
#undef Digits

return {
a << 32 | b, // upper 64 bits
c << 32 | d // lower 64 bits
};
}

template <typename UInt>
constexpr ExtendedMultResult<UInt> extendedMult( UInt lhs, UInt rhs ) {
static_assert( std::is_unsigned<UInt>::value,
"extendedMult can only handle unsigned integers" );
static_assert( sizeof( UInt ) < sizeof( std::uint64_t ),
"Generic extendedMult can only handle types smaller "
"than uint64_t" );
using WideType = DoubleWidthUnsignedType_t<UInt>;

auto result = WideType( lhs ) * WideType( rhs );
return {
static_cast<UInt>( result >> ( CHAR_BIT * sizeof( UInt ) ) ),
static_cast<UInt>( result & UInt( -1 ) ) };
}


template <typename TargetType,
typename Generator>
std::enable_if_t<sizeof(typename Generator::result_type) >= sizeof(TargetType),
TargetType> fillBitsFrom(Generator& gen) {
using gresult_type = typename Generator::result_type;
static_assert( std::is_unsigned<TargetType>::value, "Only unsigned integers are supported" );
static_assert( Generator::min() == 0 &&
Generator::max() == static_cast<gresult_type>( -1 ),
"Generator must be able to output all numbers in its result type (effectively it must be a random bit generator)" );

// We want to return the top bits from a generator, as they are
// usually considered higher quality.
constexpr auto generated_bits = sizeof( gresult_type ) * CHAR_BIT;
constexpr auto return_bits = sizeof( TargetType ) * CHAR_BIT;

return static_cast<TargetType>( gen() >>
( generated_bits - return_bits) );
}

template <typename TargetType,
typename Generator>
std::enable_if_t<sizeof(typename Generator::result_type) < sizeof(TargetType),
TargetType> fillBitsFrom(Generator& gen) {
using gresult_type = typename Generator::result_type;
static_assert( std::is_unsigned<TargetType>::value,
"Only unsigned integers are supported" );
static_assert( Generator::min() == 0 &&
Generator::max() == static_cast<gresult_type>( -1 ),
"Generator must be able to output all numbers in its result type (effectively it must be a random bit generator)" );

constexpr auto generated_bits = sizeof( gresult_type ) * CHAR_BIT;
constexpr auto return_bits = sizeof( TargetType ) * CHAR_BIT;
std::size_t filled_bits = 0;
TargetType ret = 0;
do {
ret <<= generated_bits;
ret |= gen();
filled_bits += generated_bits;
} while ( filled_bits < return_bits );

return ret;
}

/*
* Transposes numbers into unsigned type while keeping their ordering
*
* This means that signed types are changed so that the ordering is
* [INT_MIN, ..., -1, 0, ..., INT_MAX], rather than order we would
* get by simple casting ([0, ..., INT_MAX, INT_MIN, ..., -1])
*/
template <typename OriginalType, typename UnsignedType>
std::enable_if_t<std::is_signed<OriginalType>::value, UnsignedType>
transposeToNaturalOrder( UnsignedType in ) {
static_assert(
sizeof( OriginalType ) == sizeof( UnsignedType ),
"reordering requires the same sized types on both sides" );
static_assert( std::is_unsigned<UnsignedType>::value,
"Input type must be unsigned" );
// Assuming 2s complement (standardized in current C++), the
// positive and negative numbers are already internally ordered,
// and their difference is in the top bit. Swapping it orders
// them the desired way.
constexpr auto highest_bit =
UnsignedType( 1 ) << ( sizeof( UnsignedType ) * CHAR_BIT - 1 );
return static_cast<UnsignedType>( in ^ highest_bit );
}



template <typename OriginalType,
typename UnsignedType>
std::enable_if_t<std::is_unsigned<OriginalType>::value, UnsignedType>
transposeToNaturalOrder(UnsignedType in) {
static_assert(
sizeof( OriginalType ) == sizeof( UnsignedType ),
"reordering requires the same sized types on both sides" );
static_assert( std::is_unsigned<UnsignedType>::value, "Input type must be unsigned" );
// No reordering is needed for unsigned -> unsigned
return in;
}
} // namespace Detail
} // namespace Catch

#endif // CATCH_RANDOM_INTEGER_HELPERS_HPP_INCLUDED
1 change: 1 addition & 0 deletions src/catch2/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ internal_headers = [
'internal/catch_preprocessor_internal_stringify.hpp',
'internal/catch_preprocessor_remove_parens.hpp',
'internal/catch_random_floating_point_helpers.hpp',
'internal/catch_random_integer_helpers.hpp',
'internal/catch_random_number_generator.hpp',
'internal/catch_random_seed_generation.hpp',
'internal/catch_reporter_registry.hpp',
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ set(TEST_SOURCES
${SELF_TEST_DIR}/IntrospectiveTests/Details.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/FloatingPoint.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/GeneratorsImpl.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/Integer.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/InternalBenchmark.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/Json.tests.cpp
${SELF_TEST_DIR}/IntrospectiveTests/Parse.tests.cpp
Expand Down
150 changes: 150 additions & 0 deletions tests/SelfTest/IntrospectiveTests/Integer.tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@

// Copyright Catch2 Authors
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE.txt or copy at
// https://www.boost.org/LICENSE_1_0.txt)

// SPDX-License-Identifier: BSL-1.0

#include <catch2/catch_test_macros.hpp>
#include <catch2/internal/catch_random_integer_helpers.hpp>

namespace {
template <typename Int>
static void
CommutativeMultCheck( Int a, Int b, Int upper_result, Int lower_result ) {
using Catch::Detail::extendedMult;
using Catch::Detail::ExtendedMultResult;
CHECK( extendedMult( a, b ) ==
ExtendedMultResult<Int>{ upper_result, lower_result } );
CHECK( extendedMult( b, a ) ==
ExtendedMultResult<Int>{ upper_result, lower_result } );
}
} // namespace

TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
// a x 0 == 0
CommutativeMultCheck<uint64_t>( 0x1234'5678'9ABC'DEFF, 0, 0, 0 );

// bit carried from low half to upper half
CommutativeMultCheck<uint64_t>( uint64_t( 1 ) << 63, 2, 1, 0 );

// bits in upper half on one side, bits in lower half on other side
CommutativeMultCheck<uint64_t>( 0xcdcd'dcdc'0000'0000,
0x0000'0000'aeae'aeae,
0x0000'0000'8c6e'5a77,
0x7391'a588'0000'0000 );

// Some input numbers without interesting patterns
CommutativeMultCheck<uint64_t>( 0xaaaa'aaaa'aaaa'aaaa,
0xbbbb'bbbb'bbbb'bbbb,
0x7d27'd27d'27d2'7d26,
0xd82d'82d8'2d82'd82e );

CommutativeMultCheck<uint64_t>( 0x7d27'd27d'27d2'7d26,
0xd82d'82d8'2d82'd82e,
0x69af'd991'8256'b953,
0x8724'8909'fcb6'8cd4 );

CommutativeMultCheck<uint64_t>( 0xdead'beef'dead'beef,
0xfeed'feed'feed'feef,
0xddbf'680b'2b0c'b558,
0x7a36'b06f'2ce9'6321 );

CommutativeMultCheck<uint64_t>( 0xddbf'680b'2b0c'b558,
0x7a36'b06f'2ce9'6321,
0x69dc'96c9'294b'fc7f,
0xd038'39fa'a3dc'6858 );

CommutativeMultCheck<uint64_t>( 0x61c8'8646'80b5'83eb,
0x61c8'8646'80b5'83eb,
0x2559'92d3'8220'8bbe,
0xdf44'2d22'ce48'59b9 );
}

TEST_CASE( "SizedUnsignedType helpers", "[integer][approvals]" ) {
using Catch::Detail::SizedUnsignedType_t;
using Catch::Detail::DoubleWidthUnsignedType_t;

STATIC_REQUIRE( sizeof( SizedUnsignedType_t<1> ) == 1 );
STATIC_REQUIRE( sizeof( SizedUnsignedType_t<2> ) == 2 );
STATIC_REQUIRE( sizeof( SizedUnsignedType_t<4> ) == 4 );
STATIC_REQUIRE( sizeof( SizedUnsignedType_t<8> ) == 8 );

STATIC_REQUIRE( sizeof( DoubleWidthUnsignedType_t<std::uint8_t> ) == 2 );
STATIC_REQUIRE( std::is_unsigned<DoubleWidthUnsignedType_t<std::uint8_t>>::value );
STATIC_REQUIRE( sizeof( DoubleWidthUnsignedType_t<std::uint16_t> ) == 4 );
STATIC_REQUIRE( std::is_unsigned<DoubleWidthUnsignedType_t<std::uint16_t>>::value );
STATIC_REQUIRE( sizeof( DoubleWidthUnsignedType_t<std::uint32_t> ) == 8 );
STATIC_REQUIRE( std::is_unsigned<DoubleWidthUnsignedType_t<std::uint32_t>>::value );
}

TEST_CASE( "extendedMult 32x32", "[integer][approvals]" ) {
// a x 0 == 0
CommutativeMultCheck<uint32_t>( 0x1234'5678, 0, 0, 0 );

// bit carried from low half to upper half
CommutativeMultCheck<uint32_t>( uint32_t(1) << 31, 2, 1, 0 );

// bits in upper half on one side, bits in lower half on other side
CommutativeMultCheck<uint32_t>( 0xdcdc'0000, 0x0000'aabb, 0x0000'934b, 0x6cb4'0000 );

// Some input numbers without interesting patterns
CommutativeMultCheck<uint32_t>(
0xaaaa'aaaa, 0xbbbb'bbbb, 0x7d27'd27c, 0x2d82'd82e );

CommutativeMultCheck<uint32_t>(
0x7d27'd27c, 0x2d82'd82e, 0x163f'f7e8, 0xc5b8'7248 );

CommutativeMultCheck<uint32_t>(
0xdead'beef, 0xfeed'feed, 0xddbf'6809, 0x6f8d'e543 );

CommutativeMultCheck<uint32_t>(
0xddbf'6809, 0x6f8d'e543, 0x60a0'e71e, 0x751d'475b );
}

TEST_CASE( "extendedMult 8x8", "[integer][approvals]" ) {
// a x 0 == 0
CommutativeMultCheck<uint8_t>( 0xcd, 0, 0, 0 );

// bit carried from low half to upper half
CommutativeMultCheck<uint8_t>( uint8_t( 1 ) << 7, 2, 1, 0 );

// bits in upper half on one side, bits in lower half on other side
CommutativeMultCheck<uint8_t>( 0x80, 0x03, 0x01, 0x80 );

// Some input numbers without interesting patterns
CommutativeMultCheck<uint8_t>( 0xaa, 0xbb, 0x7c, 0x2e );
CommutativeMultCheck<uint8_t>( 0x7c, 0x2e, 0x16, 0x48 );
CommutativeMultCheck<uint8_t>( 0xdc, 0xcd, 0xb0, 0x2c );
CommutativeMultCheck<uint8_t>( 0xb0, 0x2c, 0x1e, 0x40 );
}


TEST_CASE( "negative and positive signed integers keep their order after transposeToNaturalOrder",
"[integer][approvals]") {
using Catch::Detail::transposeToNaturalOrder;
int32_t negative( -1 );
int32_t positive( 1 );
uint32_t adjusted_negative =
transposeToNaturalOrder<int32_t>( static_cast<uint32_t>( negative ) );
uint32_t adjusted_positive =
transposeToNaturalOrder<int32_t>( static_cast<uint32_t>( positive ) );
REQUIRE( adjusted_negative < adjusted_positive );
REQUIRE( adjusted_positive - adjusted_negative == 2 );

// Conversion has to be reversible
REQUIRE( negative == static_cast<int32_t>( transposeToNaturalOrder<int32_t>(
adjusted_negative ) ) );
REQUIRE( positive == static_cast<int32_t>( transposeToNaturalOrder<int32_t>(
adjusted_positive ) ) );
}

TEST_CASE( "unsigned integers are unchanged by transposeToNaturalOrder",
"[integer][approvals]") {
using Catch::Detail::transposeToNaturalOrder;
uint32_t max = std::numeric_limits<uint32_t>::max();
uint32_t zero = 0;
REQUIRE( max == transposeToNaturalOrder<uint32_t>( max ) );
REQUIRE( zero == transposeToNaturalOrder<uint32_t>( zero ) );
}
Loading

0 comments on commit 04a829b

Please sign in to comment.