Skip to content

Commit

Permalink
Respond to PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
btracey committed Jul 1, 2017
1 parent e1f479d commit 6d799d8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
23 changes: 11 additions & 12 deletions stat/distmv/normal.go
Expand Up @@ -277,36 +277,35 @@ func (n *Normal) Rand(x []float64) []float64 {
return x
}

// ScoreInput returns the score function with respect to the input of the
// distribution at the location specified by x. The score function is the
// derivative of the log-probability at the input.
// ∂/∂x log(p(x))
// If deriv is nil, a new slice will be allocated and returned. If deriv is of
// ScoreInput returns the gradient of the log-probability with respect to the
// input x. That is, ScoreInput computes
// ∇_x log(p(x))
// If score is nil, a new slice will be allocated and returned. If score is of
// length the dimension of Normal, then the result will be put in-place into deriv.
// If neither of these is true, ScoreInput will panic.
func (n *Normal) ScoreInput(deriv, x []float64) []float64 {
func (n *Normal) ScoreInput(score, x []float64) []float64 {
// Normal log probability is
// c - 0.5*(x-μ)' Σ^-1 (x-μ).
// So the derivative is just
// -Σ^-1 (x-μ).
if len(x) != n.Dim() {
panic(badInputLength)
}
if deriv == nil {
deriv = make([]float64, len(x))
if score == nil {
score = make([]float64, len(x))
}
if len(deriv) != len(x) {
if len(score) != len(x) {
panic(badSizeMismatch)
}
tmp := make([]float64, len(x))
copy(tmp, x)
floats.Sub(tmp, n.mu)

dv := mat.NewVector(len(deriv), deriv)
dv := mat.NewVector(len(score), score)
dt := mat.NewVector(len(tmp), tmp)
dv.SolveCholeskyVec(&n.chol, dt)
floats.Scale(-1, deriv)
return deriv
floats.Scale(-1, score)
return score
}

// SetMean changes the mean of the normal distribution. SetMean panics if len(mu)
Expand Down
8 changes: 4 additions & 4 deletions stat/distmv/normal_test.go
Expand Up @@ -561,13 +561,13 @@ func TestNormalScoreInput(t *testing.T) {
}
x := make([]float64, len(test.x))
copy(x, test.x)
deriv := normal.ScoreInput(nil, x)
score := normal.ScoreInput(nil, x)
if !floats.Equal(x, test.x) {
t.Errorf("x modified during call to ScoreInput")
}
derivFD := fd.Gradient(nil, normal.LogProb, x, nil)
if !floats.EqualApprox(deriv, derivFD, 1e-4) {
t.Errorf("Case %d: derivative mismatch. Got %v, want %v", cas, deriv, derivFD)
scoreFD := fd.Gradient(nil, normal.LogProb, x, nil)
if !floats.EqualApprox(score, scoreFD, 1e-4) {
t.Errorf("Case %d: derivative mismatch. Got %v, want %v", cas, score, scoreFD)
}
}
}

0 comments on commit 6d799d8

Please sign in to comment.