Skip to content

Commit

Permalink
mat: allow nil destination for Cholesky extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Jul 6, 2017
1 parent 1466bc5 commit 08e1f7d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
42 changes: 30 additions & 12 deletions mat/cholesky.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,38 +233,56 @@ func (c *Cholesky) SolveVec(v, b *Vector) error {
}

// UTo extracts the n×n upper triangular matrix U from a Cholesky
// decomposition into t.
// decomposition into dst and returns the result. If dst is nil a new
// TriDense is allocated.
// A = U^T * U.
func (c *Cholesky) UTo(t *TriDense) {
func (c *Cholesky) UTo(dst *TriDense) *TriDense {
if !c.valid() {
panic(badCholesky)
}
n := c.chol.mat.N
t.reuseAs(n, Upper)
t.Copy(c.chol)
if dst == nil {
dst = NewTriDense(n, Upper, make([]float64, n*n))
} else {
dst.reuseAs(n, Upper)
}
dst.Copy(c.chol)
return dst
}

// LTo extracts the n×n lower triangular matrix L from a Cholesky
// decomposition into t.
// decomposition into dst and returns the result. If dst is nil a new
// TriDense is allocated.
// A = L * L^T.
func (c *Cholesky) LTo(t *TriDense) {
func (c *Cholesky) LTo(dst *TriDense) *TriDense {
if !c.valid() {
panic(badCholesky)
}
n := c.chol.mat.N
t.reuseAs(n, Lower)
t.Copy(c.chol.TTri())
if dst == nil {
dst = NewTriDense(n, Lower, make([]float64, n*n))
} else {
dst.reuseAs(n, Lower)
}
dst.Copy(c.chol.TTri())
return dst
}

// To reconstructs the original positive definite matrix given its
// Cholesky decomposition into s.
func (c *Cholesky) To(s *SymDense) {
// Cholesky decomposition into dst and returns the result. If dst is nil
// a new SymDense is allocated.
func (c *Cholesky) To(dst *SymDense) *SymDense {
if !c.valid() {
panic(badCholesky)
}
n := c.chol.mat.N
s.reuseAs(n)
s.SymOuterK(1, c.chol.T())
if dst == nil {
dst = NewSymDense(n, make([]float64, n*n))
} else {
dst.reuseAs(n)
}
dst.SymOuterK(1, c.chol.T())
return dst
}

// InverseTo computes the inverse of the matrix represented by its Cholesky
Expand Down
10 changes: 4 additions & 6 deletions mat/cholesky_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ func ExampleCholesky() {
fmt.Printf("x = %0.4v\n", mat.Formatted(&x, mat.Prefix(" ")))

// Extract the factorization and check that it equals the original matrix.
var t mat.TriDense
chol.LTo(&t)
t := chol.LTo(nil)
var test mat.Dense
test.Mul(&t, t.T())
test.Mul(t, t.T())
fmt.Println()
fmt.Printf("L * L^T = %0.4v\n", mat.Formatted(&a, mat.Prefix(" ")))

Expand Down Expand Up @@ -92,13 +91,12 @@ func ExampleCholesky_SymRankOne() {
// Rank-1 update the matrix a.
a.SymRankOne(a, 1, x)

var au mat.SymDense
chol.To(&au)
au := chol.To(nil)

// Print the matrix that was updated directly.
fmt.Printf("\nA' = %0.4v\n", mat.Formatted(a, mat.Prefix(" ")))
// Print the matrix recovered from the factorization.
fmt.Printf("\nU'^T * U' = %0.4v\n", mat.Formatted(&au, mat.Prefix(" ")))
fmt.Printf("\nU'^T * U' = %0.4v\n", mat.Formatted(au, mat.Prefix(" ")))

// Output:
// A = ⎡ 1 1 1 1⎤
Expand Down
17 changes: 7 additions & 10 deletions mat/cholesky_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,16 @@ func TestCholesky(t *testing.T) {
if math.Abs(test.cond-chol.cond) > 1e-13 {
t.Errorf("Condition number mismatch: Want %v, got %v", test.cond, chol.cond)
}
var U TriDense
chol.UTo(&U)
U := chol.UTo(nil)
aCopy := DenseCopyOf(test.a)
var a Dense
a.Mul(U.TTri(), &U)
a.Mul(U.TTri(), U)
if !EqualApprox(&a, aCopy, 1e-14) {
t.Error("unexpected Cholesky factor product")
}

var L TriDense
chol.LTo(&L)
a.Mul(&L, L.TTri())
L := chol.LTo(nil)
a.Mul(L, L.TTri())
if !EqualApprox(&a, aCopy, 1e-14) {
t.Error("unexpected Cholesky factor product")
}
Expand Down Expand Up @@ -234,11 +232,10 @@ func TestCholeskyTo(t *testing.T) {
if !ok {
t.Fatal("unexpected Cholesky factorization failure: not positive definite")
}
var s SymDense
chol.To(&s)
s := chol.To(nil)

if !EqualApprox(&s, test, 1e-12) {
t.Errorf("Cholesky reconstruction not equal to original matrix.\nWant:\n% v\nGot:\n% v\n", Formatted(test), Formatted(&s))
if !EqualApprox(s, test, 1e-12) {
t.Errorf("Cholesky reconstruction not equal to original matrix.\nWant:\n% v\nGot:\n% v\n", Formatted(test), Formatted(s))
}
}
}
Expand Down

0 comments on commit 08e1f7d

Please sign in to comment.