diff --git a/mat64/cholesky.go b/mat64/cholesky.go index cd2a189..9343736 100644 --- a/mat64/cholesky.go +++ b/mat64/cholesky.go @@ -100,46 +100,15 @@ func (m *Dense) SolveTri(t *Triangular, b Matrix) error { m.Copy(b) } - if t.mat.Uplo == blas.Upper { - // Solve U'*Y = B; - for k := 0; k < n; k++ { - for j := 0; j < bn; j++ { - for i := 0; i < k; i++ { - m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(i, k)) - } - m.set(k, j, m.at(k, j)/t.at(k, k)) - } - } - - // Solve U*X = Y; - for k := n - 1; k >= 0; k-- { - for j := 0; j < bn; j++ { - for i := k + 1; i < n; i++ { - m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(k, i)) - } - m.set(k, j, m.at(k, j)/t.at(k, k)) - } - } - } else { - // Solve L*Y = B; - for k := 0; k < n; k++ { - for j := 0; j < bn; j++ { - for i := 0; i < k; i++ { - m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(k, i)) - } - m.set(k, j, m.at(k, j)/t.at(k, k)) - } - } - - // Solve L'*X = Y; - for k := n - 1; k >= 0; k-- { - for j := 0; j < bn; j++ { - for i := k + 1; i < n; i++ { - m.set(k, j, m.at(k, j)-m.at(i, j)*t.at(i, k)) - } - m.set(k, j, m.at(k, j)/t.at(k, k)) - } - } + switch t.mat.Uplo { + case blas.Upper: + blas64.Trsm(blas.Left, blas.Trans, 1, t.mat, m.mat) + blas64.Trsm(blas.Left, blas.NoTrans, 1, t.mat, m.mat) + case blas.Lower: + blas64.Trsm(blas.Left, blas.NoTrans, 1, t.mat, m.mat) + blas64.Trsm(blas.Left, blas.Trans, 1, t.mat, m.mat) + default: + panic(ErrUplo) } return nil