Skip to content
This repository has been archived by the owner on Dec 22, 2018. It is now read-only.

Commit

Permalink
Merge pull request #159 from gonum/statdist
Browse files Browse the repository at this point in the history
Add distance functions between probability distributions
  • Loading branch information
btracey committed May 6, 2017
2 parents 7fad93f + c33a2a3 commit cd35374
Show file tree
Hide file tree
Showing 2 changed files with 365 additions and 0 deletions.
184 changes: 184 additions & 0 deletions distmv/statdist.go
@@ -0,0 +1,184 @@
// Copyright ©2016 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package distmv

import (
"math"

"github.com/gonum/floats"
"github.com/gonum/matrix/mat64"
"github.com/gonum/stat"
)

// Bhattacharyya is a type for computing the Bhattacharyya distance between
// probability distributions.
//
// The Battachara distance is defined as
// D_B = -ln(BC(l,r))
// BC = \int_x (p(x)q(x))^(1/2) dx
// Where BC is known as the Bhattacharyya coefficient.
// The Bhattacharyya distance is related to the Hellinger distance by
// H = sqrt(1-BC)
// For more information, see
// https://en.wikipedia.org/wiki/Bhattacharyya_distance
type Bhattacharyya struct{}

// DistNormal computes the Bhattacharyya distance between normal distributions l and r.
// The dimensions of the input distributions must match or DistNormal will panic.
//
// For Normal distributions, the Bhattacharyya distance is
// Σ = (Σ_l + Σ_r)/2
// D_B = (1/8)*(μ_l - μ_r)^T*Σ^-1*(μ_l - μ_r) + (1/2)*ln(det(Σ)/(det(Σ_l)*det(Σ_r))^(1/2))
func (Bhattacharyya) DistNormal(l, r *Normal) float64 {
dim := l.Dim()
if dim != r.Dim() {
panic(badSizeMismatch)
}

var sigma mat64.SymDense
sigma.AddSym(&l.sigma, &r.sigma)
sigma.ScaleSym(0.5, &sigma)

var chol mat64.Cholesky
chol.Factorize(&sigma)

mahalanobis := stat.Mahalanobis(mat64.NewVector(dim, l.mu), mat64.NewVector(dim, r.mu), &chol)
mahalanobisSq := mahalanobis * mahalanobis

dl := l.chol.LogDet()
dr := r.chol.LogDet()
ds := chol.LogDet()

return 0.125*mahalanobisSq + 0.5*ds - 0.25*dl - 0.25*dr
}

// CrossEntropy is a type for computing the cross-entropy between probability
// distributions.
//
// The cross-entropy is defined as
// - \int_x l(x) log(r(x)) dx = KL(l || r) + H(l)
// where KL is the Kullback-Leibler divergence and H is the entropy.
// For more information, see
// https://en.wikipedia.org/wiki/Cross_entropy
type CrossEntropy struct{}

// DistNormal returns the cross-entropy between normal distributions l and r.
// The dimensions of the input distributions must match or DistNormal will panic.
func (CrossEntropy) DistNormal(l, r *Normal) float64 {
if l.Dim() != r.Dim() {
panic(badSizeMismatch)
}
kl := KullbackLeibler{}.DistNormal(l, r)
return kl + l.Entropy()
}

// Hellinger is a type for computing the Hellinger distance between probability
// distributions.
//
// The Hellinger distance is defined as
// H^2(l,r) = 1/2 * int_x (\sqrt(l(x)) - \sqrt(r(x)))^2 dx
// and is bounded between 0 and 1.
// The Hellinger distance is related to the Bhattacharyya distance by
// H^2 = 1 - exp(-Db)
// For more information, see
// https://en.wikipedia.org/wiki/Hellinger_distance
type Hellinger struct{}

// DistNormal returns the Hellinger distance between normal distributions l and r.
// The dimensions of the input distributions must match or DistNormal will panic.
//
// See the documentation of Bhattacharyya.DistNormal for the formula for Normal
// distributions.
func (Hellinger) DistNormal(l, r *Normal) float64 {
if l.Dim() != r.Dim() {
panic(badSizeMismatch)
}
db := Bhattacharyya{}.DistNormal(l, r)
bc := math.Exp(-db)
return math.Sqrt(1 - bc)
}

