Skip to content
This repository has been archived by the owner on Dec 22, 2018. It is now read-only.

Commit

Permalink
stat/sample: add weighted sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Jun 4, 2015
1 parent 11b7e99 commit 20ea4c6
Show file tree
Hide file tree
Showing 2 changed files with 381 additions and 0 deletions.
113 changes: 113 additions & 0 deletions sample/weighted.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this code is governed by a BSD-style
// license that can be found in the LICENSE file

package sample

import (
"math/rand"

"github.com/gonum/floats"
)

// Weighted provides sampling without replacement from a collection of items with
// non-uniform probability.
type Weighted struct {
weights []float64
// heap is a weight heap. See comments
// in contained/heap for an explanation.
heap []float64
src *rand.Rand
}

// NewWeighted returns a Weighted for the weights w. Elements of w are modified
// during sampling, so a copy should be passed if the user needs to keep the values.
// If src is nil, rand.Rand is used as the random source.
func NewWeighted(w []float64, src *rand.Rand) Weighted {
s := Weighted{
weights: make([]float64, len(w)),
heap: make([]float64, len(w)),
}
s.ReweightAll(w)
return s
}

// Len returns the number of items held by the Weighted, including items
// already taken.
func (s Weighted) Len() int { return len(s.weights) }

// Take returns an index from the Weighted with probability proportional
// to the weight of the item. The weight of the item is then set to zero.
// Take returns false if there are no items remaining.
func (s Weighted) Take() (idx int, ok bool) {
if floats.EqualWithinAbsOrRel(s.heap[0], 0, 1e-12, 1e-12) {
return -1, false
}

var r float64
if s.src == nil {
r = s.heap[0] * rand.Float64()
} else {
r = s.heap[0] * s.src.Float64()
}
i := 1
for {
if r -= s.weights[i-1]; r <= 0 {
break // fall within item i-1
}
i <<= 1 // move to left child
if d := s.heap[i-1]; r > d {
r -= d
// if enough r to pass left child
// move to right child
// state will be caught at break above
i++
}
}

w, idx := s.weights[i-1], i-1

s.weights[i-1] = 0
for i > 0 {
s.heap[i-1] -= w
// The following condition is necessary to
// handle floating point error. If we see
// a heap value below zero, we know we need
// to rebuild it.
if s.heap[i-1] < 0 {
s.reset()
return idx, true
}
i >>= 1
}

return idx, true
}

// Reweight sets the weight of item idx to w.
func (s Weighted) Reweight(idx int, w float64) {
w, s.weights[idx] = s.weights[idx]-w, w
idx++
for idx > 0 {
s.heap[idx-1] -= w
idx >>= 1
}
}

// ReweightAll sets the weight of all items in the Weighted. ReweightAll
// panics if len(w) != s.Len.
func (s Weighted) ReweightAll(w []float64) {
if len(w) != s.Len() {
panic("floats: length of the slices do not match")
}
copy(s.weights, w)
s.reset()
}

func (s Weighted) reset() {
copy(s.heap, s.weights)
for i := len(s.heap) - 1; i > 0; i-- {
// Sometimes 1-based counting makes sense.
s.heap[((i+1)>>1)-1] += s.heap[i]
}
}
268 changes: 268 additions & 0 deletions sample/weighted_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this code is governed by a BSD-style
// license that can be found in the LICENSE file

package sample

import (
"flag"
"math/rand"
"reflect"
"testing"
"time"

"github.com/gonum/floats"
)

var prob = flag.Bool("prob", false, "enables probabilistic testing of the random weighted sampler")

const sigChi2 = 16.92 // p = 0.05 df = 9

var (
newExp = func() []float64 {
return []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8, 1 << 9}
}
exp = newExp()

obt = []float64{973, 1937, 3898, 7897, 15769, 31284, 62176, 125408, 250295, 500363}
)

func newTestWeighted() Weighted {
weights := make([]float64, len(obt))
for i := range weights {
weights[i] = float64(int(1) << uint(i))
}
return NewWeighted(weights, nil)
}

func TestWeightedUnseeded(t *testing.T) {
rand.Seed(0)

want := Weighted{
weights: []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8, 1 << 9},
heap: []float64{
exp[0] + exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9] + exp[2] + exp[5] + exp[6],
exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9],
exp[2] + exp[5] + exp[6],
exp[3] + exp[7] + exp[8],
exp[4] + exp[9],
exp[5],
exp[6],
exp[7],
exp[8],
exp[9],
},
}

ts := newTestWeighted()
if !reflect.DeepEqual(ts, want) {
t.Fatalf("unexpected new Weighted value:\ngot: %#v\nwant:%#v", ts, want)
}

f := make([]float64, len(obt))
for i := 0; i < 1e6; i++ {
item, ok := newTestWeighted().Take()
if !ok {
t.Fatal("Weighted unexpectedly empty")
}
f[item]++
}

exp := newExp()
fac := floats.Sum(f) / floats.Sum(exp)
for i := range f {
exp[i] *= fac
}

if !reflect.DeepEqual(f, obt) {
t.Fatalf("unexpected selection:\ngot: %#v\nwant:%#v", f, obt)
}

// Check that this is within statistical expectations - we know this is true for this set.
X := chi2(f, exp)
if X >= sigChi2 {
t.Errorf("H₀: d(Sample) = d(Expect), H₁: d(S) ≠ d(Expect). df = %d, p = 0.05, X² threshold = %.2f, X² = %f", len(f)-1, sigChi2, X)
}
}

