From d3c96ed0e13320351268ef2b71850767dc72de6f Mon Sep 17 00:00:00 2001 From: Brendan Tracey Date: Fri, 21 Oct 2016 14:04:43 -0600 Subject: [PATCH] Add tests for Students T --- distmv/studentst_test.go | 227 +++++++++++++++++++++++++++++++++++++++ distmv/studentt.go | 7 +- 2 files changed, 231 insertions(+), 3 deletions(-) create mode 100644 distmv/studentst_test.go diff --git a/distmv/studentst_test.go b/distmv/studentst_test.go new file mode 100644 index 0000000..48ee1cc --- /dev/null +++ b/distmv/studentst_test.go @@ -0,0 +1,227 @@ +// Copyright ©2016 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package distmv + +import ( + "math/rand" + "testing" + + "github.com/gonum/floats" + "github.com/gonum/matrix/mat64" + "github.com/gonum/stat" +) + +func TestStudentTProbs(t *testing.T) { + src := rand.New(rand.NewSource(1)) + for _, test := range []struct { + nu float64 + mu []float64 + sigma *mat64.SymDense + + x [][]float64 + probs []float64 + }{ + { + nu: 3, + mu: []float64{0, 0}, + sigma: mat64.NewSymDense(2, []float64{1, 0, 0, 1}), + + x: [][]float64{ + {0, 0}, + {1, -1}, + {3, 4}, + {-1, -2}, + }, + // Outputs compared with WolframAlpha. + probs: []float64{ + 0.159154943091895335768883, + 0.0443811199724279860006777747927, + 0.0005980371870904696541052658, + 0.01370560783418571283428283, + }, + }, + { + nu: 4, + mu: []float64{2, -3}, + sigma: mat64.NewSymDense(2, []float64{8, -1, -1, 5}), + + x: [][]float64{ + {0, 0}, + {1, -1}, + {3, 4}, + {-1, -2}, + {2, -3}, + }, + // Outputs compared with WolframAlpha. + probs: []float64{ + 0.007360810111491788657953608191001, + 0.0143309905845607117740440592999, + 0.0005307774290578041397794096037035009801668903, + 0.0115657422475668739943625904793879, + 0.0254851872062589062995305736215, + }, + }, + } { + s, ok := NewStudentsT(test.nu, test.mu, test.sigma, src) + if !ok { + t.Fatal("bad test") + } + for i, x := range test.x { + xcpy := make([]float64, len(x)) + copy(xcpy, x) + p := s.Prob(x) + if !floats.Same(x, xcpy) { + t.Errorf("X modified during call to prob, %v, %v", x, xcpy) + } + if !floats.EqualWithinAbsOrRel(p, test.probs[i], 1e-10, 1e-10) { + t.Errorf("Probability mismatch. X = %v. Got %v, want %v.", x, p, test.probs[i]) + } + } + } +} + +func TestStudentsTRand(t *testing.T) { + src := rand.New(rand.NewSource(1)) + for _, test := range []struct { + mean []float64 + cov *mat64.SymDense + nu float64 + tolcov float64 + }{ + { + mean: []float64{0, 0}, + cov: mat64.NewSymDense(2, []float64{1, 0, 0, 1}), + nu: 3, + tolcov: 5e-2, + }, + { + mean: []float64{3, 4}, + cov: mat64.NewSymDense(2, []float64{5, 1.2, 1.2, 6}), + nu: 8, + tolcov: 1e-2, + }, + { + mean: []float64{3, 4, -2}, + cov: mat64.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}), + nu: 8, + tolcov: 1e-2, + }, + } { + s, ok := NewStudentsT(test.nu, test.mean, test.cov, src) + if !ok { + t.Fatal("bad test") + } + nSamples := 1000000 + dim := len(test.mean) + samps := mat64.NewDense(nSamples, dim, nil) + for i := 0; i < nSamples; i++ { + s.Rand(samps.RawRowView(i)) + } + estMean := make([]float64, dim) + for i := range estMean { + estMean[i] = stat.Mean(mat64.Col(nil, i, samps), nil) + } + mean := s.Mean(nil) + if !floats.EqualApprox(estMean, mean, 1e-2) { + t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean) + } + cov := s.CovarianceMatrix(nil) + estCov := stat.CovarianceMatrix(nil, samps, nil) + if !mat64.EqualApprox(estCov, cov, test.tolcov) { + t.Errorf("Cov mismatch: want: %v, got %v", cov, estCov) + } + } +} + +func TestStudentsTConditional(t *testing.T) { + src := rand.New(rand.NewSource(1)) + for _, test := range []struct { + mean []float64 + cov *mat64.SymDense + nu float64 + + idx []int + value []float64 + tolcov float64 + }{ + { + mean: []float64{3, 4, -2}, + cov: mat64.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}), + nu: 8, + idx: []int{0}, + value: []float64{6}, + + tolcov: 1e-2, + }, + } { + s, ok := NewStudentsT(test.nu, test.mean, test.cov, src) + if !ok { + t.Fatal("bad test") + } + + sUp, ok := s.ConditionStudentsT(test.idx, test.value, src) + + // Compute the other values by hand the inefficient way to compare + newNu := test.nu + float64(len(test.idx)) + if newNu != sUp.nu { + t.Errorf("Updated nu mismatch. Got %v, want %v", s.nu, newNu) + } + dim := len(test.mean) + unob := findUnob(test.idx, dim) + ob := test.idx + + muUnob := make([]float64, len(unob)) + for i, v := range unob { + muUnob[i] = test.mean[v] + } + muOb := make([]float64, len(ob)) + for i, v := range ob { + muOb[i] = test.mean[v] + } + + s.setSigma() + sUp.setSigma() + + var sig11, sig22 mat64.SymDense + sig11.SubsetSym(s.sigma, unob) + sig22.SubsetSym(s.sigma, ob) + + sig12 := mat64.NewDense(len(unob), len(ob), nil) + for i := range unob { + for j := range ob { + sig12.Set(i, j, s.sigma.At(unob[i], ob[j])) + } + } + + shift := make([]float64, len(ob)) + copy(shift, test.value) + floats.Sub(shift, muOb) + + newMu := make([]float64, len(muUnob)) + newMuVec := mat64.NewVector(len(muUnob), newMu) + shiftVec := mat64.NewVector(len(shift), shift) + var tmp mat64.Vector + tmp.SolveVec(&sig22, shiftVec) + newMuVec.MulVec(sig12, &tmp) + floats.Add(newMu, muUnob) + + if !floats.EqualApprox(newMu, sUp.mu, 1e-10) { + t.Errorf("Mu mismatch. Got %v, want %v", sUp.mu, newMu) + } + + var tmp2 mat64.Dense + tmp2.Solve(&sig22, sig12.T()) + + var tmp3 mat64.Dense + tmp3.Mul(sig12, &tmp2) + tmp3.Sub(&sig11, &tmp3) + + dot := mat64.Dot(shiftVec, &tmp) + tmp3.Scale((test.nu+dot)/(test.nu+float64(len(ob))), &tmp3) + if !mat64.EqualApprox(&tmp3, sUp.sigma, 1e-10) { + t.Errorf("Sigma mismatch") + } + } +} diff --git a/distmv/studentt.go b/distmv/studentt.go index 2cf701b..a362b22 100644 --- a/distmv/studentt.go +++ b/distmv/studentt.go @@ -59,6 +59,7 @@ func NewStudentsT(nu float64, mu []float64, sigma mat64.Symmetric, src *rand.Ran s := &StudentsT{ nu: nu, mu: make([]float64, dim), + dim: dim, src: src, } copy(s.mu, mu) @@ -69,7 +70,7 @@ func NewStudentsT(nu float64, mu []float64, sigma mat64.Symmetric, src *rand.Ran } s.lower.LFromCholesky(&s.chol) s.logSqrtDet = 0.5 * s.chol.LogDet() - return s, false + return s, true } // ConditionStudentsT returns the Student's T distribution that is the receiver @@ -252,7 +253,7 @@ func (s *StudentsT) LogProb(y []float64) float64 { shift := make([]float64, len(y)) copy(shift, y) - floats.Sub(y, s.mu) + floats.Sub(shift, s.mu) x := mat64.NewVector(s.dim, shift) @@ -261,7 +262,7 @@ func (s *StudentsT) LogProb(y []float64) float64 { dot := mat64.Dot(&tmp, x) - return t1 - ((nu+n)/2)*(1+dot/nu) + return t1 - ((nu+n)/2)*math.Log(1+dot/nu) } // MarginalStudentsT returns the marginal distribution of the given input variables.