Skip to content

Commit

Permalink
Merge de6a1fa into eac07bb
Browse files Browse the repository at this point in the history
  • Loading branch information
btracey committed Jun 21, 2017
2 parents eac07bb + de6a1fa commit b91ece2
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 2 deletions.
72 changes: 71 additions & 1 deletion stat/distmv/statdist.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// Bhattacharyya is a type for computing the Bhattacharyya distance between
// probability distributions.
//
// The Battachara distance is defined as
// The Battacharyya distance is defined as
// D_B = -ln(BC(l,r))
// BC = \int_x (p(x)q(x))^(1/2) dx
// Where BC is known as the Bhattacharyya coefficient.
Expand Down Expand Up @@ -207,6 +207,76 @@ func (KullbackLeibler) DistUniform(l, r *Uniform) float64 {
return logPx - logQx
}

// Renyi is a type for computing the Rényi divergence of order α from l to r.
// The dimensions of the input distributions must match or the function will panic.
//
// The Rényi divergence with α > 0, α ≠ 1 is defined as
// D_α(l || r) = 1/(α-1) log(\int_x l(x)^α r(x)^(1-α)dx)
// The Rényi divergence has special forms for α = 0 and α = 1. This type does
// not implement α = ∞. For α = 0,
// D_0(l || r) = -log \int_x r(x)1{p(x)>0} dx
// that is, the negative log probability under r(x) that l(x) > 0.
// When α = 1, the Rényi divergence is equal to the Kullback-Leibler divergence.
// The Rényi divergence is also equal to half the Bhattacharyya distance when α = 0.5.
//
// The parameter α must be in 0 ≤ α < ∞ or the distance functions will panic.
type Renyi struct {
Alpha float64
}

// DistNormal returns the Rényi divergence between normal distributions l and r.
// The dimensions of the input distributions must match or DistNormal will panic.
//
// For two normal distributions, the Rényi divergence is computed as
// Σ_α = (1-α) Σ_l + αΣ_r
// D_α(l||r) = α/2 * (μ_l - μ_r)'*Σ_α^-1*(μ_l - μ_r) + 1/(2(α-1))*[ln(|Σ_λ|/(|Σ_l|^(1-α)+|Σ_r|^α)]
//
// For a more nicely formatted version of the formula, see Eq. 15 of
// Kolchinsky, Artemy, and Brendan D. Tracey. "Estimating Mixture Entropy
// with Pairwise Distances." arXiv preprint arXiv:1706.02419 (2017).
// Note that the this formula is for the Chernoff divergence, which differs from
// the Rényi divergence by a factor of 1 - α. Also be aware that most sources in
// the literature report this formula incorrectly.
func (renyi Renyi) DistNormal(l, r *Normal) float64 {
if renyi.Alpha < 0 {
panic("renyi: alpha < 0")
}
dim := l.Dim()
if dim != r.Dim() {
panic(badSizeMismatch)
}
if renyi.Alpha == 0 {
return 0
}
if renyi.Alpha == 1 {
return KullbackLeibler{}.DistNormal(l, r)
}

logDetL := l.chol.LogDet()
logDetR := r.chol.LogDet()

// Σ_α = (1-α)Σ_l + αΣ_r.
sigA := mat.NewSymDense(dim, nil)
for i := 0; i < dim; i++ {
for j := i; j < dim; j++ {
v := (1-renyi.Alpha)*l.sigma.At(i, j) + renyi.Alpha*r.sigma.At(i, j)
sigA.SetSym(i, j, v)
}
}

var chol mat.Cholesky
ok := chol.Factorize(sigA)
if !ok {
return math.NaN()
}
logDetA := chol.LogDet()

mahalanobis := stat.Mahalanobis(mat.NewVector(dim, l.mu), mat.NewVector(dim, r.mu), &chol)
mahalanobisSq := mahalanobis * mahalanobis

return (renyi.Alpha/2)*mahalanobisSq + 1/(2*(1-renyi.Alpha))*(logDetA-(1-renyi.Alpha)*logDetL-renyi.Alpha*logDetR)
}

// Wasserstein is a type for computing the Wasserstein distance between two
// probability distributions.
//
Expand Down
71 changes: 70 additions & 1 deletion stat/distmv/statdist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func TestKullbackLeiblerUniform(t *testing.T) {
}
}

// klSample finds an estimate of the Kullback-Leibler Divergence through sampling.
// klSample finds an estimate of the Kullback-Leibler divergence through sampling.
func klSample(dim, samples int, l RandLogProber, r LogProber) float64 {
var klmc float64
x := make([]float64, dim)
Expand All @@ -259,3 +259,72 @@ func klSample(dim, samples int, l RandLogProber, r LogProber) float64 {
}
return klmc / float64(samples)
}

func TestRenyiNormal(t *testing.T) {
for cas, test := range []struct {
am, bm []float64
ac, bc *mat.SymDense
alpha float64
samples int
tol float64
}{
{
am: []float64{2, 3},
ac: mat.NewSymDense(2, []float64{3, -1, -1, 2}),
bm: []float64{-1, 1},
bc: mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
alpha: 0.3,
samples: 10000,
tol: 1e-2,
},
} {
rnd := rand.New(rand.NewSource(1))
a, ok := NewNormal(test.am, test.ac, rnd)
if !ok {
panic("bad test")
}
b, ok := NewNormal(test.bm, test.bc, rnd)
if !ok {
panic("bad test")
}
want := renyiSample(a.Dim(), test.samples, test.alpha, a, b)
got := Renyi{Alpha: test.alpha}.DistNormal(a, b)
if !floats.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
t.Errorf("Case %d: Renyi sampling mismatch: got %v, want %v", cas, got, want)
}

// Compare with Bhattacharyya.
want = 2 * Bhattacharyya{}.DistNormal(a, b)
got = Renyi{Alpha: 0.5}.DistNormal(a, b)
if math.Abs(want-got) > 1e-10 {
t.Errorf("Case %d: Renyi mismatch with Bhattacharyya: got %v, want %v", cas, got, want)
}

// Compare with KL in both directions.
want = KullbackLeibler{}.DistNormal(a, b)
got = Renyi{Alpha: 0.9999999}.DistNormal(a, b) // very close to 1 but not equal to 1.
if math.Abs(want-got) > 1e-6 {
t.Errorf("Case %d: Renyi mismatch with KL(a||b): got %v, want %v", cas, got, want)
}
want = KullbackLeibler{}.DistNormal(b, a)
got = Renyi{Alpha: 0.9999999}.DistNormal(b, a) // very close to 1 but not equal to 1.
if math.Abs(want-got) > 1e-6 {
t.Errorf("Case %d: Renyi mismatch with KL(b||a): got %v, want %v", cas, got, want)
}
}
}

// renyiSample finds an estimate of the Rényi divergence through sampling.
// Note that this sampling procedure only works if l has broader support than
// r.
func renyiSample(dim, samples int, alpha float64, l RandLogProber, r LogProber) float64 {
rmcs := make([]float64, samples)
x := make([]float64, dim)
for i := 0; i < samples; i++ {
l.Rand(x)
pa := l.LogProb(x)
pb := r.LogProb(x)
rmcs[i] = (alpha-1)*pa + (1-alpha)*pb
}
return 1 / (alpha - 1) * (floats.LogSumExp(rmcs) - math.Log(float64(samples)))
}

0 comments on commit b91ece2

Please sign in to comment.