/
weighted.go
52 lines (44 loc) · 883 Bytes
/
weighted.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package selector
import (
"math/rand"
"time"
)
type randomWeightedItem[T any] struct {
item T
weight int
}
type randomWeighted[T any] struct {
items []*randomWeightedItem[T]
sum int
r *rand.Rand
}
func newRandomWeighted[T any]() *randomWeighted[T] {
return &randomWeighted[T]{
r: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (rw *randomWeighted[T]) Add(item T, weight int) {
ri := &randomWeightedItem[T]{item: item, weight: weight}
rw.items = append(rw.items, ri)
rw.sum += weight
}
func (rw *randomWeighted[T]) Next() (v T) {
if len(rw.items) == 0 {
return
}
if rw.sum <= 0 {
return
}
weight := rw.r.Intn(rw.sum) + 1
for _, item := range rw.items {
weight -= item.weight
if weight <= 0 {
return item.item
}
}
return rw.items[len(rw.items)-1].item
}
func (rw *randomWeighted[T]) Reset() {
rw.items = nil
rw.sum = 0
}