New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lapack/gonum: fix various bugs in Dgetri #806
Changes from all commits
9540c89
4325e56
704a95b
3dcd135
e9c5252
5615c5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,71 +22,95 @@ import ( | |
// by the temporary space available. If lwork == -1, instead of performing Dgetri, | ||
// the optimal work length will be stored into work[0]. | ||
func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) { | ||
checkMatrix(n, n, a, lda) | ||
if len(ipiv) < n { | ||
panic(badIpiv) | ||
switch { | ||
case n < 0: | ||
panic("lapack: has negative number of columns") | ||
case lda < max(1, n): | ||
panic("lapack: stride less than number of columns") | ||
case lwork < max(1, n) && lwork != -1: | ||
panic(badWork) | ||
case len(work) < max(1, lwork): | ||
panic(shortWork) | ||
} | ||
nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1) | ||
if lwork == -1 { | ||
work[0] = float64(n * nb) | ||
|
||
if n == 0 { | ||
work[0] = 1 | ||
return true | ||
} | ||
if lwork < n { | ||
panic(badWork) | ||
} | ||
if len(work) < lwork { | ||
panic(badWork) | ||
|
||
switch { | ||
case len(a) < (n-1)*lda+n: | ||
panic("lapack: insufficient matrix slice length") | ||
case len(ipiv) < n: | ||
panic(badIpiv) | ||
} | ||
if n == 0 { | ||
|
||
nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1) | ||
lworkopt := float64(n * nb) | ||
if lwork == -1 { | ||
work[0] = lworkopt | ||
return true | ||
} | ||
|
||
// Form inv(U). | ||
ok = impl.Dtrtri(blas.Upper, blas.NonUnit, n, a, lda) | ||
if !ok { | ||
return false | ||
} | ||
|
||
nbmin := 2 | ||
ldwork := nb | ||
if nb > 1 && nb < n { | ||
iws := max(ldwork*n, 1) | ||
if 1 < nb && nb < n { | ||
iws := max(n*ldwork, 1) | ||
if lwork < iws { | ||
nb = lwork / ldwork | ||
nb = lwork / n | ||
nbmin = max(2, impl.Ilaenv(2, "DGETRI", " ", n, -1, -1, -1)) | ||
} | ||
} | ||
|
||
bi := blas64.Implementation() | ||
// Solve the equation inv(A)*L = inv(U) for inv(A). | ||
// TODO(btracey): Replace this with a more row-major oriented algorithm. | ||
if nb < nbmin || nb >= n { | ||
if nb < nbmin || n <= nb { | ||
// Unblocked code. | ||
for j := n - 1; j >= 0; j-- { | ||
for i := j + 1; i < n; i++ { | ||
// Copy current column of L to work and replace with zeros. | ||
work[i*ldwork] = a[i*lda+j] | ||
a[i*lda+j] = 0 | ||
} | ||
if j < n { | ||
// Compute current column of inv(A). | ||
if j < n-1 { | ||
bi.Dgemv(blas.NoTrans, n, n-j-1, -1, a[(j+1):], lda, work[(j+1)*ldwork:], ldwork, 1, a[j:], lda) | ||
} | ||
} | ||
} else { | ||
// Blocked code. | ||
nn := ((n - 1) / nb) * nb | ||
for j := nn; j >= 0; j -= nb { | ||
jb := min(nb, n-j) | ||
for jj := j; jj < j+jb-1; jj++ { | ||
// Copy current block column of L to work and replace | ||
// with zeros. | ||
for jj := j; jj < j+jb; jj++ { | ||
for i := jj + 1; i < n; i++ { | ||
work[i*ldwork+(jj-j)] = a[i*lda+jj] | ||
a[i*lda+jj] = 0 | ||
} | ||
} | ||
// Compute current block column of inv(A). | ||
if j+jb < n { | ||
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, jb, n-j-jb, -1, a[(j+jb):], lda, work[(j+jb)*ldwork:], ldwork, 1, a[j:], lda) | ||
bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda) | ||
} | ||
bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nasty non-bracketed conditional blocks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the code there is particularly difficult to parse. I had to stare on it for a while and still wasn't completely sure if dtrsm is under the if or not. |
||
} | ||
} | ||
// Apply column interchanges. | ||
for j := n - 2; j >= 0; j-- { | ||
jp := ipiv[j] | ||
if jp != j { | ||
bi.Dswap(n, a[j:], lda, a[jp:], lda) | ||
} | ||
} | ||
work[0] = lworkopt | ||
return true | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't look like a row vs col major issue, so does this affect the NETLIB implementation as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a row vs column major issue. This line calculates "how many columns fit into the workspace we have". For reference the column size is ldwork, for us it's n.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes, the assignment to
ldwork
above. Thanks.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep ... sorry.