Skip to content
This repository has been archived by the owner on Dec 22, 2018. It is now read-only.

Commit

Permalink
Add tests for Students T
Browse files Browse the repository at this point in the history
  • Loading branch information
btracey committed Oct 25, 2016
1 parent dbc04b6 commit d3c96ed
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 3 deletions.
227 changes: 227 additions & 0 deletions distmv/studentst_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
// Copyright ©2016 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package distmv

import (
"math/rand"
"testing"

"github.com/gonum/floats"
"github.com/gonum/matrix/mat64"
"github.com/gonum/stat"
)

func TestStudentTProbs(t *testing.T) {
src := rand.New(rand.NewSource(1))
for _, test := range []struct {
nu float64
mu []float64
sigma *mat64.SymDense

x [][]float64
probs []float64
}{
{
nu: 3,
mu: []float64{0, 0},
sigma: mat64.NewSymDense(2, []float64{1, 0, 0, 1}),

x: [][]float64{
{0, 0},
{1, -1},
{3, 4},
{-1, -2},
},
// Outputs compared with WolframAlpha.
probs: []float64{
0.159154943091895335768883,
0.0443811199724279860006777747927,
0.0005980371870904696541052658,
0.01370560783418571283428283,
},
},
{
nu: 4,
mu: []float64{2, -3},
sigma: mat64.NewSymDense(2, []float64{8, -1, -1, 5}),

x: [][]float64{
{0, 0},
{1, -1},
{3, 4},
{-1, -2},
{2, -3},
},
// Outputs compared with WolframAlpha.
probs: []float64{
0.007360810111491788657953608191001,
0.0143309905845607117740440592999,
0.0005307774290578041397794096037035009801668903,
0.0115657422475668739943625904793879,
0.0254851872062589062995305736215,
},
},
} {
s, ok := NewStudentsT(test.nu, test.mu, test.sigma, src)
if !ok {
t.Fatal("bad test")
}
for i, x := range test.x {
xcpy := make([]float64, len(x))
copy(xcpy, x)
p := s.Prob(x)
if !floats.Same(x, xcpy) {
t.Errorf("X modified during call to prob, %v, %v", x, xcpy)
}
if !floats.EqualWithinAbsOrRel(p, test.probs[i], 1e-10, 1e-10) {
t.Errorf("Probability mismatch. X = %v. Got %v, want %v.", x, p, test.probs[i])
}
}
}
}

func TestStudentsTRand(t *testing.T) {
src := rand.New(rand.NewSource(1))
for _, test := range []struct {
mean []float64
cov *mat64.SymDense
nu float64
tolcov float64
}{
{
mean: []float64{0, 0},
cov: mat64.NewSymDense(2, []float64{1, 0, 0, 1}),
nu: 3,
tolcov: 5e-2,
},
{
mean: []float64{3, 4},
cov: mat64.NewSymDense(2, []float64{5, 1.2, 1.2, 6}),
nu: 8,
tolcov: 1e-2,
},
{
mean: []float64{3, 4, -2},
cov: mat64.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}),
nu: 8,
tolcov: 1e-2,
},
} {
s, ok := NewStudentsT(test.nu, test.mean, test.cov, src)
if !ok {
t.Fatal("bad test")
}
nSamples := 1000000
dim := len(test.mean)
samps := mat64.NewDense(nSamples, dim, nil)
for i := 0; i < nSamples; i++ {
s.Rand(samps.RawRowView(i))
}
estMean := make([]float64, dim)
for i := range estMean {
estMean[i] = stat.Mean(mat64.Col(nil, i, samps), nil)
}
mean := s.Mean(nil)
if !floats.EqualApprox(estMean, mean, 1e-2) {
t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean)
}
cov := s.CovarianceMatrix(nil)
estCov := stat.CovarianceMatrix(nil, samps, nil)
if !mat64.EqualApprox(estCov, cov, test.tolcov) {
t.Errorf("Cov mismatch: want: %v, got %v", cov, estCov)
}
}
}

