diff --git a/distmv/statdist.go b/distmv/statdist.go new file mode 100644 index 0000000..9dab26b --- /dev/null +++ b/distmv/statdist.go @@ -0,0 +1,184 @@ +// Copyright ©2016 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package distmv + +import ( + "math" + + "github.com/gonum/floats" + "github.com/gonum/matrix/mat64" + "github.com/gonum/stat" +) + +// Bhattacharyya is a type for computing the Bhattacharyya distance between +// probability distributions. +// +// The Battachara distance is defined as +// D_B = -ln(BC(l,r)) +// BC = \int_x (p(x)q(x))^(1/2) dx +// Where BC is known as the Bhattacharyya coefficient. +// The Bhattacharyya distance is related to the Hellinger distance by +// H = sqrt(1-BC) +// For more information, see +// https://en.wikipedia.org/wiki/Bhattacharyya_distance +type Bhattacharyya struct{} + +// DistNormal computes the Bhattacharyya distance between normal distributions l and r. +// The dimensions of the input distributions must match or DistNormal will panic. +// +// For Normal distributions, the Bhattacharyya distance is +// Σ = (Σ_l + Σ_r)/2 +// D_B = (1/8)*(μ_l - μ_r)^T*Σ^-1*(μ_l - μ_r) + (1/2)*ln(det(Σ)/(det(Σ_l)*det(Σ_r))^(1/2)) +func (Bhattacharyya) DistNormal(l, r *Normal) float64 { + dim := l.Dim() + if dim != r.Dim() { + panic(badSizeMismatch) + } + + var sigma mat64.SymDense + sigma.AddSym(&l.sigma, &r.sigma) + sigma.ScaleSym(0.5, &sigma) + + var chol mat64.Cholesky + chol.Factorize(&sigma) + + mahalanobis := stat.Mahalanobis(mat64.NewVector(dim, l.mu), mat64.NewVector(dim, r.mu), &chol) + mahalanobisSq := mahalanobis * mahalanobis + + dl := l.chol.LogDet() + dr := r.chol.LogDet() + ds := chol.LogDet() + + return 0.125*mahalanobisSq + 0.5*ds - 0.25*dl - 0.25*dr +} + +// CrossEntropy is a type for computing the cross-entropy between probability +// distributions. +// +// The cross-entropy is defined as +// - \int_x l(x) log(r(x)) dx = KL(l || r) + H(l) +// where KL is the Kullback-Leibler divergence and H is the entropy. +// For more information, see +// https://en.wikipedia.org/wiki/Cross_entropy +type CrossEntropy struct{} + +// DistNormal returns the cross-entropy between normal distributions l and r. +// The dimensions of the input distributions must match or DistNormal will panic. +func (CrossEntropy) DistNormal(l, r *Normal) float64 { + if l.Dim() != r.Dim() { + panic(badSizeMismatch) + } + kl := KullbackLeibler{}.DistNormal(l, r) + return kl + l.Entropy() +} + +// Hellinger is a type for computing the Hellinger distance between probability +// distributions. +// +// The Hellinger distance is defined as +// H^2(l,r) = 1/2 * int_x (\sqrt(l(x)) - \sqrt(r(x)))^2 dx +// and is bounded between 0 and 1. +// The Hellinger distance is related to the Bhattacharyya distance by +// H^2 = 1 - exp(-Db) +// For more information, see +// https://en.wikipedia.org/wiki/Hellinger_distance +type Hellinger struct{} + +// DistNormal returns the Hellinger distance between normal distributions l and r. +// The dimensions of the input distributions must match or DistNormal will panic. +// +// See the documentation of Bhattacharyya.DistNormal for the formula for Normal +// distributions. +func (Hellinger) DistNormal(l, r *Normal) float64 { + if l.Dim() != r.Dim() { + panic(badSizeMismatch) + } + db := Bhattacharyya{}.DistNormal(l, r) + bc := math.Exp(-db) + return math.Sqrt(1 - bc) +} + +// KullbackLiebler is a type for computing the Kullback-Leibler divergence from l to r. +// The dimensions of the input distributions must match or the function will panic. +// +// The Kullback-Liebler divergence is defined as +// D_KL(l || r ) = \int_x p(x) log(p(x)/q(x)) dx +// Note that the Kullback-Liebler divergence is not symmetric with respect to +// the order of the input arguments. +type KullbackLeibler struct{} + +// DistNormal returns the KullbackLeibler distance between normal distributions l and r. +// The dimensions of the input distributions must match or DistNormal will panic. +// +// For two normal distributions, the KL divergence is computed as +// D_KL(l || r) = 0.5*[ln(|Σ_r|) - ln(|Σ_l|) + (μ_l - μ_r)^T*Σ_r^-1*(μ_l - μ_r) + tr(Σ_r^-1*Σ_l)-d] +func (KullbackLeibler) DistNormal(l, r *Normal) float64 { + dim := l.Dim() + if dim != r.Dim() { + panic(badSizeMismatch) + } + + mahalanobis := stat.Mahalanobis(mat64.NewVector(dim, l.mu), mat64.NewVector(dim, r.mu), &r.chol) + mahalanobisSq := mahalanobis * mahalanobis + + // TODO(btracey): Optimize where there is a SolveCholeskySym + // TODO(btracey): There may be a more efficient way to just compute the trace + // Compute tr(Σ_r^-1*Σ_l) using the fact that Σ_l = U^T * U + var u mat64.TriDense + u.UFromCholesky(&l.chol) + var m mat64.Dense + err := m.SolveCholesky(&r.chol, u.T()) + if err != nil { + return math.NaN() + } + m.Mul(&m, &u) + tr := mat64.Trace(&m) + + return r.logSqrtDet - l.logSqrtDet + 0.5*(mahalanobisSq+tr-float64(l.dim)) +} + +// Wasserstein is a type for computing the Wasserstein distance between two +// probability distributions. +// +// The Wasserstein distance is defined as +// W(l,r) := inf 𝔼(||X-Y||_2^2)^1/2 +// For more information, see +// https://en.wikipedia.org/wiki/Wasserstein_metric +type Wasserstein struct{} + +// DistNormal returns the Wasserstein distance between normal distributions l and r. +// The dimensions of the input distributions must match or DistNormal will panic. +// +// The Wasserstein distance for Normal distributions is +// d^2 = ||m_l - m_r||_2^2 + Tr(Σ_l + Σ_r - 2(Σ_l^(1/2)*Σ_r*Σ_l^(1/2))^(1/2)) +// For more information, see +// http://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ +func (Wasserstein) DistNormal(l, r *Normal) float64 { + dim := l.Dim() + if dim != r.Dim() { + panic(badSizeMismatch) + } + + d := floats.Distance(l.mu, r.mu, 2) + d = d * d + + // Compute Σ_l^(1/2) + var ssl mat64.SymDense + ssl.PowPSD(&l.sigma, 0.5) + // Compute Σ_l^(1/2)*Σ_r*Σ_l^(1/2) + var mean mat64.Dense + mean.Mul(&ssl, &r.sigma) + mean.Mul(&mean, &ssl) + + // Reinterpret as symdense, and take Σ^(1/2) + meanSym := mat64.NewSymDense(dim, mean.RawMatrix().Data) + ssl.PowPSD(meanSym, 0.5) + + tr := mat64.Trace(&r.sigma) + tl := mat64.Trace(&l.sigma) + tm := mat64.Trace(&ssl) + + return d + tl + tr - 2*tm +} diff --git a/distmv/statdist_test.go b/distmv/statdist_test.go new file mode 100644 index 0000000..72985c3 --- /dev/null +++ b/distmv/statdist_test.go @@ -0,0 +1,181 @@ +// Copyright ©2016 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package distmv + +import ( + "math" + "math/rand" + "testing" + + "github.com/gonum/floats" + "github.com/gonum/matrix/mat64" +) + +func TestBhattacharyyaNormal(t *testing.T) { + for cas, test := range []struct { + am, bm []float64 + ac, bc *mat64.SymDense + samples int + tol float64 + }{ + { + am: []float64{2, 3}, + ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}), + bm: []float64{-1, 1}, + bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), + samples: 100000, + tol: 1e-2, + }, + } { + rnd := rand.New(rand.NewSource(1)) + a, ok := NewNormal(test.am, test.ac, rnd) + if !ok { + panic("bad test") + } + b, ok := NewNormal(test.bm, test.bc, rnd) + if !ok { + panic("bad test") + } + lBhatt := make([]float64, test.samples) + x := make([]float64, a.Dim()) + for i := 0; i < test.samples; i++ { + // Do importance sampling over a: \int sqrt(a*b)/a * a dx + a.Rand(x) + pa := a.LogProb(x) + pb := b.LogProb(x) + lBhatt[i] = 0.5*pb - 0.5*pa + } + logBc := floats.LogSumExp(lBhatt) - math.Log(float64(test.samples)) + db := -logBc + got := Bhattacharyya{}.DistNormal(a, b) + if math.Abs(db-got) > test.tol { + t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, db) + } + } +} + +func TestCrossEntropyNormal(t *testing.T) { + for cas, test := range []struct { + am, bm []float64 + ac, bc *mat64.SymDense + samples int + tol float64 + }{ + { + am: []float64{2, 3}, + ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}), + bm: []float64{-1, 1}, + bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), + samples: 100000, + tol: 1e-2, + }, + } { + rnd := rand.New(rand.NewSource(1)) + a, ok := NewNormal(test.am, test.ac, rnd) + if !ok { + panic("bad test") + } + b, ok := NewNormal(test.bm, test.bc, rnd) + if !ok { + panic("bad test") + } + var ce float64 + x := make([]float64, a.Dim()) + for i := 0; i < test.samples; i++ { + a.Rand(x) + ce -= b.LogProb(x) + } + ce /= float64(test.samples) + got := CrossEntropy{}.DistNormal(a, b) + if math.Abs(ce-got) > test.tol { + t.Errorf("CrossEntropy mismatch, case %d: got %v, want %v", cas, got, ce) + } + } +} + +func TestHellingerNormal(t *testing.T) { + for cas, test := range []struct { + am, bm []float64 + ac, bc *mat64.SymDense + samples int + tol float64 + }{ + { + am: []float64{2, 3}, + ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}), + bm: []float64{-1, 1}, + bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), + samples: 100000, + tol: 5e-1, + }, + } { + rnd := rand.New(rand.NewSource(1)) + a, ok := NewNormal(test.am, test.ac, rnd) + if !ok { + panic("bad test") + } + b, ok := NewNormal(test.bm, test.bc, rnd) + if !ok { + panic("bad test") + } + lAitchEDoubleHockeySticks := make([]float64, test.samples) + x := make([]float64, a.Dim()) + for i := 0; i < test.samples; i++ { + // Do importance sampling over a: \int (\sqrt(a)-\sqrt(b))^2/a * a dx + a.Rand(x) + pa := a.LogProb(x) + pb := b.LogProb(x) + d := math.Exp(0.5*pa) - math.Exp(0.5*pb) + d = d * d + lAitchEDoubleHockeySticks[i] = math.Log(d) - pa + } + want := math.Sqrt(0.5 * math.Exp(floats.LogSumExp(lAitchEDoubleHockeySticks)-math.Log(float64(test.samples)))) + got := Hellinger{}.DistNormal(a, b) + if math.Abs(want-got) > test.tol { + t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want) + } + } +} + +func TestKullbackLieblerNormal(t *testing.T) { + for cas, test := range []struct { + am, bm []float64 + ac, bc *mat64.SymDense + samples int + tol float64 + }{ + { + am: []float64{2, 3}, + ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}), + bm: []float64{-1, 1}, + bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), + samples: 10000, + tol: 1e-2, + }, + } { + rnd := rand.New(rand.NewSource(1)) + a, ok := NewNormal(test.am, test.ac, rnd) + if !ok { + panic("bad test") + } + b, ok := NewNormal(test.bm, test.bc, rnd) + if !ok { + panic("bad test") + } + var klmc float64 + x := make([]float64, a.Dim()) + for i := 0; i < test.samples; i++ { + a.Rand(x) + pa := a.LogProb(x) + pb := b.LogProb(x) + klmc += pa - pb + } + klmc /= float64(test.samples) + kl := KullbackLeibler{}.DistNormal(a, b) + if !floats.EqualWithinAbsOrRel(kl, klmc, test.tol, test.tol) { + t.Errorf("Case %d, KL mismatch: got %v, want %v", cas, kl, klmc) + } + } +}