-
Notifications
You must be signed in to change notification settings - Fork 11
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ import ( | |
|
|
||
| // Copied from lapack/native. Keep in sync. | ||
| const ( | ||
| absIncNotOne = "lapack: increment not one or negative one" | ||
| badDirect = "lapack: bad direct" | ||
| badIpiv = "lapack: insufficient permutation length" | ||
| badLdA = "lapack: index of a out of range" | ||
|
|
@@ -76,7 +77,7 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok | |
| return clapack.Dpotrf(ul, n, a, lda) | ||
| } | ||
|
|
||
| // Dgetf2 computes the LU decomposition of the m×n matrix a. | ||
| // Dgetf2 computes the LU decomposition of the m×n matrix A. | ||
| // The LU decomposition is a factorization of a into | ||
| // A = P * L * U | ||
| // where P is a permutation matrix, L is a unit lower triangular matrix, and | ||
|
|
@@ -85,9 +86,9 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok | |
| // | ||
| // ipiv is a permutation vector. It indicates that row i of the matrix was | ||
| // changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic | ||
| // otherwise. | ||
| // otherwise. ipiv is zero-indexed. | ||
| // | ||
| // Dgetf2 returns whether the matrix a is singular. The LU decomposition will | ||
| // Dgetf2 returns whether the matrix A is singular. The LU decomposition will | ||
| // be computed regardless of the singularity of A, but division by zero | ||
| // will occur if the false is returned and the result is used to solve a | ||
| // system of equations. | ||
|
|
@@ -100,7 +101,38 @@ func (Implementation) Dgetf2(m, n int, a []float64, lda int, ipiv []int) (ok boo | |
| ipiv32 := make([]int32, len(ipiv)) | ||
| ok = clapack.Dgetf2(m, n, a, lda, ipiv32) | ||
| for i, v := range ipiv32 { | ||
| ipiv[i] = int(v) - 1 // OpenBLAS returns one indexed. | ||
| ipiv[i] = int(v) - 1 // Transform to zero-indexed. | ||
| } | ||
| return ok | ||
| } | ||
|
|
||
| // Dgetrf computes the LU decomposition of the m×n matrix A. | ||
| // The LU decomposition is a factorization of a into | ||
| // A = P * L * U | ||
| // where P is a permutation matrix, L is a unit lower triangular matrix, and | ||
| // U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored | ||
| // in place into a. | ||
| // | ||
| // ipiv is a permutation vector. It indicates that row i of the matrix was | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we note that ipiv is zero-based here. It may be surprising that we don't do one-based when everyone (?) else does.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| // changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic | ||
| // otherwise. ipiv is zero-indexed. | ||
| // | ||
| // Dgetrf is the blocked version of the algorithm. | ||
| // | ||
| // Dgetrf returns whether the matrix A is singular. The LU decomposition will | ||
| // be computed regardless of the singularity of A, but division by zero | ||
| // will occur if the false is returned and the result is used to solve a | ||
| // system of equations. | ||
| func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) { | ||
| mn := min(m, n) | ||
| checkMatrix(m, n, a, lda) | ||
| if len(ipiv) < mn { | ||
| panic(badIpiv) | ||
| } | ||
| ipiv32 := make([]int32, len(ipiv)) | ||
| ok = clapack.Dgetrf(m, n, a, lda, ipiv32) | ||
| for i, v := range ipiv32 { | ||
| ipiv[i] = int(v) - 1 // Transform to zero-indexed. | ||
| } | ||
| return ok | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| package native | ||
|
|
||
| import ( | ||
| "github.com/gonum/blas" | ||
| "github.com/gonum/blas/blas64" | ||
| ) | ||
|
|
||
| // Dgetrf computes the LU decomposition of the m×n matrix a. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Capitalisation of "A". |
||
| // The LU decomposition is a factorization of a into | ||
| // A = P * L * U | ||
| // where P is a permutation matrix, L is a unit lower triangular matrix, and | ||
| // U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored | ||
| // in place into a. | ||
| // | ||
| // ipiv is a permutation vector. It indicates that row i of the matrix was | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that ipiv is zero-based (probably in other places also - maybe another PR?).
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the comment for the LU cases. |
||
| // changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic | ||
| // otherwise. ipiv is zero-indexed. | ||
| // | ||
| // Dgetrf is the blocked version of the algorithm. | ||
| // | ||
| // Dgetrf returns whether the matrix A is singular. The LU decomposition will | ||
| // be computed regardless of the singularity of A, but division by zero | ||
| // will occur if the false is returned and the result is used to solve a | ||
| // system of equations. | ||
| func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) { | ||
| mn := min(m, n) | ||
| checkMatrix(m, n, a, lda) | ||
| if len(ipiv) < mn { | ||
| panic(badIpiv) | ||
| } | ||
| if m == 0 || n == 0 { | ||
| return | ||
| } | ||
| bi := blas64.Implementation() | ||
| nb := impl.Ilaenv(1, "DGETRF", " ", m, n, -1, -1) | ||
| if nb <= 1 || nb >= min(m, n) { | ||
| // Use the unblocked algorithm. | ||
| return impl.Dgetf2(m, n, a, lda, ipiv) | ||
| } | ||
| ok = true | ||
| for j := 0; j < mn; j += nb { | ||
| jb := min(mn-j, nb) | ||
| blockOk := impl.Dgetf2(m-j, jb, a[j*lda+j:], lda, ipiv[j:]) | ||
| if !blockOk { | ||
| ok = false | ||
| } | ||
| for i := j; i <= min(m-1, j+jb-1); i++ { | ||
| ipiv[i] = j + ipiv[i] | ||
| } | ||
| impl.Dlaswp(j, a, lda, j, j+jb-1, ipiv, 1) | ||
| if j+jb < n { | ||
| impl.Dlaswp(n-j-jb, a[j+jb:], lda, j, j+jb-1, ipiv, 1) | ||
| bi.Dtrsm(blas.Left, blas.Lower, blas.NoTrans, blas.Unit, | ||
| jb, n-j-jb, 1, | ||
| a[j*lda+j:], lda, | ||
| a[j*lda+j+jb:], lda) | ||
| if j+jb < m { | ||
| bi.Dgemm(blas.NoTrans, blas.NoTrans, m-j-jb, n-j-jb, jb, -1, | ||
| a[(j+jb)*lda+j:], lda, | ||
| a[j*lda+j+jb:], lda, | ||
| 1, a[(j+jb)*lda+j+jb:], lda) | ||
| } | ||
| } | ||
| } | ||
| return ok | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| package native | ||
|
|
||
| import "github.com/gonum/blas/blas64" | ||
|
|
||
| // Dlaswp swaps the rows k1 to k2 of a according to the indices in ipiv. | ||
| // a is a matrix with n columns and stride lda. incX is the increment for ipiv. | ||
| // k1 and k2 are zero-indexed. If incX is negative, then loops from k2 to k1 | ||
| func (impl Implementation) Dlaswp(n int, a []float64, lda, k1, k2 int, ipiv []int, incX int) { | ||
| if incX != 1 && incX != -1 { | ||
| panic(absIncNotOne) | ||
| } | ||
| bi := blas64.Implementation() | ||
| if incX == 1 { | ||
| for k := k1; k <= k2; k++ { | ||
| bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1) | ||
| } | ||
| return | ||
| } | ||
| for k := k2; k >= k1; k-- { | ||
| bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1) | ||
| } | ||
| return | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,9 @@ type Implementation struct{} | |
|
|
||
| var _ lapack.Float64 = Implementation{} | ||
|
|
||
| // This list is duplicated in lapack/cgo. Keep in sync. | ||
| const ( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sync comment. |
||
| absIncNotOne = "lapack: increment not one or negative one" | ||
| badDirect = "lapack: bad direct" | ||
| badIpiv = "lapack: insufficient permutation length" | ||
| badLdA = "lapack: index of a out of range" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reflect this comment in the native location as well - I should have done that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, oops. I had this in at one point as well.