This repository has been archived by the owner on Dec 22, 2018. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit for univariate advanced sampling package. Contains Lat…
…inHypercube, Importance sampling, Rejection sampling, and MetropolisHastings Removed duplicated interface Improve MH comment Made MH examples Documentation fixes Fix documentation and permute LHC
- Loading branch information
Showing
6 changed files
with
379 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package dist | ||
|
||
type LogProber interface { | ||
LogProb(float64) float64 | ||
} | ||
|
||
type Rander interface { | ||
Rand() float64 | ||
} | ||
|
||
type RandLogProber interface { | ||
Rander | ||
LogProber | ||
} | ||
|
||
type Quantiler interface { | ||
Quantile(p float64) float64 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package sample | ||
|
||
import "github.com/gonum/stat/dist" | ||
|
||
type ProposalDist struct { | ||
Sigma float64 | ||
} | ||
|
||
func (p ProposalDist) ConditionalRand(y float64) float64 { | ||
return dist.Normal{Mu: y, Sigma: p.Sigma}.Rand() | ||
} | ||
|
||
func (p ProposalDist) ConditionalLogProb(x, y float64) float64 { | ||
return dist.Normal{Mu: y, Sigma: p.Sigma}.LogProb(x) | ||
} | ||
|
||
func ExampleMetropolisHastings_burnin() { | ||
n := 1000 // The number of samples to generate. | ||
burnin := 50 // Number of samples to ignore at the start. | ||
var initial float64 | ||
// target is the distribution from which we would like to sample. | ||
target := dist.Weibull{K: 5, Lambda: 0.5} | ||
// proposal is the proposal distribution. Here, we are choosing | ||
// a tight Gaussian distribution around the current location. In | ||
// typical problems, if Sigma is too small, it takes a lot of samples | ||
// to move around the distribution. If Sigma is too large, it can be hard | ||
// to find acceptable samples. | ||
proposal := ProposalDist{Sigma: 0.2} | ||
|
||
samples := make([]float64, n+burnin) | ||
MetropolisHastings(samples, initial, target, proposal, nil) | ||
|
||
// Remove the initial samples through slicing. | ||
samples = samples[burnin:] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package sample | ||
|
||
import "github.com/gonum/stat/dist" | ||
|
||
func max(a, b int) int { | ||
if a < b { | ||
return b | ||
} | ||
return a | ||
} | ||
|
||
func ExampleMetropolisHastings_samplingRate() { | ||
// See Burnin example for a description of these quantities. | ||
n := 1000 | ||
burnin := 300 | ||
var initial float64 | ||
target := dist.Weibull{K: 5, Lambda: 0.5} | ||
proposal := ProposalDist{Sigma: 0.2} | ||
|
||
// Successive samples are correlated with one another through the | ||
// Markov Chain defined by the proposal distribution. To get less | ||
// correlated samples, one may use a sampling rate, in which only | ||
// one sample from every few is accepted from the chain. This can | ||
// be accomplished through a for loop. | ||
rate := 50 | ||
|
||
tmp := make([]float64, max(rate, burnin)) | ||
|
||
// First deal with burnin | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong. |
||
tmp = tmp[:burnin] | ||
MetropolisHastings(tmp, initial, target, proposal, nil) | ||
// The final sample in tmp in the final point in the chain. | ||
// Use it as the new initial location. | ||
initial = tmp[len(tmp)-1] | ||
|
||
// Now, generate samples by using one every rate samples. | ||
tmp = tmp[:rate] | ||
samples := make([]float64, n) | ||
samples[0] = initial | ||
for i := 1; i < len(samples); i++ { | ||
MetropolisHastings(tmp, initial, target, proposal, nil) | ||
initial = tmp[len(tmp)-1] | ||
samples[i] = initial | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
// package sample contains a set of advanced routines for sampling from | ||
// probability distributions. | ||
package sample | ||
|
||
import ( | ||
"math" | ||
"math/rand" | ||
|
||
"github.com/gonum/stat/dist" | ||
) | ||
|
||
var ( | ||
badLengthMismatch = "sample: slice length mismatch" | ||
) | ||
|
||
// LatinHypercube generates len(samples) samples using Latin hypercube sampling | ||
// from the given distribution. If src != nil, it will be used to generate | ||
// random numbers, otherwise rand.Float64 will be used. | ||
// | ||
// Latin hypercube sampling divides the cumulative distribution function into equally | ||
// spaced bins and guarantees that one sample is generated per bin. Within each bin, | ||
// the location is randomly sampled. The dist.UnitNormal variable can be used | ||
// for easy generation from the unit interval. | ||
func LatinHypercube(samples []float64, q dist.Quantiler, src *rand.Rand) { | ||
n := len(samples) | ||
var perm []int | ||
var f64 func() float64 | ||
if src != nil { | ||
f64 = src.Float64 | ||
perm = src.Perm(n) | ||
} else { | ||
f64 = rand.Float64 | ||
perm = rand.Perm(n) | ||
} | ||
for i := range samples { | ||
v := f64()/float64(n) + float64(i)/float64(n) | ||
samples[perm[i]] = q.Quantile(v) | ||
} | ||
} | ||
|
||
// Importance sampling generates len(x) samples from the proposal distribution, | ||
// and stores the locations and importance sampling weights in place. | ||
// | ||
// Importance sampling is a variance reduction technique where samples are | ||
// generated from a proposal distribution, q(x), instead of the target distribution | ||
// p(x). This allows relatively unlikely samples in p(x) to be generated more frequently | ||
// | ||
// The importance sampling weight at x is given by p(x)/q(x). To reduce variance, | ||
// a good proposal distribution will bound this sampling weight. This implies the | ||
// support of q(x) should be at least as broad as p(x), and q(x) should be "fatter tailed" | ||
// than p(x). | ||
func Importance(samples, weights []float64, target dist.LogProber, proposal dist.RandLogProber) { | ||
if len(samples) != len(weights) { | ||
panic(badLengthMismatch) | ||
} | ||
for i := range samples { | ||
v := proposal.Rand() | ||
samples[i] = v | ||
weights[i] = math.Exp(target.LogProb(v) - proposal.LogProb(v)) | ||
} | ||
} | ||
|
||
// Rejection generates len(x) samples using the rejection sampling algorithm and | ||
// stores them in place into samples. | ||
// Sampling continues until x is filled. Rejection the total number of proposed | ||
// locations and a boolean indicating if the rejection sampling assumption is | ||
// violated (see details below). If the returned boolean is false, all elements | ||
// of samples are set to NaN. If src != nil, it will be used to generate random | ||
// numbers, otherwise rand.Float64 will be used. | ||
// | ||
// Rejection sampling generates points from the target distribution by using | ||
// the proposal distribution. At each step of the algorithm, the proposaed point | ||
// is accepted with probability | ||
// p = target(x) / (proposal(x) * c) | ||
// where target(x) is the probability of the point according to the target distribution | ||
// and proposal(x) is the probability according to the proposal distribution. | ||
// The constant c must be chosen such that target(x) < proposal(x) * c for all x. | ||
// The expected number of proposed samples is len(samples) * c. | ||
// | ||
// Target may return the true (log of) the probablity of the location, or it may return | ||
// a value that is proportional to the probability (logprob + constant). This is | ||
// useful for cases where the probability distribution is only known up to a normalization | ||
// constant. | ||
func Rejection(samples []float64, target dist.LogProber, proposal dist.RandLogProber, c float64, src *rand.Rand) (nProposed int, ok bool) { | ||
if c < 1 { | ||
panic("rejection: acceptance constant must be greater than 1") | ||
} | ||
f64 := rand.Float64 | ||
if src != nil { | ||
f64 = src.Float64 | ||
} | ||
var idx int | ||
for { | ||
nProposed++ | ||
v := proposal.Rand() | ||
qx := proposal.LogProb(v) | ||
px := target.LogProb(v) | ||
accept := math.Exp(px-qx) / c | ||
if accept > 1 { | ||
// Invalidate the whole result and return a failure | ||
for i := range samples { | ||
samples[i] = math.NaN() | ||
} | ||
return nProposed, false | ||
} | ||
if accept > f64() { | ||
samples[idx] = v | ||
idx++ | ||
if idx == len(samples) { | ||
break | ||
} | ||
} | ||
} | ||
return nProposed, true | ||
} | ||
|
||
type MHProposal interface { | ||
// ConditionalDist returns the probability of the first argument conditioned on | ||
// being at the second argument | ||
// p(x|y) | ||
ConditionalLogProb(x, y float64) (prob float64) | ||
|
||
// ConditionalRand generates a new random location conditioned being at the | ||
// location y | ||
ConditionalRand(y float64) (x float64) | ||
} | ||
|
||
// MetropolisHastings generates len(samples) samples using the Metropolis Hastings | ||
// algorithm (http://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm), | ||
// with the given target and proposal distributions, starting at the intial location | ||
// and storing the results in-place into samples. If src != nil, it will be used to generate random | ||
// numbers, otherwise rand.Float64 will be used. | ||
// | ||
// Metropolis-Hastings is a Markov-chain Monte Carlo algorithm that generates | ||
// samples according to the distribution specified by target by using the Markov | ||
// chain implicitly defined by the proposal distribution. At each | ||
// iteration, a proposal point is generated randomly from the current location. | ||
// This proposal point is accepted with probability | ||
// p = min(1, (target(new) * proposal(current|new)) / (target(current) * proposal(new|current))) | ||
// If the new location is accepted, it is stored into samples and becomes the | ||
// new current location. If it is rejected, the current location remains and | ||
// is stored into samples. Thus, a location is stored into samples at every iteration. | ||
// | ||
// The samples in Metropolis Hastings are correlated with one another through the | ||
// Markov-Chain. As a result, the initial value can have a significant influence | ||
// on the early samples, and so typically, the first sapmles generated by the chain. | ||
// are ignored. This is known as "burn-in", and can be accomplished with slicing. | ||
// The best choice for burn-in length will depend on the sampling and the target | ||
// distribution. | ||
// | ||
// Many choose to have a sampling "rate" where a number of samples | ||
// are ignored in between each kept sample. This helps decorrelate | ||
// the samples from one another, but also reduces the number of available samples. | ||
// A sampling rate can be implemented with successive calls to MetropolisHastings. | ||
func MetropolisHastings(samples []float64, initial float64, target dist.LogProber, proposal MHProposal, src *rand.Rand) { | ||
f64 := rand.Float64 | ||
if src != nil { | ||
f64 = src.Float64 | ||
} | ||
current := initial | ||
currentLogProb := target.LogProb(initial) | ||
for i := range samples { | ||
proposed := proposal.ConditionalRand(current) | ||
proposedLogProb := target.LogProb(proposed) | ||
probTo := proposal.ConditionalLogProb(proposed, current) | ||
probBack := proposal.ConditionalLogProb(current, proposed) | ||
|
||
accept := math.Exp(proposedLogProb + probBack - probTo - currentLogProb) | ||
if accept > f64() { | ||
current = proposed | ||
currentLogProb = proposedLogProb | ||
} | ||
samples[i] = current | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
package sample | ||
|
||
import ( | ||
"math" | ||
"sort" | ||
"testing" | ||
|
||
"github.com/gonum/stat" | ||
"github.com/gonum/stat/dist" | ||
) | ||
|
||
type lhDist interface { | ||
Quantile(float64) float64 | ||
CDF(float64) float64 | ||
} | ||
|
||
func TestLatinHypercube(t *testing.T) { | ||
for _, nSamples := range []int{1, 2, 5, 10, 20} { | ||
samples := make([]float64, nSamples) | ||
for _, dist := range []lhDist{ | ||
dist.Uniform{Min: 0, Max: 1}, | ||
dist.Uniform{Min: 0, Max: 10}, | ||
dist.Normal{Mu: 5, Sigma: 3}, | ||
} { | ||
LatinHypercube(samples, dist, nil) | ||
sort.Float64s(samples) | ||
for i, v := range samples { | ||
p := dist.CDF(v) | ||
if p < float64(i)/float64(nSamples) || p > float64(i+1)/float64(nSamples) { | ||
t.Errorf("probability out of bounds") | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
func TestImportance(t *testing.T) { | ||
// Test by finding the expected value of a Normal | ||
trueMean := 3.0 | ||
target := dist.Normal{Mu: trueMean, Sigma: 2} | ||
proposal := dist.Normal{Mu: 0, Sigma: 5} | ||
nSamples := 100000 | ||
x := make([]float64, nSamples) | ||
weights := make([]float64, nSamples) | ||
Importance(x, weights, target, proposal) | ||
ev := stat.Mean(x, weights) | ||
if math.Abs(ev-trueMean) > 1e-2 { | ||
t.Errorf("Mean mismatch: Want %v, got %v", trueMean, ev) | ||
} | ||
} | ||
|
||
func TestRejection(t *testing.T) { | ||
// Test by finding the expected value of a Normal | ||
trueMean := 3.0 | ||
target := dist.Normal{Mu: trueMean, Sigma: 2} | ||
proposal := dist.Normal{Mu: 0, Sigma: 5} | ||
|
||
nSamples := 100000 | ||
x := make([]float64, nSamples) | ||
Rejection(x, target, proposal, 100, nil) | ||
ev := stat.Mean(x, nil) | ||
if math.Abs(ev-trueMean) > 1e-2 { | ||
t.Errorf("Mean mismatch: Want %v, got %v", trueMean, ev) | ||
} | ||
} | ||
|
||
type condNorm struct { | ||
Sigma float64 | ||
} | ||
|
||
func (c condNorm) ConditionalRand(y float64) float64 { | ||
return dist.Normal{Mu: y, Sigma: c.Sigma}.Rand() | ||
} | ||
|
||
func (c condNorm) ConditionalLogProb(x, y float64) float64 { | ||
return dist.Normal{Mu: y, Sigma: c.Sigma}.LogProb(x) | ||
} | ||
|
||
func TestMetropolisHastings(t *testing.T) { | ||
// Test by finding the expected value of a Normal | ||
trueMean := 3.0 | ||
target := dist.Normal{Mu: trueMean, Sigma: 2} | ||
proposal := condNorm{Sigma: 5} | ||
|
||
burnin := 500 | ||
nSamples := 100000 + burnin | ||
x := make([]float64, nSamples) | ||
MetropolisHastings(x, 100, target, proposal, nil) | ||
// Remove burnin | ||
x = x[burnin:] | ||
ev := stat.Mean(x, nil) | ||
if math.Abs(ev-trueMean) > 1e-2 { | ||
t.Errorf("Mean mismatch: Want %v, got %v", trueMean, ev) | ||
} | ||
} |
Full stop.