diff --git a/stat/distmv/statdist.go b/stat/distmv/statdist.go index a51bd24cb..3d310b0d0 100644 --- a/stat/distmv/statdist.go +++ b/stat/distmv/statdist.go @@ -7,6 +7,8 @@ package distmv import ( "math" + "gonum.org/v1/gonum/mathext" + "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/mat" "gonum.org/v1/gonum/stat" @@ -142,6 +144,36 @@ func (Hellinger) DistNormal(l, r *Normal) float64 { // the order of the input arguments. type KullbackLeibler struct{} +// DistDirichlet returns the KullbackLeibler distance between Dirichlet +// distributions l and r. The dimensions of the input distributions must match +// or DistDirichlet will panic. +// +// For two Dirichlet distributions, the KL divergence is computed as +// D_KL(l || r) = log Γ(α_0_l) - \sum_i log Γ(α_i_l) - log Γ(α_0_r) + \sum_i log Γ(α_i_r) +// + \sum_i (α_i_l - α_i_r)(ψ(α_i_l)- ψ(α_0_l)) +// Where Γ is the gamma function, ψ is the digamma function, and α_0 is the +// sum of the Dirichlet parameters. +func (KullbackLeibler) DistDirichlet(l, r *Dirichlet) float64 { + // http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ + if l.Dim() != r.Dim() { + panic(badSizeMismatch) + } + l0, _ := math.Lgamma(l.sumAlpha) + r0, _ := math.Lgamma(r.sumAlpha) + dl := mathext.Digamma(l.sumAlpha) + + var l1, r1, c float64 + for i, al := range l.alpha { + ar := r.alpha[i] + vl, _ := math.Lgamma(al) + l1 += vl + vr, _ := math.Lgamma(ar) + r1 += vr + c += (al - ar) * (mathext.Digamma(al) - dl) + } + return l0 - l1 - r0 + r1 + c +} + // DistNormal returns the KullbackLeibler distance between normal distributions l and r. // The dimensions of the input distributions must match or DistNormal will panic. // diff --git a/stat/distmv/statdist_test.go b/stat/distmv/statdist_test.go index becdde23a..9361f09f2 100644 --- a/stat/distmv/statdist_test.go +++ b/stat/distmv/statdist_test.go @@ -186,6 +186,35 @@ func TestHellingerNormal(t *testing.T) { } } +func TestKullbackLeiblerDirichlet(t *testing.T) { + rnd := rand.New(rand.NewSource(1)) + for cas, test := range []struct { + a, b *Dirichlet + samples int + tol float64 + }{ + { + a: NewDirichlet([]float64{2, 3, 4}, rnd), + b: NewDirichlet([]float64{4, 2, 1.1}, rnd), + samples: 100000, + tol: 1e-2, + }, + { + a: NewDirichlet([]float64{2, 3, 4, 0.1, 8}, rnd), + b: NewDirichlet([]float64{2, 2, 6, 0.5, 9}, rnd), + samples: 100000, + tol: 1e-2, + }, + } { + a, b := test.a, test.b + want := klSample(a.Dim(), test.samples, a, b) + got := KullbackLeibler{}.DistDirichlet(a, b) + if !floats.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { + t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want) + } + } +} + func TestKullbackLeiblerNormal(t *testing.T) { for cas, test := range []struct { am, bm []float64