Skip to content

Commit

Permalink
Merge e1f479d into e56ddb0
Browse files Browse the repository at this point in the history
  • Loading branch information
btracey committed Jun 30, 2017
2 parents e56ddb0 + e1f479d commit 75a340b
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
32 changes: 32 additions & 0 deletions stat/distmv/normal.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,38 @@ func (n *Normal) Rand(x []float64) []float64 {
return x
}

// ScoreInput returns the score function with respect to the input of the
// distribution at the location specified by x. The score function is the
// derivative of the log-probability at the input.
// ∂/∂x log(p(x))
// If deriv is nil, a new slice will be allocated and returned. If deriv is of
// length the dimension of Normal, then the result will be put in-place into deriv.
// If neither of these is true, ScoreInput will panic.
func (n *Normal) ScoreInput(deriv, x []float64) []float64 {
// Normal log probability is
// c - 0.5*(x-μ)' Σ^-1 (x-μ).
// So the derivative is just
// -Σ^-1 (x-μ).
if len(x) != n.Dim() {
panic(badInputLength)
}
if deriv == nil {
deriv = make([]float64, len(x))
}
if len(deriv) != len(x) {
panic(badSizeMismatch)
}
tmp := make([]float64, len(x))
copy(tmp, x)
floats.Sub(tmp, n.mu)

dv := mat.NewVector(len(deriv), deriv)
dt := mat.NewVector(len(tmp), tmp)
dv.SolveCholeskyVec(&n.chol, dt)
floats.Scale(-1, deriv)
return deriv
}

// SetMean changes the mean of the normal distribution. SetMean panics if len(mu)
// does not equal the dimension of the normal distribution.
func (n *Normal) SetMean(mu []float64) {
Expand Down
35 changes: 35 additions & 0 deletions stat/distmv/normal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"math/rand"
"testing"

"gonum.org/v1/gonum/diff/fd"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
Expand Down Expand Up @@ -536,3 +537,37 @@ func TestMarginalSingle(t *testing.T) {
}
}
}

func TestNormalScoreInput(t *testing.T) {
for cas, test := range []struct {
mu []float64
sigma *mat.SymDense
x []float64
}{
{
mu: []float64{2, 3, 4},
sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
x: []float64{1, 3.1, -2},
},
{
mu: []float64{2, 3, 4, 5},
sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
x: []float64{1, 3.1, -2, 5},
},
} {
normal, ok := NewNormal(test.mu, test.sigma, nil)
if !ok {
t.Fatalf("Bad test, covariance matrix not positive definite")
}
x := make([]float64, len(test.x))
copy(x, test.x)
deriv := normal.ScoreInput(nil, x)
if !floats.Equal(x, test.x) {
t.Errorf("x modified during call to ScoreInput")
}
derivFD := fd.Gradient(nil, normal.LogProb, x, nil)
if !floats.EqualApprox(deriv, derivFD, 1e-4) {
t.Errorf("Case %d: derivative mismatch. Got %v, want %v", cas, deriv, derivFD)
}
}
}

0 comments on commit 75a340b

Please sign in to comment.