diff --git a/mat64/cholesky.go b/mat64/cholesky.go index f33346b..cd2a189 100644 --- a/mat64/cholesky.go +++ b/mat64/cholesky.go @@ -76,111 +76,71 @@ func (t *Triangular) Cholesky(a *SymDense, upper bool) (ok bool) { return true } -// Solve returns a matrix x that solves A * X = B where A = L * L^T or A = U^T * U, -// and U or L are represented by the receiver. The matrix a must be symmetric and -// positive definite. If b is mutable it is overwritten by the operation. -func (t *Triangular) Solve(b Matrix) (x Matrix) { +// SolveTri finds the matrix x that solves A * X = B where A = L * L^T or +// A = U^T * U, and U or L are represented by t. The matrix A must be symmetric +// and positive definite. +func (m *Dense) SolveTri(t *Triangular, b Matrix) error { _, n := t.Dims() bm, bn := b.Dims() if n != bm { - panic(ErrShape) + return ErrShape } - nx := bn - - switch b := b.(type) { - case Mutable: - x := b - - if t.mat.Uplo == blas.Upper { - // Solve U'*Y = B; - for k := 0; k < n; k++ { - for j := 0; j < nx; j++ { - for i := 0; i < k; i++ { - x.Set(k, j, x.At(k, j)-x.At(i, j)*t.at(i, k)) - } - x.Set(k, j, x.At(k, j)/t.at(k, k)) - } - } - // Solve U*X = Y; - for k := n - 1; k >= 0; k-- { - for j := 0; j < nx; j++ { - for i := k + 1; i < n; i++ { - x.Set(k, j, x.At(k, j)-x.At(i, j)*t.at(k, i)) - } - x.Set(k, j, x.At(k, j)/t.at(k, k)) - } - } - } else { - // Solve L*Y = B; - for k := 0; k < n; k++ { - for j := 0; j < nx; j++ { - for i := 0; i < k; i++ { - x.Set(k, j, x.At(k, j)-x.At(i, j)*t.at(k, i)) - } - x.Set(k, j, x.At(k, j)/t.at(k, k)) - } - } - - // Solve L'*X = Y; - for k := n - 1; k >= 0; k-- { - for j := 0; j < nx; j++ { - for i := k + 1; i < n; i++ { - x.Set(k, j, x.At(k, j)-x.At(i, j)*t.at(i, k)) - } - x.Set(k, j, x.At(k, j)/t.at(k, k)) - } - } + if m.isZero() { + m.mat = blas64.General{ + Rows: bm, + Cols: bn, + Stride: bn, + Data: use(m.mat.Data, bm*bn), } + } else if bm != m.mat.Rows || bn != m.mat.Cols { + return ErrShape + } + if b != m { + m.Copy(b) + } - return x - - default: - x := NewDense(bm, bn, nil) - nx := bn - - if t.mat.Uplo == blas.Upper { - // Solve U'*Y = B; - for k := 0; k < n; k++ { - for j := 0; j < nx; j++ { - for i := 0; i < k; i++ { - x.set(k, j, x.at(k, j)-x.at(i, j)*t.at(i, k)) - } - x.set(k, j, x.at(k, j)/t.at(k, k)) + if t.mat.Uplo == blas.Upper { + // Solve U'*Y = B; + for k := 0; k < n; k++ { + for j := 0; j < bn; j++ { + for i := 0; i < k; i++ { + m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(i, k)) } + m.set(k, j, m.at(k, j)/t.at(k, k)) } + } - // Solve U*X = Y; - for k := n - 1; k >= 0; k-- { - for j := 0; j < nx; j++ { - for i := k + 1; i < n; i++ { - x.set(k, j, x.at(k, j)-x.at(i, j)*t.at(k, i)) - } - x.set(k, j, x.at(k, j)/t.at(k, k)) + // Solve U*X = Y; + for k := n - 1; k >= 0; k-- { + for j := 0; j < bn; j++ { + for i := k + 1; i < n; i++ { + m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(k, i)) } + m.set(k, j, m.at(k, j)/t.at(k, k)) } - } else { - // Solve L*Y = B; - for k := 0; k < n; k++ { - for j := 0; j < nx; j++ { - for i := 0; i < k; i++ { - x.set(k, j, x.at(k, j)-x.at(i, j)*t.at(k, i)) - } - x.set(k, j, x.at(k, j)/t.at(k, k)) + } + } else { + // Solve L*Y = B; + for k := 0; k < n; k++ { + for j := 0; j < bn; j++ { + for i := 0; i < k; i++ { + m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(k, i)) } + m.set(k, j, m.at(k, j)/t.at(k, k)) } + } - // Solve L'*X = Y; - for k := n - 1; k >= 0; k-- { - for j := 0; j < nx; j++ { - for i := k + 1; i < n; i++ { - x.set(k, j, x.at(k, j)-x.at(i, j)*t.at(i, k)) - } - x.set(k, j, x.at(k, j)/t.at(k, k)) + // Solve L'*X = Y; + for k := n - 1; k >= 0; k-- { + for j := 0; j < bn; j++ { + for i := k + 1; i < n; i++ { + m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(i, k)) } + m.set(k, j, m.at(k, j)/t.at(k, k)) } } - - return x } + + return nil } diff --git a/mat64/cholesky_test.go b/mat64/cholesky_test.go index a296a0d..71f779e 100644 --- a/mat64/cholesky_test.go +++ b/mat64/cholesky_test.go @@ -95,10 +95,11 @@ func (s *S) TestCholesky(c *check.C) { } c.Check(fc.EqualsApprox(t.a, 1e-12), check.Equals, true) - x := t.f.Solve(eye()) + var x Dense + c.Check(x.SolveTri(t.f, eye()), check.Equals, nil) var res Dense - res.Mul(t.a, x) + res.Mul(t.a, &x) c.Check(res.EqualsApprox(eye(), 1e-12), check.Equals, true) } } @@ -121,7 +122,9 @@ func (s *S) TestCholeskySolve(c *check.C) { var f Triangular ok := f.Cholesky(t.a, false) c.Assert(ok, check.Equals, true) - ans := DenseCopyOf(f.Solve(t.b)) - c.Check(ans.EqualsApprox(t.ans, 1e-12), check.Equals, true) + + var x Dense + c.Check(x.SolveTri(&f, t.b), check.Equals, nil) + c.Check(x.EqualsApprox(t.ans, 1e-12), check.Equals, true) } }