Skip to content

Commit

Permalink
feat: method to set randomness source
Browse files Browse the repository at this point in the history
The internal mechanics of this are a bit inelegant, since unfortunately
the global randomness source is not exported, necessitating these nil
check methods instead.

The API here needs some user feedback. I believe the majority case will
want to set this once and not on a per-call basis (cf. the deprecated
PickSource method in the previous version), but that needs be validated.
  • Loading branch information
mroth committed Apr 17, 2024
1 parent a5fa29d commit cd42f24
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions weightedrand.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Chooser[T any, W integer] struct {
data []Choice[T, W]
totals []uint64
max uint64

customRand *rand.Rand
}

// NewChooser initializes a new Chooser for picking from the provided choices.
Expand Down Expand Up @@ -64,7 +66,13 @@ func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], erro
return nil, errNoValidChoices
}

return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal}, nil
return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal, customRand: nil}, nil
}

// SetRand applies an optional custom randomness source r for the Chooser. If
// set to nil nil, global rand will be used.
func (c *Chooser[T, W]) SetRand(r *rand.Rand) {
c.customRand = r
}

// Possible errors returned by NewChooser, preventing the creation of a Chooser
Expand All @@ -82,9 +90,17 @@ var (

// Pick returns a single weighted random Choice.Item from the Chooser.
//
// Utilizes global rand as the source of randomness. Safe for concurrent usage.
// Utilizes global rand as the source of randomness by default, which is safe
// for concurrent usage. If a custom rand source was set with SetRand, that
// source will be used instead.
func (c Chooser[T, W]) Pick() T {
r := rand.Uint64N(c.max) + 1
var r uint64
if c.customRand == nil {
r = rand.Uint64N(c.max) + 1
} else {
r = c.customRand.Uint64N(c.max) + 1
}

i, _ := slices.BinarySearch(c.totals, r)
return c.data[i].Item
}

0 comments on commit cd42f24

Please sign in to comment.