Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update tests to make stochastic failure less likely
- Loading branch information
Showing
2 changed files
with
79 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
FactCheck | ||
Distributions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,93 @@ | ||
using RouletteWheels | ||
using Base.Test | ||
using FactCheck | ||
using Distributions | ||
using Compat | ||
|
||
const algos = [LinearWalk, BisectingSearch, StochasticAcceptance] | ||
|
||
############################################################################### | ||
# First, I'll use an intuitive test. I tally 1,000 samples over frequencies | ||
# [5,4,3,2,1]. The tallies should be descending. | ||
############################################################################### | ||
|
||
const freqs = [5,4,3,2,1] | ||
const props = freqs / sum(freqs) | ||
|
||
function test_descending(tallies) | ||
sliding_windows = zip(tallies[1:end-1], tallies[2:end]) | ||
@test all(pair -> pair[1] > pair[2], sliding_windows) | ||
end | ||
|
||
for algo in algos | ||
test_descending(rand_tally(algo(freqs), 1000)) | ||
test_descending(rand_tally(algo(props), 1000)) | ||
end | ||
|
||
############################################################################### | ||
# Now, I'll test sampling following growth. Remember, to call `normalize!` | ||
# prior to sampling, following a `push!`. I should add a statistical test | ||
# for distribution comparisons. | ||
############################################################################### | ||
|
||
function test_ascending(tallies) | ||
sliding_windows = zip(tallies[1:end-1], tallies[2:end]) | ||
@test all(pair -> pair[1] < pair[2], sliding_windows) | ||
end | ||
|
||
xs = collect(1:5) | ||
|
||
for algo in algos | ||
selector = algo(xs) | ||
test_ascending(rand_tally(selector, 1000)) | ||
|
||
# Cheap test of iterator interface conformity. | ||
@test collect(selector) == xs | ||
function in_ci(wheel, n) | ||
ps = collect(wheel) | ||
ps /= sum(ps) | ||
tallies = rand_tally(wheel, n) | ||
|
||
# push! then normalize. | ||
push!(selector, 1) | ||
@test length(selector) == 6 | ||
normalize!(selector) | ||
for (p, tally) in zip(ps, tallies) | ||
(l_b, u_b) = quantile(Binomial(n, p), [0.00001, 0.99999]) | ||
if u_b < tally < l_b | ||
return false | ||
end | ||
end | ||
|
||
tally = rand_tally(selector, 300000) | ||
test_ascending(tally[1:5]) | ||
@test tally[6] < tally[2] | ||
return true | ||
end | ||
|
||
############################################################################### | ||
# Test the WheelFromDict wrapper. | ||
############################################################################### | ||
|
||
d = @compat Dict{Symbol, Int}(:red => 1, :green => 2, :blue => 3) | ||
wheel = WheelFromDict(d) | ||
facts("A RouletteWheel") do | ||
context("iterates over the underlying frequency") do | ||
for algo in algos | ||
@fact collect(algo([1,9,9,4])) --> [1, 9, 9, 4] | ||
end | ||
end | ||
|
||
@test length(wheel) == 3 | ||
context("samples in proportion to the given frequency") do | ||
freqs = [5,4,9,2,4] | ||
for algo in algos | ||
@fact in_ci(algo(freqs), 10000) --> true | ||
end | ||
end | ||
|
||
for (k, v) in wheel | ||
@test k ∈ keys(d) | ||
@test v == d[k] | ||
context("samples in proportion to the given proportion") do | ||
ps = [5,4,9,2,4] | ||
ps /= ps | ||
for algo in algos | ||
@fact in_ci(algo(ps), 10000) --> true | ||
end | ||
end | ||
|
||
context("allows for push! given a normalize!") do | ||
freqs = [5,4,9,2,4] | ||
for algo in algos | ||
wheel = algo(freqs) | ||
|
||
push!(wheel, 1) | ||
@fact length(wheel) --> 6 | ||
normalize!(wheel) | ||
@fact wheel[6] --> 1 | ||
|
||
@fact in_ci(wheel, 10000) --> true | ||
end | ||
end | ||
end | ||
|
||
sampled_d = rand_dict(wheel, 1000) | ||
@test sampled_d[:red] < sampled_d[:green] < sampled_d[:blue] | ||
|
||
wheel[:gold] = 4 | ||
normalize!(wheel) | ||
sampled_d = rand_dict(wheel, 1000) | ||
@test sampled_d[:red] < sampled_d[:green] < sampled_d[:blue] < sampled_d[:gold] | ||
|
||
############################################################################### | ||
# Test select_fastest. | ||
############################################################################### | ||
fastest, all_timings = select_fastest(1:3) do wheel | ||
for _ in 1:1000 | ||
rand(wheel) | ||
facts("The WheelFromDict wrapper") do | ||
context("samples underlying values instead of an index") do | ||
d = @compat Dict{Symbol, Int}(:red => 1, :green => 5, :blue => 20) | ||
wheel = WheelFromDict(d) | ||
@fact length(wheel) --> 3 | ||
|
||
for (k, v) in wheel | ||
@fact k ∈ keys(d) --> true | ||
@fact v --> d[k] | ||
end | ||
|
||
sampled_d = rand_dict(wheel, 10000) | ||
@fact sampled_d[:red] < sampled_d[:green] < sampled_d[:blue] --> true | ||
|
||
wheel[:gold] = 90 | ||
normalize!(wheel) | ||
sampled_d = rand_dict(wheel, 10000) | ||
@fact sampled_d[:red] < sampled_d[:green] < sampled_d[:blue] --> true | ||
@fact sampled_d[:blue] < sampled_d[:gold] --> true | ||
end | ||
end | ||
@test length(all_timings) == 3 | ||
@test all_timings[fastest] == minimum(values(all_timings)) | ||
|
||
facts("The select_fastest function") do | ||
context("selects the fastest wheel over some proportions") do | ||
fastest, all_timings = select_fastest(1:3) do wheel | ||
for _ in 1:1000 | ||
rand(wheel) | ||
end | ||
end | ||
@fact length(all_timings) --> 3 | ||
@fact all_timings[fastest] --> minimum(values(all_timings)) | ||
end | ||
end |