This repository has been archived by the owner on Dec 22, 2018. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #159 from gonum/statdist
Add distance functions between probability distributions
- Loading branch information
Showing
2 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
// Copyright ©2016 The gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package distmv | ||
|
||
import ( | ||
"math" | ||
|
||
"github.com/gonum/floats" | ||
"github.com/gonum/matrix/mat64" | ||
"github.com/gonum/stat" | ||
) | ||
|
||
// Bhattacharyya is a type for computing the Bhattacharyya distance between | ||
// probability distributions. | ||
// | ||
// The Battachara distance is defined as | ||
// D_B = -ln(BC(l,r)) | ||
// BC = \int_x (p(x)q(x))^(1/2) dx | ||
// Where BC is known as the Bhattacharyya coefficient. | ||
// The Bhattacharyya distance is related to the Hellinger distance by | ||
// H = sqrt(1-BC) | ||
// For more information, see | ||
// https://en.wikipedia.org/wiki/Bhattacharyya_distance | ||
type Bhattacharyya struct{} | ||
|
||
// DistNormal computes the Bhattacharyya distance between normal distributions l and r. | ||
// The dimensions of the input distributions must match or DistNormal will panic. | ||
// | ||
// For Normal distributions, the Bhattacharyya distance is | ||
// Σ = (Σ_l + Σ_r)/2 | ||
// D_B = (1/8)*(μ_l - μ_r)^T*Σ^-1*(μ_l - μ_r) + (1/2)*ln(det(Σ)/(det(Σ_l)*det(Σ_r))^(1/2)) | ||
func (Bhattacharyya) DistNormal(l, r *Normal) float64 { | ||
dim := l.Dim() | ||
if dim != r.Dim() { | ||
panic(badSizeMismatch) | ||
} | ||
|
||
var sigma mat64.SymDense | ||
sigma.AddSym(&l.sigma, &r.sigma) | ||
sigma.ScaleSym(0.5, &sigma) | ||
|
||
var chol mat64.Cholesky | ||
chol.Factorize(&sigma) | ||
|
||
mahalanobis := stat.Mahalanobis(mat64.NewVector(dim, l.mu), mat64.NewVector(dim, r.mu), &chol) | ||
mahalanobisSq := mahalanobis * mahalanobis | ||
|
||
dl := l.chol.LogDet() | ||
dr := r.chol.LogDet() | ||
ds := chol.LogDet() | ||
|
||
return 0.125*mahalanobisSq + 0.5*ds - 0.25*dl - 0.25*dr | ||
} | ||
|
||
// CrossEntropy is a type for computing the cross-entropy between probability | ||
// distributions. | ||
// | ||
// The cross-entropy is defined as | ||
// - \int_x l(x) log(r(x)) dx = KL(l || r) + H(l) | ||
// where KL is the Kullback-Leibler divergence and H is the entropy. | ||
// For more information, see | ||
// https://en.wikipedia.org/wiki/Cross_entropy | ||
type CrossEntropy struct{} | ||
|
||
// DistNormal returns the cross-entropy between normal distributions l and r. | ||
// The dimensions of the input distributions must match or DistNormal will panic. | ||
func (CrossEntropy) DistNormal(l, r *Normal) float64 { | ||
if l.Dim() != r.Dim() { | ||
panic(badSizeMismatch) | ||
} | ||
kl := KullbackLeibler{}.DistNormal(l, r) | ||
return kl + l.Entropy() | ||
} | ||
|
||
// Hellinger is a type for computing the Hellinger distance between probability | ||
// distributions. | ||
// | ||
// The Hellinger distance is defined as | ||
// H^2(l,r) = 1/2 * int_x (\sqrt(l(x)) - \sqrt(r(x)))^2 dx | ||
// and is bounded between 0 and 1. | ||
// The Hellinger distance is related to the Bhattacharyya distance by | ||
// H^2 = 1 - exp(-Db) | ||
// For more information, see | ||
// https://en.wikipedia.org/wiki/Hellinger_distance | ||
type Hellinger struct{} | ||
|
||
// DistNormal returns the Hellinger distance between normal distributions l and r. | ||
// The dimensions of the input distributions must match or DistNormal will panic. | ||
// | ||
// See the documentation of Bhattacharyya.DistNormal for the formula for Normal | ||
// distributions. | ||
func (Hellinger) DistNormal(l, r *Normal) float64 { | ||
if l.Dim() != r.Dim() { | ||
panic(badSizeMismatch) | ||
} | ||
db := Bhattacharyya{}.DistNormal(l, r) | ||
bc := math.Exp(-db) | ||
return math.Sqrt(1 - bc) | ||
} | ||
|
||
// KullbackLiebler is a type for computing the Kullback-Leibler divergence from l to r. | ||
// The dimensions of the input distributions must match or the function will panic. | ||
// | ||
// The Kullback-Liebler divergence is defined as | ||
// D_KL(l || r ) = \int_x p(x) log(p(x)/q(x)) dx | ||
// Note that the Kullback-Liebler divergence is not symmetric with respect to | ||
// the order of the input arguments. | ||
type KullbackLeibler struct{} | ||
|
||
// DistNormal returns the KullbackLeibler distance between normal distributions l and r. | ||
// The dimensions of the input distributions must match or DistNormal will panic. | ||
// | ||
// For two normal distributions, the KL divergence is computed as | ||
// D_KL(l || r) = 0.5*[ln(|Σ_r|) - ln(|Σ_l|) + (μ_l - μ_r)^T*Σ_r^-1*(μ_l - μ_r) + tr(Σ_r^-1*Σ_l)-d] | ||
func (KullbackLeibler) DistNormal(l, r *Normal) float64 { | ||
dim := l.Dim() | ||
if dim != r.Dim() { | ||
panic(badSizeMismatch) | ||
} | ||
|
||
mahalanobis := stat.Mahalanobis(mat64.NewVector(dim, l.mu), mat64.NewVector(dim, r.mu), &r.chol) | ||
mahalanobisSq := mahalanobis * mahalanobis | ||
|
||
// TODO(btracey): Optimize where there is a SolveCholeskySym | ||
// TODO(btracey): There may be a more efficient way to just compute the trace | ||
// Compute tr(Σ_r^-1*Σ_l) using the fact that Σ_l = U^T * U | ||
var u mat64.TriDense | ||
u.UFromCholesky(&l.chol) | ||
var m mat64.Dense | ||
err := m.SolveCholesky(&r.chol, u.T()) | ||
if err != nil { | ||
return math.NaN() | ||
} | ||
m.Mul(&m, &u) | ||
tr := mat64.Trace(&m) | ||
|
||
return r.logSqrtDet - l.logSqrtDet + 0.5*(mahalanobisSq+tr-float64(l.dim)) | ||
} | ||
|
||
// Wasserstein is a type for computing the Wasserstein distance between two | ||
// probability distributions. | ||
// | ||
// The Wasserstein distance is defined as | ||
// W(l,r) := inf 𝔼(||X-Y||_2^2)^1/2 | ||
// For more information, see | ||
// https://en.wikipedia.org/wiki/Wasserstein_metric | ||
type Wasserstein struct{} | ||
|
||
// DistNormal returns the Wasserstein distance between normal distributions l and r. | ||
// The dimensions of the input distributions must match or DistNormal will panic. | ||
// | ||
// The Wasserstein distance for Normal distributions is | ||
// d^2 = ||m_l - m_r||_2^2 + Tr(Σ_l + Σ_r - 2(Σ_l^(1/2)*Σ_r*Σ_l^(1/2))^(1/2)) | ||
// For more information, see | ||
// http://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ | ||
func (Wasserstein) DistNormal(l, r *Normal) float64 { | ||
dim := l.Dim() | ||
if dim != r.Dim() { | ||
panic(badSizeMismatch) | ||
} | ||
|
||
d := floats.Distance(l.mu, r.mu, 2) | ||
d = d * d | ||
|
||
// Compute Σ_l^(1/2) | ||
var ssl mat64.SymDense | ||
ssl.PowPSD(&l.sigma, 0.5) | ||
// Compute Σ_l^(1/2)*Σ_r*Σ_l^(1/2) | ||
var mean mat64.Dense | ||
mean.Mul(&ssl, &r.sigma) | ||
mean.Mul(&mean, &ssl) | ||
|
||
// Reinterpret as symdense, and take Σ^(1/2) | ||
meanSym := mat64.NewSymDense(dim, mean.RawMatrix().Data) | ||
ssl.PowPSD(meanSym, 0.5) | ||
|
||
tr := mat64.Trace(&r.sigma) | ||
tl := mat64.Trace(&l.sigma) | ||
tm := mat64.Trace(&ssl) | ||
|
||
return d + tl + tr - 2*tm | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
// Copyright ©2016 The gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package distmv | ||
|
||
import ( | ||
"math" | ||
"math/rand" | ||
"testing" | ||
|
||
"github.com/gonum/floats" | ||
"github.com/gonum/matrix/mat64" | ||
) | ||
|
||
func TestBhattacharyyaNormal(t *testing.T) { | ||
for cas, test := range []struct { | ||
am, bm []float64 | ||
ac, bc *mat64.SymDense | ||
samples int | ||
tol float64 | ||
}{ | ||
{ | ||
am: []float64{2, 3}, | ||
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}), | ||
bm: []float64{-1, 1}, | ||
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), | ||
samples: 100000, | ||
tol: 1e-2, | ||
}, | ||
} { | ||
rnd := rand.New(rand.NewSource(1)) | ||
a, ok := NewNormal(test.am, test.ac, rnd) | ||
if !ok { | ||
panic("bad test") | ||
} | ||
b, ok := NewNormal(test.bm, test.bc, rnd) | ||
if !ok { | ||
panic("bad test") | ||
} | ||
lBhatt := make([]float64, test.samples) | ||
x := make([]float64, a.Dim()) | ||
for i := 0; i < test.samples; i++ { | ||
// Do importance sampling over a: \int sqrt(a*b)/a * a dx | ||
a.Rand(x) | ||
pa := a.LogProb(x) | ||
pb := b.LogProb(x) | ||
lBhatt[i] = 0.5*pb - 0.5*pa | ||
} | ||
logBc := floats.LogSumExp(lBhatt) - math.Log(float64(test.samples)) | ||
db := -logBc | ||
got := Bhattacharyya{}.DistNormal(a, b) | ||
if math.Abs(db-got) > test.tol { | ||
t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, db) | ||
} | ||
} | ||
} | ||
|
||
func TestCrossEntropyNormal(t *testing.T) { | ||
for cas, test := range []struct { | ||
am, bm []float64 | ||
ac, bc *mat64.SymDense | ||
samples int | ||
tol float64 | ||
}{ | ||
{ | ||
am: []float64{2, 3}, | ||
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}), | ||
bm: []float64{-1, 1}, | ||
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), | ||
samples: 100000, | ||
tol: 1e-2, | ||
}, | ||
} { | ||
rnd := rand.New(rand.NewSource(1)) | ||
a, ok := NewNormal(test.am, test.ac, rnd) | ||
if !ok { | ||
panic("bad test") | ||
} | ||
b, ok := NewNormal(test.bm, test.bc, rnd) | ||
if !ok { | ||
panic("bad test") | ||
} | ||
var ce float64 | ||
x := make([]float64, a.Dim()) | ||
for i := 0; i < test.samples; i++ { | ||
a.Rand(x) | ||
ce -= b.LogProb(x) | ||
} | ||
ce /= float64(test.samples) | ||
got := CrossEntropy{}.DistNormal(a, b) | ||
if math.Abs(ce-got) > test.tol { | ||
t.Errorf("CrossEntropy mismatch, case %d: got %v, want %v", cas, got, ce) | ||
} | ||
} | ||
} | ||
|
||
func TestHellingerNormal(t *testing.T) { | ||
for cas, test := range []struct { | ||
am, bm []float64 | ||
ac, bc *mat64.SymDense | ||
samples int | ||
tol float64 | ||
}{ | ||
{ | ||
am: []float64{2, 3}, | ||
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}), | ||
bm: []float64{-1, 1}, | ||
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), | ||
samples: 100000, | ||
tol: 5e-1, | ||
}, | ||
} { | ||
rnd := rand.New(rand.NewSource(1)) | ||
a, ok := NewNormal(test.am, test.ac, rnd) | ||
if !ok { | ||
panic("bad test") | ||
} | ||
b, ok := NewNormal(test.bm, test.bc, rnd) | ||
if !ok { | ||
panic("bad test") | ||
} | ||
lAitchEDoubleHockeySticks := make([]float64, test.samples) | ||
x := make([]float64, a.Dim()) | ||
for i := 0; i < test.samples; i++ { | ||
// Do importance sampling over a: \int (\sqrt(a)-\sqrt(b))^2/a * a dx | ||
a.Rand(x) | ||
pa := a.LogProb(x) | ||
pb := b.LogProb(x) | ||
d := math.Exp(0.5*pa) - math.Exp(0.5*pb) | ||
d = d * d | ||
lAitchEDoubleHockeySticks[i] = math.Log(d) - pa | ||
} | ||
want := math.Sqrt(0.5 * math.Exp(floats.LogSumExp(lAitchEDoubleHockeySticks)-math.Log(float64(test.samples)))) | ||
got := Hellinger{}.DistNormal(a, b) | ||
if math.Abs(want-got) > test.tol { | ||
t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want) | ||
} | ||
} | ||
} | ||
|
||
func TestKullbackLieblerNormal(t *testing.T) { | ||
for cas, test := range []struct { | ||
am, bm []float64 | ||
ac, bc *mat64.SymDense | ||
samples int | ||
tol float64 | ||
}{ | ||
{ | ||
am: []float64{2, 3}, | ||
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}), | ||
bm: []float64{-1, 1}, | ||
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), | ||
samples: 10000, | ||
tol: 1e-2, | ||
}, | ||
} { | ||
rnd := rand.New(rand.NewSource(1)) | ||
a, ok := NewNormal(test.am, test.ac, rnd) | ||
if !ok { | ||
panic("bad test") | ||
} | ||
b, ok := NewNormal(test.bm, test.bc, rnd) | ||
if !ok { | ||
panic("bad test") | ||
} | ||
var klmc float64 | ||
x := make([]float64, a.Dim()) | ||
for i := 0; i < test.samples; i++ { | ||
a.Rand(x) | ||
pa := a.LogProb(x) | ||
pb := b.LogProb(x) | ||
klmc += pa - pb | ||
} | ||
klmc /= float64(test.samples) | ||
kl := KullbackLeibler{}.DistNormal(a, b) | ||
if !floats.EqualWithinAbsOrRel(kl, klmc, test.tol, test.tol) { | ||
t.Errorf("Case %d, KL mismatch: got %v, want %v", cas, kl, klmc) | ||
} | ||
} | ||
} |