// KullbackLiebler is a type for computing the Kullback-Leibler divergence from l to r.
// The dimensions of the input distributions must match or the function will panic.
//
// The Kullback-Liebler divergence is defined as
// D_KL(l || r ) = \int_x p(x) log(p(x)/q(x)) dx
// Note that the Kullback-Liebler divergence is not symmetric with respect to
// the order of the input arguments.
type KullbackLeibler struct{}

// DistNormal returns the KullbackLeibler distance between normal distributions l and r.
// The dimensions of the input distributions must match or DistNormal will panic.
//
// For two normal distributions, the KL divergence is computed as
// D_KL(l || r) = 0.5*[ln(|Σ_r|) - ln(|Σ_l|) + (μ_l - μ_r)^T*Σ_r^-1*(μ_l - μ_r) + tr(Σ_r^-1*Σ_l)-d]
func (KullbackLeibler) DistNormal(l, r *Normal) float64 {
dim := l.Dim()
if dim != r.Dim() {
panic(badSizeMismatch)
}

mahalanobis := stat.Mahalanobis(mat64.NewVector(dim, l.mu), mat64.NewVector(dim, r.mu), &r.chol)
mahalanobisSq := mahalanobis * mahalanobis

// TODO(btracey): Optimize where there is a SolveCholeskySym
// TODO(btracey): There may be a more efficient way to just compute the trace
// Compute tr(Σ_r^-1*Σ_l) using the fact that Σ_l = U^T * U
var u mat64.TriDense
u.UFromCholesky(&l.chol)
var m mat64.Dense
err := m.SolveCholesky(&r.chol, u.T())
if err != nil {
return math.NaN()
}
m.Mul(&m, &u)
tr := mat64.Trace(&m)

return r.logSqrtDet - l.logSqrtDet + 0.5*(mahalanobisSq+tr-float64(l.dim))
}

// Wasserstein is a type for computing the Wasserstein distance between two
// probability distributions.
//
// The Wasserstein distance is defined as
// W(l,r) := inf 𝔼(||X-Y||_2^2)^1/2
// For more information, see
// https://en.wikipedia.org/wiki/Wasserstein_metric
type Wasserstein struct{}

// DistNormal returns the Wasserstein distance between normal distributions l and r.
// The dimensions of the input distributions must match or DistNormal will panic.
//
// The Wasserstein distance for Normal distributions is
// d^2 = ||m_l - m_r||_2^2 + Tr(Σ_l + Σ_r - 2(Σ_l^(1/2)*Σ_r*Σ_l^(1/2))^(1/2))
// For more information, see
// http://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/
func (Wasserstein) DistNormal(l, r *Normal) float64 {
dim := l.Dim()
if dim != r.Dim() {
panic(badSizeMismatch)
}

d := floats.Distance(l.mu, r.mu, 2)
d = d * d

// Compute Σ_l^(1/2)
var ssl mat64.SymDense
ssl.PowPSD(&l.sigma, 0.5)
// Compute Σ_l^(1/2)*Σ_r*Σ_l^(1/2)
var mean mat64.Dense
mean.Mul(&ssl, &r.sigma)
mean.Mul(&mean, &ssl)

// Reinterpret as symdense, and take Σ^(1/2)
meanSym := mat64.NewSymDense(dim, mean.RawMatrix().Data)
ssl.PowPSD(meanSym, 0.5)

tr := mat64.Trace(&r.sigma)
tl := mat64.Trace(&l.sigma)
tm := mat64.Trace(&ssl)

return d + tl + tr - 2*tm
}
181 changes: 181 additions & 0 deletions distmv/statdist_test.go
@@ -0,0 +1,181 @@
// Copyright ©2016 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package distmv

import (
"math"
"math/rand"
"testing"

"github.com/gonum/floats"
"github.com/gonum/matrix/mat64"
)

