-
Notifications
You must be signed in to change notification settings - Fork 11
Adddgesvd #79
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ const ( | |
| badLdA = "lapack: index of a out of range" | ||
| badNorm = "lapack: bad norm" | ||
| badPivot = "lapack: bad pivot" | ||
| badS = "lapack: s has insufficient length" | ||
| badSide = "lapack: bad side" | ||
| badSlice = "lapack: bad input slice length" | ||
| badStore = "lapack: bad store" | ||
|
|
@@ -497,6 +498,85 @@ func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float | |
| return clapack.Dgels(trans, m, n, nrhs, a, lda, b, ldb) | ||
| } | ||
|
|
||
| const noSVDO = "dgesvd: not coded for overwrite" | ||
|
|
||
| // Dgesvd computes the singular value decomposition of the input matrix A. | ||
| // | ||
| // The singular value decomposition is | ||
| // A = U * Sigma * V^T | ||
| // where Sigma is an m×n diagonal matrix containing the singular values of A, | ||
| // U is an m×m orthogonal matrix and V is an n×n orthogonal matrix. The first | ||
| // min(m,n) columns of U and V are the left and right singular vectors of A | ||
| // respectively. | ||
| // | ||
| // jobU and jobVT are options for computing the singular vectors. The behavior | ||
| // is as follows | ||
| // jobU == lapack.SVDAll All m columns of U are returned in u | ||
| // jobU == lapack.SVDInPlace The first min(m,n) columns are returned in u | ||
| // jobU == lapack.SVDOverwrite The first min(m,n) columns of U are written into a | ||
| // jobU == lapack.SVDNone The columns of U are not computed. | ||
| // The behavior is the same for jobVT and the rows of V^T. At most one of jobU | ||
| // and jobVT can equal lapack.SVDOverwrite, and Dgesvd will panic otherwise. | ||
| // | ||
| // On entry, a contains the data for the m×n matrix A. During the call to Dgesvd | ||
| // the data is overwritten. On exit, A contains the appropriate singular vectors | ||
| // if either job is lapack.SVDOverwrite. | ||
| // | ||
| // s is a slice of length at least min(m,n) and on exit contains the singular | ||
| // values in decreasing order. | ||
| // | ||
| // u contains the left singular vectors on exit, stored columnwise. If | ||
| // jobU == lapack.SVDAll, u is of size m×m. If jobU == lapack.SVDInPlace u is | ||
| // of size m×min(m,n). If jobU == lapack.SVDOverwrite or lapack.SVDNone, u is | ||
| // not used. | ||
| // | ||
| // vt contains the left singular vectors on exit, stored rowwise. If | ||
| // jobV == lapack.SVDAll, vt is of size n×m. If jobVT == lapack.SVDInPlace vt is | ||
| // of size min(m,n)×n. If jobVT == lapack.SVDOverwrite or lapack.SVDNone, vt is | ||
| // not used. | ||
| // | ||
| // The C interface does not support providing temporary storage. To provide compatibility | ||
| // with native, lwork == -1 will not run Dgesvd but will instead write the minimum | ||
| // work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic. | ||
| // | ||
| // Dgesvd returns whether the decomposition successfully completed. | ||
| func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool) { | ||
| checkMatrix(m, n, a, lda) | ||
| if jobU == lapack.SVDAll { | ||
| checkMatrix(m, m, u, ldu) | ||
| } else if jobU == lapack.SVDInPlace { | ||
| checkMatrix(m, min(m, n), u, ldu) | ||
| } | ||
| if jobVT == lapack.SVDAll { | ||
| checkMatrix(n, n, vt, ldvt) | ||
| } else if jobVT == lapack.SVDInPlace { | ||
| checkMatrix(min(m, n), n, vt, ldvt) | ||
| } | ||
| if jobU == lapack.SVDOverwrite && jobVT == lapack.SVDOverwrite { | ||
| panic(noSVDO) | ||
| } | ||
| if len(s) < min(m, n) { | ||
| panic(badS) | ||
| } | ||
| if jobU == lapack.SVDOverwrite || jobVT == lapack.SVDOverwrite { | ||
| panic("lapack: SVD not coded to overwrite original matrix") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this because the LAPACKE doesn't support it or we don't copy back?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Neither. It's because it's about 2x the work to implement it, and I was trying to allow us to shift matrix over to this faster. There's a lot of trickiness with strides and col/row major. I can implement it if we want completeness for it to go in.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, so this is just to reflect the behaviour in the Go implementation? There's no need to make the implementation complete for this PR.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Yes. It seems to me like trying to have consistent behavior between implementations is more important. |
||
| } | ||
| minWork := max(5*min(m, n), 3*min(m, n)+max(m, n)) | ||
| if lwork != -1 { | ||
| if len(work) < lwork { | ||
| panic(badWork) | ||
| } | ||
| if lwork < minWork { | ||
| panic(badWork) | ||
| } | ||
| } | ||
| if lwork == -1 { | ||
| work[0] = float64(minWork) | ||
| return true | ||
| } | ||
| return clapack.Dgesvd(lapack.Job(jobU), lapack.Job(jobVT), m, n, a, lda, s, u, ldu, vt, ldvt, work[1:]) | ||
| } | ||
|
|
||
| // Dgetf2 computes the LU decomposition of the m×n matrix A. | ||
| // The LU decomposition is a factorization of a into | ||
| // A = P * L * U | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,7 +86,7 @@ func (impl Implementation) Dgebrd(m, n int, a []float64, lda int, d, e, tauQ, ta | |
| if lwork >= (m+n)*nbmin { | ||
| nb = lwork / (m + n) | ||
| } else { | ||
| nb = 1 | ||
| nb = minmn | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is because our loop subtracts nb rather than nx. When this if statement is met, it restricts the block size.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| nx = minmn | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/Dgeqrf/Dgesvd/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.