func TestStudentsTConditional(t *testing.T) {
src := rand.New(rand.NewSource(1))
for _, test := range []struct {
mean []float64
cov *mat64.SymDense
nu float64

idx []int
value []float64
tolcov float64
}{
{
mean: []float64{3, 4, -2},
cov: mat64.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}),
nu: 8,
idx: []int{0},
value: []float64{6},

tolcov: 1e-2,
},
} {
s, ok := NewStudentsT(test.nu, test.mean, test.cov, src)
if !ok {
t.Fatal("bad test")
}

sUp, ok := s.ConditionStudentsT(test.idx, test.value, src)

// Compute the other values by hand the inefficient way to compare
newNu := test.nu + float64(len(test.idx))
if newNu != sUp.nu {
t.Errorf("Updated nu mismatch. Got %v, want %v", s.nu, newNu)
}
dim := len(test.mean)
unob := findUnob(test.idx, dim)
ob := test.idx

muUnob := make([]float64, len(unob))
for i, v := range unob {
muUnob[i] = test.mean[v]
}
muOb := make([]float64, len(ob))
for i, v := range ob {
muOb[i] = test.mean[v]
}

s.setSigma()
sUp.setSigma()

var sig11, sig22 mat64.SymDense
sig11.SubsetSym(s.sigma, unob)
sig22.SubsetSym(s.sigma, ob)

sig12 := mat64.NewDense(len(unob), len(ob), nil)
for i := range unob {
for j := range ob {
sig12.Set(i, j, s.sigma.At(unob[i], ob[j]))
}
}

shift := make([]float64, len(ob))
copy(shift, test.value)
floats.Sub(shift, muOb)

newMu := make([]float64, len(muUnob))
newMuVec := mat64.NewVector(len(muUnob), newMu)
shiftVec := mat64.NewVector(len(shift), shift)
var tmp mat64.Vector
tmp.SolveVec(&sig22, shiftVec)
newMuVec.MulVec(sig12, &tmp)
floats.Add(newMu, muUnob)

if !floats.EqualApprox(newMu, sUp.mu, 1e-10) {
t.Errorf("Mu mismatch. Got %v, want %v", sUp.mu, newMu)
}

var tmp2 mat64.Dense
tmp2.Solve(&sig22, sig12.T())

var tmp3 mat64.Dense
tmp3.Mul(sig12, &tmp2)
tmp3.Sub(&sig11, &tmp3)

dot := mat64.Dot(shiftVec, &tmp)
tmp3.Scale((test.nu+dot)/(test.nu+float64(len(ob))), &tmp3)
if !mat64.EqualApprox(&tmp3, sUp.sigma, 1e-10) {
t.Errorf("Sigma mismatch")
}
}
}
7 changes: 4 additions & 3 deletions distmv/studentt.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func NewStudentsT(nu float64, mu []float64, sigma mat64.Symmetric, src *rand.Ran
s := &StudentsT{
nu: nu,
mu: make([]float64, dim),
dim: dim,
src: src,
}
copy(s.mu, mu)
Expand All @@ -69,7 +70,7 @@ func NewStudentsT(nu float64, mu []float64, sigma mat64.Symmetric, src *rand.Ran
}
s.lower.LFromCholesky(&s.chol)
s.logSqrtDet = 0.5 * s.chol.LogDet()
return s, false
return s, true
}

// ConditionStudentsT returns the Student's T distribution that is the receiver
Expand Down Expand Up @@ -252,7 +253,7 @@ func (s *StudentsT) LogProb(y []float64) float64 {

shift := make([]float64, len(y))
copy(shift, y)
floats.Sub(y, s.mu)
floats.Sub(shift, s.mu)

x := mat64.NewVector(s.dim, shift)

Expand All @@ -261,7 +262,7 @@ func (s *StudentsT) LogProb(y []float64) float64 {

dot := mat64.Dot(&tmp, x)

return t1 - ((nu+n)/2)*(1+dot/nu)
return t1 - ((nu+n)/2)*math.Log(1+dot/nu)
}

// MarginalStudentsT returns the marginal distribution of the given input variables.
Expand Down

0 comments on commit d3c96ed

Please sign in to comment.