Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Commit

Permalink
Make Solve a method on Dense
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Mar 21, 2015
1 parent d405b56 commit 13bd105
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 92 deletions.
136 changes: 48 additions & 88 deletions mat64/cholesky.go
Expand Up @@ -76,111 +76,71 @@ func (t *Triangular) Cholesky(a *SymDense, upper bool) (ok bool) {
return true
}

// Solve returns a matrix x that solves A * X = B where A = L * L^T or A = U^T * U,
// and U or L are represented by the receiver. The matrix a must be symmetric and
// positive definite. If b is mutable it is overwritten by the operation.
func (t *Triangular) Solve(b Matrix) (x Matrix) {
// SolveTri finds the matrix x that solves A * X = B where A = L * L^T or
// A = U^T * U, and U or L are represented by t. The matrix A must be symmetric
// and positive definite.
func (m *Dense) SolveTri(t *Triangular, b Matrix) error {
_, n := t.Dims()
bm, bn := b.Dims()
if n != bm {
panic(ErrShape)
return ErrShape
}
nx := bn

switch b := b.(type) {
case Mutable:
x := b

if t.mat.Uplo == blas.Upper {
// Solve U'*Y = B;
for k := 0; k < n; k++ {
for j := 0; j < nx; j++ {
for i := 0; i < k; i++ {
x.Set(k, j, x.At(k, j)-x.At(i, j)*t.at(i, k))
}
x.Set(k, j, x.At(k, j)/t.at(k, k))
}
}

// Solve U*X = Y;
for k := n - 1; k >= 0; k-- {
for j := 0; j < nx; j++ {
for i := k + 1; i < n; i++ {
x.Set(k, j, x.At(k, j)-x.At(i, j)*t.at(k, i))
}
x.Set(k, j, x.At(k, j)/t.at(k, k))
}
}
} else {
// Solve L*Y = B;
for k := 0; k < n; k++ {
for j := 0; j < nx; j++ {
for i := 0; i < k; i++ {
x.Set(k, j, x.At(k, j)-x.At(i, j)*t.at(k, i))
}
x.Set(k, j, x.At(k, j)/t.at(k, k))
}
}

// Solve L'*X = Y;
for k := n - 1; k >= 0; k-- {
for j := 0; j < nx; j++ {
for i := k + 1; i < n; i++ {
x.Set(k, j, x.At(k, j)-x.At(i, j)*t.at(i, k))
}
x.Set(k, j, x.At(k, j)/t.at(k, k))
}
}
if m.isZero() {
m.mat = blas64.General{
Rows: bm,
Cols: bn,
Stride: bn,
Data: use(m.mat.Data, bm*bn),
}
} else if bm != m.mat.Rows || bn != m.mat.Cols {
return ErrShape
}
if b != m {
m.Copy(b)
}

return x

default:
x := NewDense(bm, bn, nil)
nx := bn

if t.mat.Uplo == blas.Upper {
// Solve U'*Y = B;
for k := 0; k < n; k++ {
for j := 0; j < nx; j++ {
for i := 0; i < k; i++ {
x.set(k, j, x.at(k, j)-x.at(i, j)*t.at(i, k))
}
x.set(k, j, x.at(k, j)/t.at(k, k))
if t.mat.Uplo == blas.Upper {
// Solve U'*Y = B;
for k := 0; k < n; k++ {
for j := 0; j < bn; j++ {
for i := 0; i < k; i++ {
m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(i, k))
}
m.set(k, j, m.at(k, j)/t.at(k, k))
}
}

// Solve U*X = Y;
for k := n - 1; k >= 0; k-- {
for j := 0; j < nx; j++ {
for i := k + 1; i < n; i++ {
x.set(k, j, x.at(k, j)-x.at(i, j)*t.at(k, i))
}
x.set(k, j, x.at(k, j)/t.at(k, k))
// Solve U*X = Y;
for k := n - 1; k >= 0; k-- {
for j := 0; j < bn; j++ {
for i := k + 1; i < n; i++ {
m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(k, i))
}
m.set(k, j, m.at(k, j)/t.at(k, k))
}
} else {
// Solve L*Y = B;
for k := 0; k < n; k++ {
for j := 0; j < nx; j++ {
for i := 0; i < k; i++ {
x.set(k, j, x.at(k, j)-x.at(i, j)*t.at(k, i))
}
x.set(k, j, x.at(k, j)/t.at(k, k))
}
} else {
// Solve L*Y = B;
for k := 0; k < n; k++ {
for j := 0; j < bn; j++ {
for i := 0; i < k; i++ {
m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(k, i))
}
m.set(k, j, m.at(k, j)/t.at(k, k))
}
}

// Solve L'*X = Y;
for k := n - 1; k >= 0; k-- {
for j := 0; j < nx; j++ {
for i := k + 1; i < n; i++ {
x.set(k, j, x.at(k, j)-x.at(i, j)*t.at(i, k))
}
x.set(k, j, x.at(k, j)/t.at(k, k))
// Solve L'*X = Y;
for k := n - 1; k >= 0; k-- {
for j := 0; j < bn; j++ {
for i := k + 1; i < n; i++ {
m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(i, k))
}
m.set(k, j, m.at(k, j)/t.at(k, k))
}
}

return x
}

return nil
}
11 changes: 7 additions & 4 deletions mat64/cholesky_test.go
Expand Up @@ -95,10 +95,11 @@ func (s *S) TestCholesky(c *check.C) {
}
c.Check(fc.EqualsApprox(t.a, 1e-12), check.Equals, true)

x := t.f.Solve(eye())
var x Dense
c.Check(x.SolveTri(t.f, eye()), check.Equals, nil)

var res Dense
res.Mul(t.a, x)
res.Mul(t.a, &x)
c.Check(res.EqualsApprox(eye(), 1e-12), check.Equals, true)
}
}
Expand All @@ -121,7 +122,9 @@ func (s *S) TestCholeskySolve(c *check.C) {
var f Triangular
ok := f.Cholesky(t.a, false)
c.Assert(ok, check.Equals, true)
ans := DenseCopyOf(f.Solve(t.b))
c.Check(ans.EqualsApprox(t.ans, 1e-12), check.Equals, true)

var x Dense
c.Check(x.SolveTri(&f, t.b), check.Equals, nil)
c.Check(x.EqualsApprox(t.ans, 1e-12), check.Equals, true)
}
}

0 comments on commit 13bd105

Please sign in to comment.