Skip to content

Commit

Permalink
distmv: Add KL divergence for two Dirichlet distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
btracey committed May 24, 2018
1 parent f5d91d7 commit d20069c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
32 changes: 32 additions & 0 deletions stat/distmv/statdist.go
Expand Up @@ -7,6 +7,8 @@ package distmv
import (
"math"

"gonum.org/v1/gonum/mathext"

"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
Expand Down Expand Up @@ -142,6 +144,36 @@ func (Hellinger) DistNormal(l, r *Normal) float64 {
// the order of the input arguments.
type KullbackLeibler struct{}

// DistDirichlet returns the KullbackLeibler distance between Dirichlet
// distributions l and r. The dimensions of the input distributions must match
// or DistDirichlet will panic.
//
// For two Dirichlet distributions, the KL divergence is computed as
// D_KL(l || r) = log Γ(α_0_l) - \sum_i log Γ(α_i_l) - log Γ(α_0_r) + \sum_i log Γ(α_i_r)
// + \sum_i (α_i_l - α_i_r)(ψ(α_i_l)- ψ(α_0_l))
// Where Γ is the gamma function, ψ is the digamma function, and α_0 is the
// sum of the Dirichlet parameters.
func (KullbackLeibler) DistDirichlet(l, r *Dirichlet) float64 {
// http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
if l.Dim() != r.Dim() {
panic(badSizeMismatch)
}
l0, _ := math.Lgamma(l.sumAlpha)
r0, _ := math.Lgamma(r.sumAlpha)
dl := mathext.Digamma(l.sumAlpha)

var l1, r1, c float64
for i, al := range l.alpha {
ar := r.alpha[i]
vl, _ := math.Lgamma(al)
l1 += vl
vr, _ := math.Lgamma(ar)
r1 += vr
c += (al - ar) * (mathext.Digamma(al) - dl)
}
return l0 - l1 - r0 + r1 + c
}

// DistNormal returns the KullbackLeibler distance between normal distributions l and r.
// The dimensions of the input distributions must match or DistNormal will panic.
//
Expand Down
29 changes: 29 additions & 0 deletions stat/distmv/statdist_test.go
Expand Up @@ -186,6 +186,35 @@ func TestHellingerNormal(t *testing.T) {
}
}

func TestKullbackLeiblerDirichlet(t *testing.T) {
rnd := rand.New(rand.NewSource(1))
for cas, test := range []struct {
a, b *Dirichlet
samples int
tol float64
}{
{
a: NewDirichlet([]float64{2, 3, 4}, rnd),
b: NewDirichlet([]float64{4, 2, 1.1}, rnd),
samples: 100000,
tol: 1e-2,
},
{
a: NewDirichlet([]float64{2, 3, 4, 0.1, 8}, rnd),
b: NewDirichlet([]float64{2, 2, 6, 0.5, 9}, rnd),
samples: 100000,
tol: 1e-2,
},
} {
a, b := test.a, test.b
want := klSample(a.Dim(), test.samples, a, b)
got := KullbackLeibler{}.DistDirichlet(a, b)
if !floats.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want)
}
}
}

func TestKullbackLeiblerNormal(t *testing.T) {
for cas, test := range []struct {
am, bm []float64
Expand Down

0 comments on commit d20069c

Please sign in to comment.