Skip to content

Commit

Permalink
Reimplement pkg/random on top of math/rand/v2
Browse files Browse the repository at this point in the history
Now that math/rand no longer has a global Seed() method, it's become
safe for us to let FastThreadSafeGenerator call into
  • Loading branch information
EdSchouten committed Jun 22, 2024
1 parent 8abc648 commit b818f9d
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 63 deletions.
2 changes: 1 addition & 1 deletion pkg/eviction/rr_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func NewRRSet[T any]() Set[T] {
func (s *rrSet[T]) Insert(value T) {
// Insert element into a random location in the list, opening up
// space by moving an existing element to the end of the list.
index := s.generator.Intn(len(s.elements) + 1)
index := s.generator.IntN(len(s.elements) + 1)
if index == len(s.elements) {
s.elements = append(s.elements, value)
} else {
Expand Down
24 changes: 7 additions & 17 deletions pkg/random/crypto_thread_safe_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
crypto_rand "crypto/rand"
"encoding/binary"
"fmt"
math_rand "math/rand"
math_rand "math/rand/v2"
)

func mustCryptoRandRead(p []byte) (int, error) {
Expand All @@ -15,38 +15,28 @@ func mustCryptoRandRead(p []byte) (int, error) {
return n, nil
}

type cryptoSource64 struct{}
type cryptoSource struct{}

func (s cryptoSource64) Int63() int64 {
return int64(s.Uint64() >> 1)
}

func (s cryptoSource64) Uint64() uint64 {
func (s cryptoSource) Uint64() uint64 {
var b [8]byte
mustCryptoRandRead(b[:])
return binary.LittleEndian.Uint64(b[:])
}

func (s cryptoSource64) Seed(seed int64) {
panic("Crypto source cannot be seeded")
}

var _ math_rand.Source64 = cryptoSource64{}
var _ math_rand.Source = cryptoSource{}

type cryptoThreadSafeGenerator struct {
*math_rand.Rand
}

func (g cryptoThreadSafeGenerator) IsThreadSafe() {}
func (cryptoThreadSafeGenerator) IsThreadSafe() {}

func (g cryptoThreadSafeGenerator) Read(p []byte) (int, error) {
// Call into crypto_rand.Read() directly, as opposed to using
// math_rand.Rand.Read().
func (cryptoThreadSafeGenerator) Read(p []byte) (int, error) {
return mustCryptoRandRead(p)
}

// CryptoThreadSafeGenerator is an instance of ThreadSafeGenerator that is
// suitable for cryptographic purposes.
var CryptoThreadSafeGenerator ThreadSafeGenerator = cryptoThreadSafeGenerator{
Rand: math_rand.New(cryptoSource64{}),
Rand: math_rand.New(cryptoSource{}),
}
21 changes: 17 additions & 4 deletions pkg/random/fast_single_threaded_generator.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
package random

import (
"math/rand"
"math/rand/v2"
)

type fastSingleThreadedGenerator struct {
*rand.Rand
}

// NewFastSingleThreadedGenerator creates a new SingleThreadedGenerator
// that is not suitable for cryptographic purposes. The generator is
// randomly seeded.
func NewFastSingleThreadedGenerator() SingleThreadedGenerator {
return rand.New(
rand.NewSource(
int64(CryptoThreadSafeGenerator.Uint64())))
return fastSingleThreadedGenerator{
Rand: rand.New(
rand.NewPCG(
CryptoThreadSafeGenerator.Uint64(),
CryptoThreadSafeGenerator.Uint64(),
),
),
}
}

func (fastSingleThreadedGenerator) Read(p []byte) (int, error) {
return mustCryptoRandRead(p)
}
46 changes: 15 additions & 31 deletions pkg/random/fast_thread_safe_generator.go
Original file line number Diff line number Diff line change
@@ -1,55 +1,39 @@
package random

import (
"sync"
"math/rand/v2"
)

type fastThreadSafeGenerator struct {
lock sync.Mutex
generator SingleThreadedGenerator
}

func (g *fastThreadSafeGenerator) IsThreadSafe() {}
func (fastThreadSafeGenerator) IsThreadSafe() {}

func (g *fastThreadSafeGenerator) Float64() float64 {
g.lock.Lock()
defer g.lock.Unlock()
return g.generator.Float64()
func (fastThreadSafeGenerator) Float64() float64 {
return rand.Float64()
}

func (g *fastThreadSafeGenerator) Int63n(n int64) int64 {
g.lock.Lock()
defer g.lock.Unlock()
return g.generator.Int63n(n)
func (fastThreadSafeGenerator) Int64N(n int64) int64 {
return rand.Int64N(n)
}

func (g *fastThreadSafeGenerator) Intn(n int) int {
g.lock.Lock()
defer g.lock.Unlock()
return g.generator.Intn(n)
func (fastThreadSafeGenerator) IntN(n int) int {
return rand.IntN(n)
}

func (g *fastThreadSafeGenerator) Read(p []byte) (int, error) {
g.lock.Lock()
defer g.lock.Unlock()
return g.generator.Read(p)
func (fastThreadSafeGenerator) Read(p []byte) (int, error) {
return mustCryptoRandRead(p)
}

func (g *fastThreadSafeGenerator) Shuffle(n int, swap func(i, j int)) {
g.lock.Lock()
defer g.lock.Unlock()
g.generator.Shuffle(n, swap)
func (fastThreadSafeGenerator) Shuffle(n int, swap func(i, j int)) {
rand.Shuffle(n, swap)
}

func (g *fastThreadSafeGenerator) Uint64() uint64 {
g.lock.Lock()
defer g.lock.Unlock()
return g.generator.Uint64()
func (fastThreadSafeGenerator) Uint64() uint64 {
return rand.Uint64()
}

// FastThreadSafeGenerator is an instance of ThreadSafeGenerator that is
// not suitable for cryptographic purposes. The generator is randomly
// seeded on startup.
var FastThreadSafeGenerator ThreadSafeGenerator = &fastThreadSafeGenerator{
generator: NewFastSingleThreadedGenerator(),
}
var FastThreadSafeGenerator ThreadSafeGenerator = fastThreadSafeGenerator{}
9 changes: 3 additions & 6 deletions pkg/random/single_threaded_generator.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package random

import (
"math/rand"
"time"
)

Expand All @@ -12,9 +11,9 @@ type SingleThreadedGenerator interface {
// Generates a number in range [0.0, 1.0).
Float64() float64
// Generates a number in range [0, n), where n is of type int64.
Int63n(n int64) int64
Int64N(n int64) int64
// Generates a number in range [0, n), where n is of type int.
Intn(n int) int
IntN(n int) int
// Generates arbitrary bytes of data. This method is guaranteed
// to succeed.
Read(p []byte) (int, error)
Expand All @@ -24,9 +23,7 @@ type SingleThreadedGenerator interface {
Uint64() uint64
}

var _ SingleThreadedGenerator = (*rand.Rand)(nil)

// Duration that is randomly generated that lies between [0, maximum).
func Duration(generator SingleThreadedGenerator, maximum time.Duration) time.Duration {
return time.Duration(generator.Int63n(maximum.Nanoseconds())) * time.Nanosecond
return time.Duration(generator.Int64N(maximum.Nanoseconds())) * time.Nanosecond
}
8 changes: 4 additions & 4 deletions pkg/random/single_threaded_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ func TestSingleThreadedGenerator(t *testing.T) {
}
})

t.Run("Int63n", func(t *testing.T) {
t.Run("Int64N", func(t *testing.T) {
for i := 0; i < 100; i++ {
v := generator.Int63n(42)
v := generator.Int64N(42)
require.LessOrEqual(t, int64(0), v)
require.Greater(t, int64(42), v)
}
})

t.Run("Intn", func(t *testing.T) {
t.Run("IntN", func(t *testing.T) {
for i := 0; i < 100; i++ {
v := generator.Intn(42)
v := generator.IntN(42)
require.LessOrEqual(t, 0, v)
require.Greater(t, 42, v)
}
Expand Down

0 comments on commit b818f9d

Please sign in to comment.