This repository has been archived by the owner on Dec 22, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this code is governed by a BSD-style | ||
// license that can be found in the LICENSE file | ||
|
||
package sample | ||
|
||
import "math/rand" | ||
|
||
// Weighted provides sampling without replacement from a collection of items with | ||
// weighted probabilites. | ||
type Weighted struct { | ||
weights []float64 | ||
heap []float64 | ||
src *rand.Rand | ||
} | ||
|
||
// NewWeighted returns a Weighted for the weights w. Elements of w are modified | ||
// during sampling, so a copy should be passed if the user needs to keep the values. | ||
// If src is nil, rand.Rand is used as the random source. | ||
func NewWeighted(w []float64, src *rand.Rand) Weighted { | ||
s := Weighted{ | ||
weights: w, | ||
heap: make([]float64, len(w)), | ||
} | ||
s.reset() | ||
return s | ||
} | ||
|
||
func (s Weighted) reset() { | ||
copy(s.heap, s.weights) | ||
for i := len(s.heap) - 1; i > 0; i-- { | ||
// Sometimes 1-based counting makes sense. | ||
s.heap[((i+1)>>1)-1] += s.heap[i] | ||
} | ||
} | ||
|
||
// Len returns the number of items held by the Weighted, including items | ||
// already taken. | ||
func (s Weighted) Len() int { return len(s.weights) } | ||
|
||
// Take returns an index from the Weighted based on the Weighted weighting, the | ||
// item is then re-weighted to zero. Take returns false if there is no item remaining. | ||
func (s Weighted) Take() (idx int, ok bool) { | ||
if s.heap[0] == 0 { | ||
return -1, false | ||
} | ||
|
||
var r float64 | ||
if s.src == nil { | ||
r = s.heap[0] * rand.Float64() | ||
} else { | ||
r = s.heap[0] * s.src.Float64() | ||
} | ||
i := 1 | ||
for { | ||
if r -= s.weights[i-1]; r <= 0 { | ||
break // fall within item i-1 | ||
} | ||
i <<= 1 // move to left child | ||
if d := s.heap[i-1]; r > d { | ||
r -= d | ||
// if enough r to pass left child | ||
// move to right child | ||
// state will be caught at break above | ||
i++ | ||
} | ||
} | ||
|
||
w, idx := s.weights[i-1], i-1 | ||
|
||
s.weights[i-1] = 0 | ||
for i > 0 { | ||
s.heap[i-1] -= w | ||
i >>= 1 | ||
} | ||
|
||
return idx, true | ||
} | ||
|
||
// Reweight alters the weight of item i in the Weighted. | ||
func (s Weighted) Reweight(idx int, w float64) { | ||
w, s.weights[idx] = s.weights[idx]-w, w | ||
idx++ | ||
for idx > 0 { | ||
s.heap[idx-1] -= w | ||
idx >>= 1 | ||
} | ||
} | ||
|
||
// ReweightAll alters the weight of all items in the Weighted. ReweightAll | ||
// panics if len(w) != s.Len. | ||
func (s Weighted) ReweightAll(w []float64) { | ||
if len(w) != s.Len() { | ||
panic("floats: length of the slices do not match") | ||
} | ||
copy(s.weights, w) | ||
s.reset() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this code is governed by a BSD-style | ||
// license that can be found in the LICENSE file | ||
|
||
package sample | ||
|
||
import ( | ||
"flag" | ||
"math/rand" | ||
"reflect" | ||
"testing" | ||
"time" | ||
|
||
"github.com/gonum/floats" | ||
) | ||
|
||
var prob = flag.Bool("prob", false, "enables probabilistic testing of the random weighted sampler") | ||
|
||
const sigChi2 = 16.92 // p = 0.05 df = 9 | ||
|
||
var ( | ||
newExp = func() []float64 { | ||
return []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8, 1 << 9} | ||
} | ||
exp = newExp() | ||
|
||
obt = []float64{973, 1937, 3898, 7897, 15769, 31284, 62176, 125408, 250295, 500363} | ||
) | ||
|
||
func newTestWeighted() Weighted { | ||
weights := make([]float64, len(obt)) | ||
for i := range weights { | ||
weights[i] = float64(int(1) << uint(i)) | ||
} | ||
return NewWeighted(weights, nil) | ||
} | ||
|
||
func TestWeightedUnseeded(t *testing.T) { | ||
rand.Seed(0) | ||
|
||
want := Weighted{ | ||
weights: []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8, 1 << 9}, | ||
heap: []float64{ | ||
exp[0] + exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9] + exp[2] + exp[5] + exp[6], | ||
exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9], | ||
exp[2] + exp[5] + exp[6], | ||
exp[3] + exp[7] + exp[8], | ||
exp[4] + exp[9], | ||
exp[5], | ||
exp[6], | ||
exp[7], | ||
exp[8], | ||
exp[9], | ||
}, | ||
} | ||
|
||
ts := newTestWeighted() | ||
if !reflect.DeepEqual(ts, want) { | ||
t.Fatalf("unexpected new Weighted value:\ngot: %#v\nwant:%#v", ts, want) | ||
} | ||
|
||
f := make([]float64, len(obt)) | ||
for i := 0; i < 1e6; i++ { | ||
item, ok := newTestWeighted().Take() | ||
if !ok { | ||
t.Fatal("Weighted unexpectedly empty") | ||
} | ||
f[item]++ | ||
} | ||
|
||
exp := newExp() | ||
fac := floats.Sum(f) / floats.Sum(exp) | ||
for i := range f { | ||
exp[i] *= fac | ||
} | ||
|
||
if !reflect.DeepEqual(f, obt) { | ||
t.Fatalf("unexpected selection:\ngot: %#v\nwant:%#v", f, obt) | ||
} | ||
|
||
// Check that this is within statistical expectations - we know this is true for this set. | ||
X := chi2(f, exp) | ||
if X >= sigChi2 { | ||
t.Errorf("H₀: d(Sample) = d(Expect), H₁: d(S) ≠ d(Expect). df = %d, p = 0.05, X² threshold = %.2f, X² = %f", len(f)-1, sigChi2, X) | ||
} | ||
} | ||
|
||
func TestWeightedTimeSeeded(t *testing.T) { | ||
if !*prob { | ||
t.Skip("probabilistic testing not requested") | ||
} | ||
t.Log("Note: This test is stochastic and is expected to fail with probability ≈ 0.05.") | ||
|
||
rand.Seed(time.Now().Unix()) | ||
|
||
f := make([]float64, len(obt)) | ||
for i := 0; i < 1e6; i++ { | ||
item, ok := newTestWeighted().Take() | ||
if !ok { | ||
t.Fatal("Weighted unexpectedly empty") | ||
} | ||
f[item]++ | ||
} | ||
|
||
exp := newExp() | ||
fac := floats.Sum(f) / floats.Sum(exp) | ||
for i := range f { | ||
exp[i] *= fac | ||
} | ||
|
||
// Check that our obtained values are within statistical expectations for p = 0.05. | ||
// This will not be true approximately 1 in 20 tests. | ||
X := chi2(f, exp) | ||
if X >= sigChi2 { | ||
t.Errorf("H₀: d(Sample) = d(Expect), H₁: d(S) ≠ d(Expect). df = %d, p = 0.05, X² threshold = %.2f, X² = %f", len(f)-1, sigChi2, X) | ||
} | ||
} | ||
|
||
func TestWeightZero(t *testing.T) { | ||
rand.Seed(0) | ||
|
||
want := Weighted{ | ||
weights: []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 0, 1 << 7, 1 << 8, 1 << 9}, | ||
heap: []float64{ | ||
exp[0] + exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9] + exp[2] + exp[5], | ||
exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9], | ||
exp[2] + exp[5], | ||
exp[3] + exp[7] + exp[8], | ||
exp[4] + exp[9], | ||
exp[5], | ||
0, | ||
exp[7], | ||
exp[8], | ||
exp[9], | ||
}, | ||
} | ||
|
||
ts := newTestWeighted() | ||
ts.Reweight(6, 0) | ||
if !reflect.DeepEqual(ts, want) { | ||
t.Fatalf("unexpected new Weighted value:\ngot: %#v\nwant:%#v", ts, want) | ||
} | ||
|
||
f := make([]float64, len(obt)) | ||
for i := 0; i < 1e6; i++ { | ||
ts := newTestWeighted() | ||
ts.Reweight(6, 0) | ||
item, ok := ts.Take() | ||
if !ok { | ||
t.Fatal("Weighted unexpectedly empty") | ||
} | ||
f[item]++ | ||
} | ||
|
||
exp := newExp() | ||
fac := floats.Sum(f) / floats.Sum(exp) | ||
for i := range f { | ||
exp[i] *= fac | ||
} | ||
|
||
if f[6] != 0 { | ||
t.Errorf("unexpected selection rate for zero-weighted item: got: %v want:%v", f[6], 0) | ||
} | ||
if reflect.DeepEqual(f[:6], obt[:6]) { | ||
t.Fatal("unexpected selection: too few elements chosen in range:\ngot: %v\nwant:%v", | ||
f[:6], obt[:6]) | ||
} | ||
if reflect.DeepEqual(f[7:], obt[7:]) { | ||
t.Fatal("unexpected selection: too few elements chosen in range:\ngot: %v\nwant:%v", | ||
f[7:], obt[7:]) | ||
} | ||
} | ||
|
||
func TestWeightIncrease(t *testing.T) { | ||
rand.Seed(0) | ||
|
||
want := Weighted{ | ||
weights: []float64{1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 9 * 2, 1 << 7, 1 << 8, 1 << 9}, | ||
heap: []float64{ | ||
exp[0] + exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9] + exp[2] + exp[5] + exp[9]*2, | ||
exp[1] + exp[3] + exp[4] + exp[7] + exp[8] + exp[9], | ||
exp[2] + exp[5] + exp[9]*2, | ||
exp[3] + exp[7] + exp[8], | ||
exp[4] + exp[9], | ||
exp[5], | ||
exp[9] * 2, | ||
exp[7], | ||
exp[8], | ||
exp[9], | ||
}, | ||
} | ||
|
||
ts := newTestWeighted() | ||
ts.Reweight(6, ts.weights[len(ts.weights)-1]*2) | ||
if !reflect.DeepEqual(ts, want) { | ||
t.Fatalf("unexpected new Weighted value:\ngot: %#v\nwant:%#v", ts, want) | ||
} | ||
|
||
f := make([]float64, len(obt)) | ||
for i := 0; i < 1e6; i++ { | ||
ts := newTestWeighted() | ||
ts.Reweight(6, ts.weights[len(ts.weights)-1]*2) | ||
item, ok := ts.Take() | ||
if !ok { | ||
t.Fatal("Weighted unexpectedly empty") | ||
} | ||
f[item]++ | ||
} | ||
|
||
exp := newExp() | ||
fac := floats.Sum(f) / floats.Sum(exp) | ||
for i := range f { | ||
exp[i] *= fac | ||
} | ||
|
||
if f[6] < f[9] { | ||
t.Errorf("unexpected selection rate for re-weighted item: got: %v want:%v", f[6], f[9]) | ||
} | ||
if reflect.DeepEqual(f[:6], obt[:6]) { | ||
t.Fatal("unexpected selection: too many elements chosen in range:\ngot: %v\nwant:%v", | ||
f[:6], obt[:6]) | ||
} | ||
if reflect.DeepEqual(f[7:], obt[7:]) { | ||
t.Fatal("unexpected selection: too many elements chosen in range:\ngot: %v\nwant:%v", | ||
f[7:], obt[7:]) | ||
} | ||
} | ||
|
||
func chi2(ob, ex []float64) (sum float64) { | ||
for i := range ob { | ||
x := ob[i] - ex[i] | ||
sum += (x * x) / ex[i] | ||
} | ||
|
||
return sum | ||
} |