func TestWeightedTimeSeeded(t *testing.T) {
if !*prob {
t.Skip("probabilistic testing not requested")
}
t.Log("Note: This test is stochastic and is expected to fail with probability ≈ 0.05.")

rand.Seed(time.Now().Unix())

f := make([]float64, len(obt))
for i := 0; i < 1e6; i++ {
item, ok := newTestWeighted().Take()
if !ok {
t.Fatal("Weighted unexpectedly empty")
}
f[item]++
}

exp := newExp()
fac := floats.Sum(f) / floats.Sum(exp)
for i := range f {
exp[i] *= fac
}

// Check that our obtained values are within statistical expectations for p = 0.05.
// This will not be true approximately 1 in 20 tests.
X := chi2(f, exp)
if X >= sigChi2 {
t.Errorf("H₀: d(Sample) = d(Expect), H₁: d(S) ≠ d(Expect). df = %d, p = 0.05, X² threshold = %.2f, X² = %f", len(f)-1, sigChi2, X)
}
}

func TestWeightZero(t *testing.T) {
rand.Seed(0)

want := Weighted{
weights: []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 0, 1 << 7, 1 << 8, 1 << 9},
heap: []float64{
exp[0] + exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9] + exp[2] + exp[5],
exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9],
exp[2] + exp[5],
exp[3] + exp[7] + exp[8],
exp[4] + exp[9],
exp[5],
0,
exp[7],
exp[8],
exp[9],
},
}

ts := newTestWeighted()
ts.Reweight(6, 0)
if !reflect.DeepEqual(ts, want) {
t.Fatalf("unexpected new Weighted value:\ngot: %#v\nwant:%#v", ts, want)
}

f := make([]float64, len(obt))
for i := 0; i < 1e6; i++ {
ts := newTestWeighted()
ts.Reweight(6, 0)
item, ok := ts.Take()
if !ok {
t.Fatal("Weighted unexpectedly empty")
}
f[item]++
}

exp := newExp()
fac := floats.Sum(f) / floats.Sum(exp)
for i := range f {
exp[i] *= fac
}

if f[6] != 0 {
t.Errorf("unexpected selection rate for zero-weighted item: got: %v want:%v", f[6], 0)
}
if reflect.DeepEqual(f[:6], obt[:6]) {
t.Fatal("unexpected selection: too few elements chosen in range:\ngot: %v\nwant:%v",
f[:6], obt[:6])
}
if reflect.DeepEqual(f[7:], obt[7:]) {
t.Fatal("unexpected selection: too few elements chosen in range:\ngot: %v\nwant:%v",
f[7:], obt[7:])
}
}

func TestWeightIncrease(t *testing.T) {
rand.Seed(0)

want := Weighted{
weights: []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 9 * 2, 1 << 7, 1 << 8, 1 << 9},
heap: []float64{
exp[0] + exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9] + exp[2] + exp[5] + exp[9]*2,
exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9],
exp[2] + exp[5] + exp[9]*2,
exp[3] + exp[7] + exp[8],
exp[4] + exp[9],
exp[5],
exp[9] * 2,
exp[7],
exp[8],
exp[9],
},
}

ts := newTestWeighted()
ts.Reweight(6, ts.weights[len(ts.weights)-1]*2)
if !reflect.DeepEqual(ts, want) {
t.Fatalf("unexpected new Weighted value:\ngot: %#v\nwant:%#v", ts, want)
}

f := make([]float64, len(obt))
for i := 0; i < 1e6; i++ {
ts := newTestWeighted()
ts.Reweight(6, ts.weights[len(ts.weights)-1]*2)
item, ok := ts.Take()
if !ok {
t.Fatal("Weighted unexpectedly empty")
}
f[item]++
}

exp := newExp()
fac := floats.Sum(f) / floats.Sum(exp)
for i := range f {
exp[i] *= fac
}

if f[6] < f[9] {
t.Errorf("unexpected selection rate for re-weighted item: got: %v want:%v", f[6], f[9])
}
if reflect.DeepEqual(f[:6], obt[:6]) {
t.Fatal("unexpected selection: too many elements chosen in range:\ngot: %v\nwant:%v",
f[:6], obt[:6])
}
if reflect.DeepEqual(f[7:], obt[7:]) {
t.Fatal("unexpected selection: too many elements chosen in range:\ngot: %v\nwant:%v",
f[7:], obt[7:])
}
}

func chi2(ob, ex []float64) (sum float64) {
for i := range ob {
x := ob[i] - ex[i]
sum += (x * x) / ex[i]
}

return sum
}

func TestWeightedNoResample(t *testing.T) {
const (
tries = 10
n = 10e5
)
ts := NewWeighted(make([]float64, n), nil)
w := make([]float64, n)
for i := 0; i < tries; i++ {
for j := range w {
w[j] = rand.Float64() * n
}
ts.ReweightAll(w)
taken := make(map[int]struct{})
var c int
for {
item, ok := ts.Take()
if !ok {
if c != n {
t.Errorf("unexpected number of items: got: %d want: %d", c, n)
}
break
}
c++
if _, exists := taken[item]; exists {
panic("here")
t.Errorf("unexpected duplicate sample for item: %d", item)
}
taken[item] = struct{}{}
}
}
}

0 comments on commit 20ea4c6

Please sign in to comment.