forked from ava-labs/avalanchego
-
Notifications
You must be signed in to change notification settings - Fork 0
/
weighted.go
116 lines (99 loc) · 2.87 KB
/
weighted.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package random
import (
"math"
"math/rand"
)
// Weighted implements the Sampler interface by sampling based on a heap
// structure.
//
// Node weight is defined as the node's given weight along with it's
// children's recursive weights. Once sampled, a nodes given weight is set to 0.
//
// Replacing runs in O(n) time while sampling runs in O(log(n)) time.
type Weighted struct {
Weights []uint64
// The reason this is separated from Weights, is because it is set to 0
// after being sampled.
weights []int64
cumWeights []int64
}
func (s *Weighted) init() {
if len(s.Weights) != len(s.weights) {
s.Replace()
}
}
// Sample returns a number in [0, len(weights)) with probability proportional to
// the weight of the item at that index. Assumes Len > 0. Sample takes
// O(log(len(weights))) time.
func (s *Weighted) Sample() int {
i := s.SampleReplace()
s.changeWeight(i, 0)
return i
}
// SampleReplace returns a number in [0, len(weights)) with probability
// proportional to the weight of the item at that index. Assumes CanSample
// returns true. Sample takes O(log(len(weights))) time. The returned index is
// not removed.
func (s *Weighted) SampleReplace() int {
s.init()
for w, i := rand.Int63n(s.cumWeights[0]), 0; ; {
w -= s.weights[i]
if w < 0 {
return i
}
i = i*2 + 1 // We shouldn't return the root, so check the left child
if lw := s.cumWeights[i]; lw <= w {
// If the weight is greater than the left weight, you should move to
// the right child
w -= lw
i++
}
}
}
// CanSample returns the number of items left that can be sampled
func (s *Weighted) CanSample() bool {
s.init()
return len(s.cumWeights) > 0 && s.cumWeights[0] > 0
}
// Replace all the sampled elements. Takes O(len(weights)) time.
func (s *Weighted) Replace() {
// Attempt to malloc as few times as possible
if s.weights == nil || cap(s.weights) < len(s.Weights) {
s.weights = make([]int64, len(s.Weights))
} else {
s.weights = s.weights[:len(s.Weights)]
}
if s.cumWeights == nil || cap(s.cumWeights) < len(s.Weights) {
s.cumWeights = make([]int64, len(s.Weights))
} else {
s.cumWeights = s.cumWeights[:len(s.Weights)]
}
for i, w := range s.Weights {
if w > math.MaxInt64 {
panic("Weight too large")
}
s.weights[i] = int64(w)
}
copy(s.cumWeights, s.weights)
// Initialize the heap
for i := len(s.cumWeights) - 1; i > 0; i-- {
parent := (i - 1) / 2
w := uint64(s.cumWeights[parent]) + uint64(s.cumWeights[i])
if w > math.MaxInt64 {
panic("Weight too large")
}
s.cumWeights[parent] = int64(w)
}
}
func (s *Weighted) changeWeight(i int, newWeight int64) {
change := s.weights[i] - newWeight
s.weights[i] = newWeight
// Decrease my weight and all my parents weights.
s.cumWeights[i] -= change
for i > 0 {
i = (i - 1) / 2
s.cumWeights[i] -= change
}
}