Skip to content
This repository has been archived by the owner on Dec 22, 2018. It is now read-only.

Commit

Permalink
stat/sample: add weighted sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Jun 5, 2015
1 parent 11b7e99 commit ac73525
Show file tree
Hide file tree
Showing 2 changed files with 400 additions and 0 deletions.
133 changes: 133 additions & 0 deletions sample/weighted.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this code is governed by a BSD-style
// license that can be found in the LICENSE file

package sample

import (
"math/rand"

"github.com/gonum/floats"
)

// Weighted provides sampling without replacement from a collection of items with
// non-uniform probability.
type Weighted struct {
weights []float64
// heap is a weight heap.
//
// It keeps a heap-organised sum of remaining
// index weights that are available to be taken
// from.
//
// Each elements holds the sum of weights for
// the corresponding index, plus the sum of
// of its children's weights; the children
// of an element i can be found at positions
// 2*(i+1)-1 and 2*(i+1). The root of the
// weight heap is at element 0.
//
// See comments in container/heap for an
// explanation of the layout of a heap.
heap []float64
src *rand.Rand
}

// NewWeighted returns a Weighted for the weights w. If src is nil, rand.Rand is
// used as the random source.
func NewWeighted(w []float64, src *rand.Rand) Weighted {
s := Weighted{
weights: make([]float64, len(w)),
heap: make([]float64, len(w)),
}
s.ReweightAll(w)
return s
}

// Len returns the number of items held by the Weighted, including items
// already taken.
func (s Weighted) Len() int { return len(s.weights) }

// Take returns an index from the Weighted with probability proportional
// to the weight of the item. The weight of the item is then set to zero.
// Take returns false if there are no items remaining.
func (s Weighted) Take() (idx int, ok bool) {
if floats.EqualWithinAbsOrRel(s.heap[0], 0, 1e-12, 1e-12) {
return -1, false
}

var r float64
if s.src == nil {
r = s.heap[0] * rand.Float64()
} else {
r = s.heap[0] * s.src.Float64()
}
i := 1
last := -1
left := len(s.weights)
for {
if r -= s.weights[i-1]; r <= 0 {
break // Fall within item i-1.
}
i <<= 1 // Move to left child.
if d := s.heap[i-1]; r > d {
r -= d
// If enough r to pass left child
// move to right child state will
// be caught at break above.
i++
}
if i == last || left < 0 {
// No progression.
return -1, false
}
last = i
left--
}

w, idx := s.weights[i-1], i-1

s.weights[i-1] = 0
for i > 0 {
s.heap[i-1] -= w
// The following condition is necessary to
// handle floating point error. If we see
// a heap value below zero, we know we need
// to rebuild it.
if s.heap[i-1] < 0 {
s.reset()
return idx, true
}
i >>= 1
}

return idx, true
}

// Reweight sets the weight of item idx to w.
func (s Weighted) Reweight(idx int, w float64) {
w, s.weights[idx] = s.weights[idx]-w, w
idx++
for idx > 0 {
s.heap[idx-1] -= w
idx >>= 1
}
}

// ReweightAll sets the weight of all items in the Weighted. ReweightAll
// panics if len(w) != s.Len.
func (s Weighted) ReweightAll(w []float64) {
if len(w) != s.Len() {
panic("floats: length of the slices do not match")
}
copy(s.weights, w)
s.reset()
}

func (s Weighted) reset() {
copy(s.heap, s.weights)
for i := len(s.heap) - 1; i > 0; i-- {
// Sometimes 1-based counting makes sense.
s.heap[((i+1)>>1)-1] += s.heap[i]
}
}
Loading

0 comments on commit ac73525

Please sign in to comment.