func TestBhattacharyyaNormal(t *testing.T) {
for cas, test := range []struct {
am, bm []float64
ac, bc *mat64.SymDense
samples int
tol float64
}{
{
am: []float64{2, 3},
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}),
bm: []float64{-1, 1},
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
samples: 100000,
tol: 1e-2,
},
} {
rnd := rand.New(rand.NewSource(1))
a, ok := NewNormal(test.am, test.ac, rnd)
if !ok {
panic("bad test")
}
b, ok := NewNormal(test.bm, test.bc, rnd)
if !ok {
panic("bad test")
}
lBhatt := make([]float64, test.samples)
x := make([]float64, a.Dim())
for i := 0; i < test.samples; i++ {
// Do importance sampling over a: \int sqrt(a*b)/a * a dx
a.Rand(x)
pa := a.LogProb(x)
pb := b.LogProb(x)
lBhatt[i] = 0.5*pb - 0.5*pa
}
logBc := floats.LogSumExp(lBhatt) - math.Log(float64(test.samples))
db := -logBc
got := Bhattacharyya{}.DistNormal(a, b)
if math.Abs(db-got) > test.tol {
t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, db)
}
}
}

func TestCrossEntropyNormal(t *testing.T) {
for cas, test := range []struct {
am, bm []float64
ac, bc *mat64.SymDense
samples int
tol float64
}{
{
am: []float64{2, 3},
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}),
bm: []float64{-1, 1},
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
samples: 100000,
tol: 1e-2,
},
} {
rnd := rand.New(rand.NewSource(1))
a, ok := NewNormal(test.am, test.ac, rnd)
if !ok {
panic("bad test")
}
b, ok := NewNormal(test.bm, test.bc, rnd)
if !ok {
panic("bad test")
}
var ce float64
x := make([]float64, a.Dim())
for i := 0; i < test.samples; i++ {
a.Rand(x)
ce -= b.LogProb(x)
}
ce /= float64(test.samples)
got := CrossEntropy{}.DistNormal(a, b)
if math.Abs(ce-got) > test.tol {
t.Errorf("CrossEntropy mismatch, case %d: got %v, want %v", cas, got, ce)
}
}
}

func TestHellingerNormal(t *testing.T) {
for cas, test := range []struct {
am, bm []float64
ac, bc *mat64.SymDense
samples int
tol float64
}{
{
am: []float64{2, 3},
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}),
bm: []float64{-1, 1},
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
samples: 100000,
tol: 5e-1,
},
} {
rnd := rand.New(rand.NewSource(1))
a, ok := NewNormal(test.am, test.ac, rnd)
if !ok {
panic("bad test")
}
b, ok := NewNormal(test.bm, test.bc, rnd)
if !ok {
panic("bad test")
}
lAitchEDoubleHockeySticks := make([]float64, test.samples)
x := make([]float64, a.Dim())
for i := 0; i < test.samples; i++ {
// Do importance sampling over a: \int (\sqrt(a)-\sqrt(b))^2/a * a dx
a.Rand(x)
pa := a.LogProb(x)
pb := b.LogProb(x)
d := math.Exp(0.5*pa) - math.Exp(0.5*pb)
d = d * d
lAitchEDoubleHockeySticks[i] = math.Log(d) - pa
}
want := math.Sqrt(0.5 * math.Exp(floats.LogSumExp(lAitchEDoubleHockeySticks)-math.Log(float64(test.samples))))
got := Hellinger{}.DistNormal(a, b)
if math.Abs(want-got) > test.tol {
t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want)
}
}
}

func TestKullbackLieblerNormal(t *testing.T) {
for cas, test := range []struct {
am, bm []float64
ac, bc *mat64.SymDense
samples int
tol float64
}{
{
am: []float64{2, 3},
ac: mat64.NewSymDense(2, []float64{3, -1, -1, 2}),
bm: []float64{-1, 1},
bc: mat64.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
samples: 10000,
tol: 1e-2,
},
} {
rnd := rand.New(rand.NewSource(1))
a, ok := NewNormal(test.am, test.ac, rnd)
if !ok {
panic("bad test")
}
b, ok := NewNormal(test.bm, test.bc, rnd)
if !ok {
panic("bad test")
}
var klmc float64
x := make([]float64, a.Dim())
for i := 0; i < test.samples; i++ {
a.Rand(x)
pa := a.LogProb(x)
pb := b.LogProb(x)
klmc += pa - pb
}
klmc /= float64(test.samples)
kl := KullbackLeibler{}.DistNormal(a, b)
if !floats.EqualWithinAbsOrRel(kl, klmc, test.tol, test.tol) {
t.Errorf("Case %d, KL mismatch: got %v, want %v", cas, kl, klmc)
}
}
}

0 comments on commit cd35374

Please sign in to comment.