This repository has been archived by the owner on Dec 22, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
400 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this code is governed by a BSD-style | ||
// license that can be found in the LICENSE file | ||
|
||
package sample | ||
|
||
import ( | ||
"math/rand" | ||
|
||
"github.com/gonum/floats" | ||
) | ||
|
||
// Weighted provides sampling without replacement from a collection of items with | ||
// non-uniform probability. | ||
type Weighted struct { | ||
weights []float64 | ||
// heap is a weight heap. | ||
// | ||
// It keeps a heap-organised sum of remaining | ||
// index weights that are available to be taken | ||
// from. | ||
// | ||
// Each elements holds the sum of weights for | ||
// the corresponding index, plus the sum of | ||
// of its children's weights; the children | ||
// of an element i can be found at positions | ||
// 2*(i+1)-1 and 2*(i+1). The root of the | ||
// weight heap is at element 0. | ||
// | ||
// See comments in container/heap for an | ||
// explanation of the layout of a heap. | ||
heap []float64 | ||
src *rand.Rand | ||
} | ||
|
||
// NewWeighted returns a Weighted for the weights w. If src is nil, rand.Rand is | ||
// used as the random source. | ||
func NewWeighted(w []float64, src *rand.Rand) Weighted { | ||
s := Weighted{ | ||
weights: make([]float64, len(w)), | ||
heap: make([]float64, len(w)), | ||
} | ||
s.ReweightAll(w) | ||
return s | ||
} | ||
|
||
// Len returns the number of items held by the Weighted, including items | ||
// already taken. | ||
func (s Weighted) Len() int { return len(s.weights) } | ||
|
||
// Take returns an index from the Weighted with probability proportional | ||
// to the weight of the item. The weight of the item is then set to zero. | ||
// Take returns false if there are no items remaining. | ||
func (s Weighted) Take() (idx int, ok bool) { | ||
if floats.EqualWithinAbsOrRel(s.heap[0], 0, 1e-12, 1e-12) { | ||
return -1, false | ||
} | ||
|
||
var r float64 | ||
if s.src == nil { | ||
r = s.heap[0] * rand.Float64() | ||
} else { | ||
r = s.heap[0] * s.src.Float64() | ||
} | ||
i := 1 | ||
last := -1 | ||
left := len(s.weights) | ||
for { | ||
if r -= s.weights[i-1]; r <= 0 { | ||
break // Fall within item i-1. | ||
} | ||
i <<= 1 // Move to left child. | ||
if d := s.heap[i-1]; r > d { | ||
r -= d | ||
// If enough r to pass left child | ||
// move to right child state will | ||
// be caught at break above. | ||
i++ | ||
} | ||
if i == last || left < 0 { | ||
// No progression. | ||
return -1, false | ||
} | ||
last = i | ||
left-- | ||
} | ||
|
||
w, idx := s.weights[i-1], i-1 | ||
|
||
s.weights[i-1] = 0 | ||
for i > 0 { | ||
s.heap[i-1] -= w | ||
// The following condition is necessary to | ||
// handle floating point error. If we see | ||
// a heap value below zero, we know we need | ||
// to rebuild it. | ||
if s.heap[i-1] < 0 { | ||
s.reset() | ||
return idx, true | ||
} | ||
i >>= 1 | ||
} | ||
|
||
return idx, true | ||
} | ||
|
||
// Reweight sets the weight of item idx to w. | ||
func (s Weighted) Reweight(idx int, w float64) { | ||
w, s.weights[idx] = s.weights[idx]-w, w | ||
idx++ | ||
for idx > 0 { | ||
s.heap[idx-1] -= w | ||
idx >>= 1 | ||
} | ||
} | ||
|
||
// ReweightAll sets the weight of all items in the Weighted. ReweightAll | ||
// panics if len(w) != s.Len. | ||
func (s Weighted) ReweightAll(w []float64) { | ||
if len(w) != s.Len() { | ||
panic("floats: length of the slices do not match") | ||
} | ||
copy(s.weights, w) | ||
s.reset() | ||
} | ||
|
||
func (s Weighted) reset() { | ||
copy(s.heap, s.weights) | ||
for i := len(s.heap) - 1; i > 0; i-- { | ||
// Sometimes 1-based counting makes sense. | ||
s.heap[((i+1)>>1)-1] += s.heap[i] | ||
} | ||
} |
Oops, something went wrong.