This repository has been archived by the owner on Dec 22, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this code is governed by a BSD-style | ||
// license that can be found in the LICENSE file | ||
|
||
package sample | ||
|
||
import ( | ||
"math/rand" | ||
|
||
"github.com/gonum/floats" | ||
) | ||
|
||
// Weighted provides sampling without replacement from a collection of items with | ||
// non-uniform probability. | ||
type Weighted struct { | ||
weights []float64 | ||
// heap is a weight heap. See comments | ||
// in contained/heap for an explanation. | ||
heap []float64 | ||
src *rand.Rand | ||
} | ||
|
||
// NewWeighted returns a Weighted for the weights w. Elements of w are modified | ||
// during sampling, so a copy should be passed if the user needs to keep the values. | ||
// If src is nil, rand.Rand is used as the random source. | ||
func NewWeighted(w []float64, src *rand.Rand) Weighted { | ||
s := Weighted{ | ||
weights: make([]float64, len(w)), | ||
heap: make([]float64, len(w)), | ||
} | ||
s.ReweightAll(w) | ||
return s | ||
} | ||
|
||
// Len returns the number of items held by the Weighted, including items | ||
// already taken. | ||
func (s Weighted) Len() int { return len(s.weights) } | ||
|
||
// Take returns an index from the Weighted with probability proportional | ||
// to the weight of the item. The weight of the item is then set to zero. | ||
// Take returns false if there are no items remaining. | ||
func (s Weighted) Take() (idx int, ok bool) { | ||
if floats.EqualWithinAbsOrRel(s.heap[0], 0, 1e-12, 1e-12) { | ||
return -1, false | ||
} | ||
|
||
var r float64 | ||
if s.src == nil { | ||
r = s.heap[0] * rand.Float64() | ||
} else { | ||
r = s.heap[0] * s.src.Float64() | ||
} | ||
i := 1 | ||
for { | ||
if r -= s.weights[i-1]; r <= 0 { | ||
break // fall within item i-1 | ||
} | ||
i <<= 1 // move to left child | ||
if d := s.heap[i-1]; r > d { | ||
r -= d | ||
// if enough r to pass left child | ||
// move to right child | ||
// state will be caught at break above | ||
i++ | ||
} | ||
} | ||
|
||
w, idx := s.weights[i-1], i-1 | ||
|
||
s.weights[i-1] = 0 | ||
for i > 0 { | ||
s.heap[i-1] -= w | ||
// The following condition is necessary to | ||
// handle floating point error. If we see | ||
// a heap value below zero, we know we need | ||
// to rebuild it. | ||
if s.heap[i-1] < 0 { | ||
s.reset() | ||
return idx, true | ||
} | ||
i >>= 1 | ||
} | ||
|
||
return idx, true | ||
} | ||
|
||
// Reweight sets the weight of item idx to w. | ||
func (s Weighted) Reweight(idx int, w float64) { | ||
w, s.weights[idx] = s.weights[idx]-w, w | ||
idx++ | ||
for idx > 0 { | ||
s.heap[idx-1] -= w | ||
idx >>= 1 | ||
} | ||
} | ||
|
||
// ReweightAll sets the weight of all items in the Weighted. ReweightAll | ||
// panics if len(w) != s.Len. | ||
func (s Weighted) ReweightAll(w []float64) { | ||
if len(w) != s.Len() { | ||
panic("floats: length of the slices do not match") | ||
} | ||
copy(s.weights, w) | ||
s.reset() | ||
} | ||
|
||
func (s Weighted) reset() { | ||
copy(s.heap, s.weights) | ||
for i := len(s.heap) - 1; i > 0; i-- { | ||
// Sometimes 1-based counting makes sense. | ||
s.heap[((i+1)>>1)-1] += s.heap[i] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this code is governed by a BSD-style | ||
// license that can be found in the LICENSE file | ||
|
||
package sample | ||
|
||
import ( | ||
"flag" | ||
"math/rand" | ||
"reflect" | ||
"testing" | ||
"time" | ||
|
||
"github.com/gonum/floats" | ||
) | ||
|
||
var prob = flag.Bool("prob", false, "enables probabilistic testing of the random weighted sampler") | ||
|
||
const sigChi2 = 16.92 // p = 0.05 df = 9 | ||
|
||
var ( | ||
newExp = func() []float64 { | ||
return []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8, 1 << 9} | ||
} | ||
exp = newExp() | ||
|
||
obt = []float64{973, 1937, 3898, 7897, 15769, 31284, 62176, 125408, 250295, 500363} | ||
) | ||
|
||
func newTestWeighted() Weighted { | ||
weights := make([]float64, len(obt)) | ||
for i := range weights { | ||
weights[i] = float64(int(1) << uint(i)) | ||
} | ||
return NewWeighted(weights, nil) | ||
} | ||
|
||
func TestWeightedUnseeded(t *testing.T) { | ||
rand.Seed(0) | ||
|
||
want := Weighted{ | ||
weights: []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8, 1 << 9}, | ||
heap: []float64{ | ||
exp[0] + exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9] + exp[2] + exp[5] + exp[6], | ||
exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9], | ||
exp[2] + exp[5] + exp[6], | ||
exp[3] + exp[7] + exp[8], | ||
exp[4] + exp[9], | ||
exp[5], | ||
exp[6], | ||
exp[7], | ||
exp[8], | ||
exp[9], | ||
}, | ||
} | ||
|
||
ts := newTestWeighted() | ||
if !reflect.DeepEqual(ts, want) { | ||
t.Fatalf("unexpected new Weighted value:\ngot: %#v\nwant:%#v", ts, want) | ||
} | ||
|
||
f := make([]float64, len(obt)) | ||
for i := 0; i < 1e6; i++ { | ||
item, ok := newTestWeighted().Take() | ||
if !ok { | ||
t.Fatal("Weighted unexpectedly empty") | ||
} | ||
f[item]++ | ||
} | ||
|
||
exp := newExp() | ||
fac := floats.Sum(f) / floats.Sum(exp) | ||
for i := range f { | ||
exp[i] *= fac | ||
} | ||
|
||
if !reflect.DeepEqual(f, obt) { | ||
t.Fatalf("unexpected selection:\ngot: %#v\nwant:%#v", f, obt) | ||
} | ||
|
||
// Check that this is within statistical expectations - we know this is true for this set. | ||
X := chi2(f, exp) | ||
if X >= sigChi2 { | ||
t.Errorf("H₀: d(Sample) = d(Expect), H₁: d(S) ≠ d(Expect). df = %d, p = 0.05, X² threshold = %.2f, X² = %f", len(f)-1, sigChi2, X) | ||
} | ||
} | ||
|
||
func TestWeightedTimeSeeded(t *testing.T) { | ||
if !*prob { | ||
t.Skip("probabilistic testing not requested") | ||
} | ||
t.Log("Note: This test is stochastic and is expected to fail with probability ≈ 0.05.") | ||
|
||
rand.Seed(time.Now().Unix()) | ||
|
||
f := make([]float64, len(obt)) | ||
for i := 0; i < 1e6; i++ { | ||
item, ok := newTestWeighted().Take() | ||
if !ok { | ||
t.Fatal("Weighted unexpectedly empty") | ||
} | ||
f[item]++ | ||
} | ||
|
||
exp := newExp() | ||
fac := floats.Sum(f) / floats.Sum(exp) | ||
for i := range f { | ||
exp[i] *= fac | ||
} | ||
|
||
// Check that our obtained values are within statistical expectations for p = 0.05. | ||
// This will not be true approximately 1 in 20 tests. | ||
X := chi2(f, exp) | ||
if X >= sigChi2 { | ||
t.Errorf("H₀: d(Sample) = d(Expect), H₁: d(S) ≠ d(Expect). df = %d, p = 0.05, X² threshold = %.2f, X² = %f", len(f)-1, sigChi2, X) | ||
} | ||
} | ||
|
||
func TestWeightZero(t *testing.T) { | ||
rand.Seed(0) | ||
|
||
want := Weighted{ | ||
weights: []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 0, 1 << 7, 1 << 8, 1 << 9}, | ||
heap: []float64{ | ||
exp[0] + exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9] + exp[2] + exp[5], | ||
exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9], | ||
exp[2] + exp[5], | ||
exp[3] + exp[7] + exp[8], | ||
exp[4] + exp[9], | ||
exp[5], | ||
0, | ||
exp[7], | ||
exp[8], | ||
exp[9], | ||
}, | ||
} | ||
|
||
ts := newTestWeighted() | ||
ts.Reweight(6, 0) | ||
if !reflect.DeepEqual(ts, want) { | ||
t.Fatalf("unexpected new Weighted value:\ngot: %#v\nwant:%#v", ts, want) | ||
} | ||
|
||
f := make([]float64, len(obt)) | ||
for i := 0; i < 1e6; i++ { | ||
ts := newTestWeighted() | ||
ts.Reweight(6, 0) | ||
item, ok := ts.Take() | ||
if !ok { | ||
t.Fatal("Weighted unexpectedly empty") | ||
} | ||
f[item]++ | ||
} | ||
|
||
exp := newExp() | ||
fac := floats.Sum(f) / floats.Sum(exp) | ||
for i := range f { | ||
exp[i] *= fac | ||
} | ||
|
||
if f[6] != 0 { | ||
t.Errorf("unexpected selection rate for zero-weighted item: got: %v want:%v", f[6], 0) | ||
} | ||
if reflect.DeepEqual(f[:6], obt[:6]) { | ||
t.Fatal("unexpected selection: too few elements chosen in range:\ngot: %v\nwant:%v", | ||
f[:6], obt[:6]) | ||
} | ||
if reflect.DeepEqual(f[7:], obt[7:]) { | ||
t.Fatal("unexpected selection: too few elements chosen in range:\ngot: %v\nwant:%v", | ||
f[7:], obt[7:]) | ||
} | ||
} | ||
|
||
func TestWeightIncrease(t *testing.T) { | ||
rand.Seed(0) | ||
|
||
want := Weighted{ | ||
weights: []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 9 * 2, 1 << 7, 1 << 8, 1 << 9}, | ||
heap: []float64{ | ||
exp[0] + exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9] + exp[2] + exp[5] + exp[9]*2, | ||
exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9], | ||
exp[2] + exp[5] + exp[9]*2, | ||
exp[3] + exp[7] + exp[8], | ||
exp[4] + exp[9], | ||
exp[5], | ||
exp[9] * 2, | ||
exp[7], | ||
exp[8], | ||
exp[9], | ||
}, | ||
} | ||
|
||
ts := newTestWeighted() | ||
ts.Reweight(6, ts.weights[len(ts.weights)-1]*2) | ||
if !reflect.DeepEqual(ts, want) { | ||
t.Fatalf("unexpected new Weighted value:\ngot: %#v\nwant:%#v", ts, want) | ||
} | ||
|
||
f := make([]float64, len(obt)) | ||
for i := 0; i < 1e6; i++ { | ||
ts := newTestWeighted() | ||
ts.Reweight(6, ts.weights[len(ts.weights)-1]*2) | ||
item, ok := ts.Take() | ||
if !ok { | ||
t.Fatal("Weighted unexpectedly empty") | ||
} | ||
f[item]++ | ||
} | ||
|
||
exp := newExp() | ||
fac := floats.Sum(f) / floats.Sum(exp) | ||
for i := range f { | ||
exp[i] *= fac | ||
} | ||
|
||
if f[6] < f[9] { | ||
t.Errorf("unexpected selection rate for re-weighted item: got: %v want:%v", f[6], f[9]) | ||
} | ||
if reflect.DeepEqual(f[:6], obt[:6]) { | ||
t.Fatal("unexpected selection: too many elements chosen in range:\ngot: %v\nwant:%v", | ||
f[:6], obt[:6]) | ||
} | ||
if reflect.DeepEqual(f[7:], obt[7:]) { | ||
t.Fatal("unexpected selection: too many elements chosen in range:\ngot: %v\nwant:%v", | ||
f[7:], obt[7:]) | ||
} | ||
} | ||
|
||
func chi2(ob, ex []float64) (sum float64) { | ||
for i := range ob { | ||
x := ob[i] - ex[i] | ||
sum += (x * x) / ex[i] | ||
} | ||
|
||
return sum | ||
} | ||
|
||
func TestWeightedNoResample(t *testing.T) { | ||
const ( | ||
tries = 10 | ||
n = 10e5 | ||
) | ||
ts := NewWeighted(make([]float64, n), nil) | ||
w := make([]float64, n) | ||
for i := 0; i < tries; i++ { | ||
for j := range w { | ||
w[j] = rand.Float64() * n | ||
} | ||
ts.ReweightAll(w) | ||
taken := make(map[int]struct{}) | ||
var c int | ||
for { | ||
item, ok := ts.Take() | ||
if !ok { | ||
if c != n { | ||
t.Errorf("unexpected number of items: got: %d want: %d", c, n) | ||
} | ||
break | ||
} | ||
c++ | ||
if _, exists := taken[item]; exists { | ||
panic("here") | ||
t.Errorf("unexpected duplicate sample for item: %d", item) | ||
} | ||
taken[item] = struct{}{} | ||
} | ||
} | ||
} |