From 8cb42192e0e0b08ed60c6c306a63d0b2ea460e51 Mon Sep 17 00:00:00 2001 From: Vladimir Chalupecky Date: Mon, 11 Mar 2019 17:33:47 +0100 Subject: [PATCH] lapack/netlib: adopt parameter checks from lapack/gonum --- lapack/netlib/lapack.go | 2423 ++++++++++++++++++++++++++------------- 1 file changed, 1620 insertions(+), 803 deletions(-) diff --git a/lapack/netlib/lapack.go b/lapack/netlib/lapack.go index 3e7cedfa..f393ef1e 100644 --- a/lapack/netlib/lapack.go +++ b/lapack/netlib/lapack.go @@ -13,48 +13,6 @@ import ( "gonum.org/v1/netlib/lapack/lapacke" ) -func min(m, n int) int { - if m < n { - return m - } - return n -} - -func max(m, n int) int { - if m < n { - return n - } - return m -} - -// checkMatrix verifies the parameters of a matrix input. -// Copied from lapack/native. Keep in sync. -func checkMatrix(m, n int, a []float64, lda int) { - if m < 0 { - panic("lapack: has negative number of rows") - } - if n < 0 { - panic("lapack: has negative number of columns") - } - if lda < n { - panic("lapack: stride less than number of columns") - } - if len(a) < (m-1)*lda+n { - panic("lapack: insufficient matrix slice length") - } -} - -// checkVector verifies the parameters of a vector input. -// Copied from lapack/native. Keep in sync. -func checkVector(n int, v []float64, inc int) { - if n < 0 { - panic("lapack: negative vector length") - } - if (inc > 0 && (n-1)*inc >= len(v)) || (inc < 0 && (1-n)*inc >= len(v)) { - panic("lapack: insufficient vector slice length") - } -} - // Implementation is the cgo-based C implementation of LAPACK routines. type Implementation struct{} @@ -91,15 +49,28 @@ var _ lapack.Float64 = Implementation{} // // Dgeqp3 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgeqp3(m, n int, a []float64, lda int, jpvt []int, tau, work []float64, lwork int) { - checkMatrix(m, n, a, lda) - if len(jpvt) != n { - panic(badIpiv) + minmn := min(m, n) + iws := 3*n + 1 + if minmn == 0 { + iws = 1 } - if len(tau) != min(m, n) { - panic(badTau) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < iws && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if len(work) < max(1, lwork) { - panic(badWork) + + // Quick return if possible. + if minmn == 0 { + work[0] = 1 + return } // Don't update jpvt if querying lwkopt. @@ -108,17 +79,24 @@ func (impl Implementation) Dgeqp3(m, n int, a []float64, lda int, jpvt []int, ta return } - jpvt32 := make([]int32, len(jpvt)) + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(jpvt) != n: + panic(badLenJpvt) + case len(tau) < minmn: + panic(shortTau) + } + + jpvt32 := make([]int32, n) for i, v := range jpvt { v++ if v != int(int32(v)) || v < 0 || n < v { - panic("lapack: jpvt element out of range") + panic(badJpvt) } jpvt32[i] = int32(v) } - lapacke.Dgeqp3(m, n, a, lda, jpvt32, tau, work, lwork) - for i, v := range jpvt32 { jpvt[i] = int(v - 1) } @@ -147,18 +125,36 @@ func (impl Implementation) Dgeqp3(m, n int, a []float64, lda int, jpvt []int, ta // // Dgerqf is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgerqf(m, n int, a []float64, lda int, tau, work []float64, lwork int) { - checkMatrix(m, n, a, lda) - - if len(work) < max(1, lwork) { + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, m) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): panic(shortWork) } - if lwork != -1 && lwork < max(1, m) { - panic(badWork) - } + // Quick return if possible. k := min(m, n) - if len(tau) != k { - panic(badTau) + if k == 0 { + work[0] = 1 + return + } + + if lwork == -1 { + lapacke.Dgerqf(m, n, a, lda, tau, work, -1) + return + } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) != k: + panic(badLenTau) } lapacke.Dgerqf(m, n, a, lda, tau, work, lwork) @@ -182,25 +178,22 @@ func (impl Implementation) Dgerqf(m, n int, a []float64, lda int, tau, work []fl // // Dlacn2 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dlacn2(n int, v, x []float64, isgn []int, est float64, kase int, isave *[3]int) (float64, int) { - if n < 1 { - panic("lapack: non-positive n") - } - checkVector(n, x, 1) - checkVector(n, v, 1) - if len(isgn) < n { - panic("lapack: insufficient isgn length") - } - if isave[0] < 0 || isave[0] > 5 { - panic("lapack: bad isave value") - } - if isave[0] == 0 && kase != 0 { - panic("lapack: bad isave value") + switch { + case n < 1: + panic(nLT1) + case len(v) < n: + panic(shortV) + case len(x) < n: + panic(shortX) + case len(isgn) < n: + panic(shortIsgn) + case isave[0] < 0 || 5 < isave[0]: + panic(badIsave) + case isave[0] == 0 && kase != 0: + panic(badIsave) } isgn32 := make([]int32, n) - for i, v := range isgn { - isgn32[i] = int32(v) - } pest := []float64{est} // Save one allocation by putting isave and kase into the same slice. isavekase := []int32{int32(isave[0]), int32(isave[1]), int32(isave[2]), int32(kase)} @@ -219,8 +212,30 @@ func (impl Implementation) Dlacn2(n int, v, x []float64, isgn []int, est float64 // a triangular portion with blas.Upper or blas.Lower, or can specify all of the // elemest with blas.All. func (impl Implementation) Dlacpy(uplo blas.Uplo, m, n int, a []float64, lda int, b []float64, ldb int) { - checkMatrix(m, n, a, lda) - checkMatrix(m, n, b, ldb) + switch { + case uplo != blas.Upper && uplo != blas.Lower && uplo != blas.All: + panic(badUplo) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, n): + panic(badLdB) + } + + if m == 0 || n == 0 { + return + } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(b) < (m-1)*ldb+n: + panic(shortB) + } + lapacke.Dlacpy(byte(uplo), m, n, a, lda, b, ldb) } @@ -237,12 +252,29 @@ func (impl Implementation) Dlacpy(uplo blas.Uplo, m, n int, a []float64, lda int // // k must have length n, otherwise Dlapmt will panic. k is zero-indexed. func (impl Implementation) Dlapmt(forward bool, m, n int, x []float64, ldx int, k []int) { - checkMatrix(m, n, x, ldx) - if len(k) != n { - panic(badKperm) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case ldx < max(1, n): + panic(badLdX) } - if n <= 1 { + // Quick return if possible. + if m == 0 || n == 0 { + return + } + + switch { + case len(x) < (m-1)*ldx+n: + panic(shortX) + case len(k) != n: + panic(badLenK) + } + + // Quick return if possible. + if n == 1 { return } @@ -250,7 +282,7 @@ func (impl Implementation) Dlapmt(forward bool, m, n int, x []float64, ldx int, if forward { forwrd = 1 } - k32 := make([]int32, len(k)) + k32 := make([]int32, n) for i, v := range k { v++ // Convert to 1-based indexing. if v != int(int32(v)) { @@ -258,7 +290,6 @@ func (impl Implementation) Dlapmt(forward bool, m, n int, x []float64, ldx int, } k32[i] = int32(v) } - lapacke.Dlapmt(forwrd, m, n, x, ldx, k32) } @@ -315,34 +346,58 @@ func (Implementation) Dlapy2(x, y float64) float64 { // // Dlarfb is an internal routine. It is exported for testing purposes. func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) { - if side != blas.Left && side != blas.Right { - panic(badSide) + nv := m + if side == blas.Right { + nv = n } - if trans != blas.Trans && trans != blas.NoTrans { + switch { + case side != blas.Left && side != blas.Right: + panic(badSide) + case trans != blas.Trans && trans != blas.NoTrans: panic(badTrans) - } - if direct != lapack.Forward && direct != lapack.Backward { + case direct != lapack.Forward && direct != lapack.Backward: panic(badDirect) - } - if store != lapack.ColumnWise && store != lapack.RowWise { - panic(badStore) - } - checkMatrix(m, n, c, ldc) - if k < 0 { + case store != lapack.ColumnWise && store != lapack.RowWise: + panic(badStoreV) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: panic(kLT0) + case store == lapack.ColumnWise && ldv < max(1, k): + panic(badLdV) + case store == lapack.RowWise && ldv < max(1, nv): + panic(badLdV) + case ldt < max(1, k): + panic(badLdT) + case ldc < max(1, n): + panic(badLdC) + case ldwork < max(1, k): + panic(badLdWork) + } + + if m == 0 || n == 0 { + return } - checkMatrix(k, k, t, ldt) - nv := m + nw := n if side == blas.Right { - nv = n nw = m } - if store == lapack.ColumnWise { - checkMatrix(nv, k, v, ldv) - } else { - checkMatrix(k, nv, v, ldv) + switch { + case store == lapack.ColumnWise && len(v) < (nv-1)*ldv+k: + panic(shortV) + case store == lapack.RowWise && len(v) < (k-1)*ldv+nv: + panic(shortV) + case len(t) < (k-1)*ldt+k: + panic(shortT) + case len(c) < (m-1)*ldc+n: + panic(shortC) + case len(work) < (nw-1)*ldwork+k: + panic(shortWork) } + // TODO(vladimir-ch): Replace the following two lines with // checkMatrix(nw, k, work, ldwork) // if and when the issue @@ -367,13 +422,25 @@ func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack // // Dlarfg is an internal routine. It is exported for testing purposes. func (impl Implementation) Dlarfg(n int, alpha float64, x []float64, incX int) (beta, tau float64) { - if n < 0 { + switch { + case n < 0: panic(nLT0) + case incX <= 0: + panic(badIncX) } + if n <= 1 { return alpha, 0 } - checkVector(n-1, x, incX) + + aincX := incX + if aincX < 0 { + aincX = -aincX + } + if len(x) < 1+(n-2)*aincX { + panic(shortX) + } + _alpha := []float64{alpha} _tau := []float64{0} lapacke.Dlarfg(n, _alpha, x, incX, _tau) @@ -398,24 +465,38 @@ func (impl Implementation) Dlarfg(n int, alpha float64, x []float64, incX int) ( // tau contains the scalar factors of the elementary reflectors H_i. // // Dlarft is an internal routine. It is exported for testing purposes. -func (Implementation) Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, - v []float64, ldv int, tau []float64, t []float64, ldt int) { - if n == 0 { - return +func (Implementation) Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int) { + mv, nv := n, k + if store == lapack.RowWise { + mv, nv = k, n } - if n < 0 || k < 0 { - panic(negDimension) - } - if direct != lapack.Forward && direct != lapack.Backward { + switch { + case direct != lapack.Forward && direct != lapack.Backward: panic(badDirect) + case store != lapack.RowWise && store != lapack.ColumnWise: + panic(badStoreV) + case n < 0: + panic(nLT0) + case k < 1: + panic(kLT1) + case ldv < max(1, nv): + panic(badLdV) + case len(tau) < k: + panic(shortTau) + case ldt < max(1, k): + panic(shortT) } - if store != lapack.RowWise && store != lapack.ColumnWise { - panic(badStore) + + if n == 0 { + return } - if len(tau) < k { - panic(badTau) + + switch { + case len(v) < (mv-1)*ldv+nv: + panic(shortV) + case len(t) < (k-1)*ldt+k: + panic(shortT) } - checkMatrix(k, k, t, ldt) lapacke.Dlarft(byte(direct), byte(store), n, k, v, ldv, tau, t, ldt) } @@ -429,15 +510,25 @@ func (Implementation) Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int // If norm == lapack.MaxColumnSum, work must be of length n, and this function will panic otherwise. // There are no restrictions on work for the other matrix norms. func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64, lda int, work []float64) float64 { - checkMatrix(m, n, a, lda) - switch norm { - case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.Frobenius, lapack.MaxAbs: - default: + switch { + case norm != lapack.MaxRowSum && norm != lapack.MaxColumnSum && norm != lapack.Frobenius && norm != lapack.MaxAbs: panic(badNorm) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. + if m == 0 || n == 0 { + return 0 } - if norm == lapack.MaxColumnSum && len(work) < n { - panic(badWork) + + switch { + case len(a) < (m-1)*lda+n: + panic(badLdA) + case norm == lapack.MaxColumnSum && len(work) < n: + panic(shortWork) } + return lapacke.Dlange(byte(norm), m, n, a, lda, work) } @@ -445,18 +536,29 @@ func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64, // norm == lapack.MaxColumnSum or norm == lapackMaxRowSum work must have length // at least n, otherwise work is unused. func (impl Implementation) Dlansy(norm lapack.MatrixNorm, uplo blas.Uplo, n int, a []float64, lda int, work []float64) float64 { - checkMatrix(n, n, a, lda) - switch norm { - case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.Frobenius, lapack.MaxAbs: - default: + switch { + case norm != lapack.MaxRowSum && norm != lapack.MaxColumnSum && norm != lapack.Frobenius && norm != lapack.MaxAbs: panic(badNorm) + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - if (norm == lapack.MaxColumnSum || norm == lapack.MaxRowSum) && len(work) < n { - panic(badWork) + + // Quick return if possible. + if n == 0 { + return 0 } - if uplo != blas.Upper && uplo != blas.Lower { - panic(badUplo) + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case (norm == lapack.MaxColumnSum || norm == lapack.MaxRowSum) && len(work) < n: + panic(shortWork) } + return lapacke.Dlansy(byte(norm), byte(uplo), n, a, lda, work) } @@ -464,21 +566,32 @@ func (impl Implementation) Dlansy(norm lapack.MatrixNorm, uplo blas.Uplo, n int, // norm == lapack.MaxColumnSum work must have length at least n, otherwise work // is unused. func (impl Implementation) Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, m, n int, a []float64, lda int, work []float64) float64 { - checkMatrix(m, n, a, lda) - switch norm { - case lapack.MaxRowSum, lapack.MaxColumnSum, lapack.Frobenius, lapack.MaxAbs: - default: + switch { + case norm != lapack.MaxRowSum && norm != lapack.MaxColumnSum && norm != lapack.Frobenius && norm != lapack.MaxAbs: panic(badNorm) - } - if uplo != blas.Upper && uplo != blas.Lower { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) - } - if diag != blas.Unit && diag != blas.NonUnit { + case diag != blas.Unit && diag != blas.NonUnit: panic(badDiag) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. + minmn := min(m, n) + if minmn == 0 { + return 0 } - if norm == lapack.MaxColumnSum && len(work) < n { - panic(badWork) + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case norm == lapack.MaxColumnSum && len(work) < n: + panic(shortWork) } + return lapacke.Dlantr(byte(norm), byte(uplo), byte(diag), m, n, a, lda, work) } @@ -501,20 +614,35 @@ func (impl Implementation) Dlantr(norm lapack.MatrixNorm, uplo blas.Uplo, diag b // == blas.Right, otherwise Dlarfx will panic. work is not referenced if H has // order < 11. func (impl Implementation) Dlarfx(side blas.Side, m, n int, v []float64, tau float64, c []float64, ldc int, work []float64) { - checkMatrix(m, n, c, ldc) - switch side { - case blas.Left: - checkVector(m, v, 1) - if len(work) < n && m > 10 { - panic(badWork) - } - case blas.Right: - checkVector(n, v, 1) - if len(work) < m && n > 10 { - panic(badWork) - } - default: + switch { + case side != blas.Left && side != blas.Right: panic(badSide) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case ldc < max(1, n): + panic(badLdC) + } + + // Quick return if possible. + if m == 0 || n == 0 { + return + } + + nh := m + lwork := n + if side == blas.Right { + nh = n + lwork = m + } + switch { + case len(v) < nh: + panic(shortV) + case len(c) < (m-1)*ldc+n: + panic(shortC) + case nh > 10 && len(work) < lwork: + panic(shortWork) } lapacke.Dlarfx(byte(side), m, n, v, tau, c, ldc, work) @@ -527,13 +655,39 @@ func (impl Implementation) Dlarfx(side blas.Side, m, n int, v []float64, tau flo // // Dlascl is an internal routine. It is exported for testing purposes. func (impl Implementation) Dlascl(kind lapack.MatrixType, kl, ku int, cfrom, cto float64, m, n int, a []float64, lda int) { - checkMatrix(m, n, a, lda) - if cfrom == 0 { - panic(zeroDiv) + switch kind { + default: + panic(badMatrixType) + case 'H', 'B', 'Q', 'Z': // See dlascl.f. + case lapack.General, lapack.UpperTri, lapack.LowerTri: + if lda < max(1, n) { + panic(badLdA) + } } - if math.IsNaN(cfrom) || math.IsNaN(cto) { - panic(nanScale) + switch { + case cfrom == 0: + panic(zeroCFrom) + case math.IsNaN(cfrom): + panic(nanCFrom) + case math.IsNaN(cto): + panic(nanCTo) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + } + + if n == 0 || m == 0 { + return + } + + switch kind { + case lapack.General, lapack.UpperTri, lapack.LowerTri: + if len(a) < (m-1)*lda+n { + panic(shortA) + } } + lapacke.Dlascl(byte(kind), kl, ku, cfrom, cto, m, n, a, lda) } @@ -545,7 +699,24 @@ func (impl Implementation) Dlascl(kind lapack.MatrixType, kl, ku int, cfrom, cto // // Dlaset is an internal routine. It is exported for testing purposes. func (impl Implementation) Dlaset(uplo blas.Uplo, m, n int, alpha, beta float64, a []float64, lda int) { - checkMatrix(m, n, a, lda) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + minmn := min(m, n) + if minmn == 0 { + return + } + + if len(a) < (m-1)*lda+n { + panic(shortA) + } + lapacke.Dlaset(byte(uplo), m, n, alpha, beta, a, lda) } @@ -556,12 +727,15 @@ func (impl Implementation) Dlaset(uplo blas.Uplo, m, n int, alpha, beta float64, // // Dlasrt is an internal routine. It is exported for testing purposes. func (impl Implementation) Dlasrt(s lapack.Sort, n int, d []float64) { - checkVector(n, d, 1) - switch s { - default: + switch { + case s != lapack.SortIncreasing && s != lapack.SortDecreasing: panic(badSort) - case lapack.SortIncreasing, lapack.SortDecreasing: + case n < 0: + panic(nLT0) + case len(d) < n: + panic(shortD) } + lapacke.Dlasrt(byte(s), n, d[:n]) } @@ -584,15 +758,27 @@ func (impl Implementation) Dlaswp(n int, a []float64, lda, k1, k2 int, ipiv []in panic(badK2) case k1 < 0 || k2 < k1: panic(badK1) + case lda < max(1, n): + panic(badLdA) + case len(a) < (k2-1)*lda+n: + panic(shortA) case len(ipiv) != k2+1: - panic(badIpiv) + panic(badLenIpiv) case incX != 1 && incX != -1: panic(absIncNotOne) } - ipiv32 := make([]int32, len(ipiv)) + if n == 0 { + return + } + + ipiv32 := make([]int32, k2+1) for i, v := range ipiv { - ipiv32[i] = int32(v + 1) + v++ + if v != int(int32(v)) { + panic("lapack: ipiv element out of range") + } + ipiv32[i] = int32(v) } lapacke.Dlaswp(n, a, lda, k1+1, k2+1, ipiv32, incX) } @@ -603,16 +789,24 @@ func (impl Implementation) Dlaswp(n int, a []float64, lda, k1, k2 int, ipiv []in // is computed and stored in-place into a. If a is not positive definite, false // is returned. This is the blocked version of the algorithm. func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool) { - // ul is checked in lapacke.Dpotrf. - if n < 0 { + switch { + case ul != blas.Upper && ul != blas.Lower: + panic(badUplo) + case n < 0: panic(nLT0) - } - if lda < n { + case lda < max(1, n): panic(badLdA) } + + // Quick return if possible. if n == 0 { return true } + + if len(a) < (n-1)*lda+n { + panic(shortA) + } + return lapacke.Dpotrf(byte(ul), n, a, lda) } @@ -631,8 +825,6 @@ func (impl Implementation) Dpotri(uplo blas.Uplo, n int, a []float64, lda int) ( panic(nLT0) case lda < max(1, n): panic(badLdA) - case len(a) < (n-1)*lda+n: - panic("lapack: a has insufficient length") } // Quick return if possible. @@ -640,6 +832,10 @@ func (impl Implementation) Dpotri(uplo blas.Uplo, n int, a []float64, lda int) ( return true } + if len(a) < (n-1)*lda+n { + panic(shortA) + } + return lapacke.Dpotri(byte(uplo), n, a, lda) } @@ -651,12 +847,31 @@ func (impl Implementation) Dpotri(uplo blas.Uplo, n int, a []float64, lda int) ( // as computed by Dpotrf. On entry, B contains the right-hand side matrix B, on // return it contains the solution matrix X. func (Implementation) Dpotrs(uplo blas.Uplo, n, nrhs int, a []float64, lda int, b []float64, ldb int) { - if uplo != blas.Upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) - } - checkMatrix(n, n, a, lda) - checkMatrix(n, nrhs, b, ldb) - + case n < 0: + panic(nLT0) + case nrhs < 0: + panic(nrhsLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, nrhs): + panic(badLdB) + } + + // Quick return if possible. + if n == 0 || nrhs == 0 { + return + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(b) < (n-1)*ldb+nrhs: + panic(shortB) + } + lapacke.Dpotrs(byte(uplo), n, nrhs, a, lda, b, ldb) } @@ -704,18 +919,31 @@ func (Implementation) Dpotrs(uplo blas.Uplo, n, nrhs int, a []float64, lda int, // // Dgebal is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgebal(job lapack.BalanceJob, n int, a []float64, lda int, scale []float64) (ilo, ihi int) { - switch job { - default: + switch { + case job != lapack.BalanceNone && job != lapack.Permute && job != lapack.Scale && job != lapack.PermuteScale: panic(badBalanceJob) - case lapack.BalanceNone, lapack.Permute, lapack.Scale, lapack.PermuteScale: + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + ilo = 0 + ihi = n - 1 + + if n == 0 { + return ilo, ihi } - checkMatrix(n, n, a, lda) - if len(scale) != n { - panic("lapack: bad length of scale") + + switch { + case len(scale) != n: + panic(shortScale) + case len(a) < (n-1)*lda+n: + panic(shortA) } - ilo32 := make([]int32, 1) - ihi32 := make([]int32, 1) + ilo32 := []int32{0} + ihi32 := []int32{0} lapacke.Dgebal(byte(job), n, a, lda, ilo32, ihi32, scale) ilo = int(ilo32[0]) - 1 ihi = int(ihi32[0]) - 1 @@ -740,26 +968,38 @@ func (impl Implementation) Dgebal(job lapack.BalanceJob, n int, a []float64, lda // // Dgebak is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgebak(job lapack.BalanceJob, side lapack.EVSide, n, ilo, ihi int, scale []float64, m int, v []float64, ldv int) { - switch job { - default: + switch { + case job != lapack.BalanceNone && job != lapack.Permute && job != lapack.Scale && job != lapack.PermuteScale: panic(badBalanceJob) - case lapack.BalanceNone, lapack.Permute, lapack.Scale, lapack.PermuteScale: - } - var bside blas.Side - switch side { - default: + case side != lapack.EVLeft && side != lapack.EVRight: panic(badEVSide) - case lapack.EVLeft: - bside = blas.Left - case lapack.EVRight: - bside = blas.Right - } - checkMatrix(n, m, v, ldv) - switch { + case n < 0: + panic(nLT0) case ilo < 0 || max(0, n-1) < ilo: panic(badIlo) case ihi < min(ilo, n-1) || n <= ihi: panic(badIhi) + case m < 0: + panic(mLT0) + case ldv < max(1, m): + panic(badLdV) + } + + // Quick return if possible. + if n == 0 || m == 0 { + return + } + + switch { + case len(scale) < n: + panic(shortScale) + case len(v) < (n-1)*ldv+m: + panic(shortV) + } + + // Quick return if possible. + if job == lapack.BalanceNone { + return } // Convert permutation indices to 1-based. @@ -769,7 +1009,7 @@ func (impl Implementation) Dgebak(job lapack.BalanceJob, side lapack.EVSide, n, for j := ihi + 1; j < n; j++ { scale[j]++ } - lapacke.Dgebak(byte(job), byte(bside), n, ilo+1, ihi+1, scale, m, v, ldv) + lapacke.Dgebak(byte(job), byte(side), n, ilo+1, ihi+1, scale, m, v, ldv) // Convert permutation indices back to 0-based. for j := 0; j < ilo; j++ { scale[j]-- @@ -815,37 +1055,49 @@ func (impl Implementation) Dgebak(job lapack.BalanceJob, side lapack.EVSide, n, // // Dbdsqr returns whether the decomposition was successful. func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) { - if uplo != blas.Upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) + case n < 0: + panic(nLT0) + case ncvt < 0: + panic(ncvtLT0) + case nru < 0: + panic(nruLT0) + case ncc < 0: + panic(nccLT0) + case ldvt < max(1, ncvt): + panic(badLdVT) + case (ldu < max(1, n) && nru > 0) || (ldu < 1 && nru == 0): + panic(badLdU) + case ldc < max(1, ncc): + panic(badLdC) + } + + // Quick return if possible. + if n == 0 { + return true } - if ncvt != 0 { - checkMatrix(n, ncvt, vt, ldvt) + + if len(vt) < (n-1)*ldvt+ncvt && ncvt != 0 { + panic(shortVT) } - if nru != 0 { - checkMatrix(nru, n, u, ldu) + if len(u) < (nru-1)*ldu+n && nru != 0 { + panic(shortU) } - if ncc != 0 { - checkMatrix(n, ncc, c, ldc) + if len(c) < (n-1)*ldc+ncc && ncc != 0 { + panic(shortC) } if len(d) < n { - panic(badD) + panic(shortD) } if len(e) < n-1 { - panic(badE) + panic(shortE) } if len(work) < 4*(n-1) { - panic(badWork) - } - // An address must be passed to cgo. If lengths are zero, allocate a slice. - if len(vt) == 0 { - vt = make([]float64, 1) - } - if len(u) == 0 { - vt = make([]float64, 1) - } - if len(c) == 0 { - c = make([]float64, 1) + panic(shortWork) } + return lapacke.Dbdsqr(byte(uplo), n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc, work) } @@ -893,25 +1145,42 @@ func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, v // // Dgebrd is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgebrd(m, n int, a []float64, lda int, d, e, tauQ, tauP, work []float64, lwork int) { - checkMatrix(m, n, a, lda) - minmn := min(m, n) - if len(d) < minmn { - panic(badD) - } - if len(e) < minmn-1 { - panic(badE) - } - if len(tauQ) < minmn { - panic(badTauQ) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, max(m, n)) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if len(tauP) < minmn { - panic(badTauP) + + // Quick return if possible. + minmn := min(m, n) + if minmn == 0 { + work[0] = 1 + return } - if lwork != -1 && lwork < max(1, max(m, n)) { - panic(badWork) + + if lwork == -1 { + lapacke.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, -1) + return } - if len(work) < max(1, lwork) { - panic(badWork) + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(d) < minmn: + panic(shortD) + case len(e) < minmn-1: + panic(shortE) + case len(tauQ) < minmn: + panic(shortTauQ) + case len(tauP) < minmn: + panic(shortTauP) } lapacke.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, lwork) @@ -928,30 +1197,33 @@ func (impl Implementation) Dgebrd(m, n int, a []float64, lda int, d, e, tauQ, ta // work is a temporary data slice of length at least 4*n and Dgecon will panic otherwise. // // iwork is a temporary data slice of length at least n and Dgecon will panic otherwise. -// Elements of iwork must fit within the int32 type or Dgecon will panic. func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 { - checkMatrix(n, n, a, lda) - if norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum { - panic("bad norm") - } - if len(work) < 4*n { - panic(badWork) + switch { + case norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum: + panic(badNorm) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - if len(iwork) < n { - panic(badWork) + + // Quick return if possible. + if n == 0 { + return 1 } - rcond := make([]float64, 1) - _iwork := make([]int32, len(iwork)) - for i, v := range iwork { - if v != int(int32(v)) { - panic("lapack: iwork element out of range") - } - _iwork[i] = int32(v) + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(work) < 4*n: + panic(shortWork) + case len(iwork) < n: + panic(shortIWork) } + + rcond := []float64{0} + _iwork := make([]int32, n) lapacke.Dgecon(byte(norm), n, a, lda, anorm, rcond, work, _iwork) - for i, v := range _iwork { - iwork[i] = int(v) - } return rcond[0] } @@ -973,13 +1245,30 @@ func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, ld // // Work is temporary storage of length at least m and this function will panic otherwise. func (impl Implementation) Dgelq2(m, n int, a []float64, lda int, tau, work []float64) { - checkMatrix(m, n, a, lda) - if len(tau) < min(m, n) { - panic(badTau) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. + k := min(m, n) + if k == 0 { + return } - if len(work) < m { - panic(badWork) + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(shortTau) + case len(work) < m: + panic(shortWork) } + lapacke.Dgelq2(m, n, a, lda, tau, work) } @@ -995,20 +1284,37 @@ func (impl Implementation) Dgelq2(m, n int, a []float64, lda int, tau, work []fl // // tau must have length at least min(m,n), and this function will panic otherwise. func (impl Implementation) Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int) { - if lwork == -1 { - work[0] = float64(m) - return - } - checkMatrix(m, n, a, lda) - if len(work) < lwork { + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, m) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): panic(shortWork) } - if lwork < m { - panic(badWork) + + k := min(m, n) + if k == 0 { + work[0] = 1 + return + } + + if lwork == -1 { + lapacke.Dgelqf(m, n, a, lda, tau, work, -1) + return } - if len(tau) < min(m, n) { - panic(badTau) + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(shortTau) } + lapacke.Dgelqf(m, n, a, lda, tau, work, lwork) } @@ -1035,14 +1341,30 @@ func (impl Implementation) Dgelqf(m, n int, a []float64, lda int, tau, work []fl // // Work is temporary storage of length at least n and this function will panic otherwise. func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []float64) { - checkMatrix(m, n, a, lda) - if len(work) < n { - panic(badWork) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case len(work) < n: + panic(shortWork) } + + // Quick return if possible. k := min(m, n) - if len(tau) < k { - panic(badTau) + if k == 0 { + return } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(shortTau) + } + lapacke.Dgeqr2(m, n, a, lda, tau, work) } @@ -1059,17 +1381,38 @@ func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []fl // // tau must have length at least min(m,n), and this function will panic otherwise. func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) { - if len(work) < max(1, lwork) { + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, n) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): panic(shortWork) } - checkMatrix(m, n, a, lda) - if lwork < n && lwork != -1 { - panic(badWork) - } + + // Quick return if possible. k := min(m, n) - if len(tau) < k { - panic(badTau) + if k == 0 { + work[0] = 1 + return } + + if lwork == -1 { + lapacke.Dgeqrf(m, n, a, lda, tau, work, -1) + return + } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(shortTau) + } + lapacke.Dgeqrf(m, n, a, lda, tau, work, lwork) } @@ -1130,21 +1473,38 @@ func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []fl // Dgehrd is an internal routine. It is exported for testing purposes. func (impl Implementation) Dgehrd(n, ilo, ihi int, a []float64, lda int, tau, work []float64, lwork int) { switch { + case n < 0: + panic(nLT0) case ilo < 0 || max(0, n-1) < ilo: panic(badIlo) case ihi < min(ilo, n-1) || n <= ihi: panic(badIhi) + case lda < max(1, n): + panic(badLdA) case lwork < max(1, n) && lwork != -1: - panic(badWork) + panic(badLWork) case len(work) < lwork: panic(shortWork) } - if lwork != -1 { - checkMatrix(n, n, a, lda) - if len(tau) != n-1 && n > 0 { - panic(badTau) - } + + // Quick return if possible. + if n == 0 { + work[0] = 1 + return + } + + if lwork == -1 { + lapacke.Dgehrd(n, ilo+1, ihi+1, a, lda, tau, work, -1) + return + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(tau) != n-1: + panic(badLenTau) } + lapacke.Dgehrd(n, ilo+1, ihi+1, a, lda, tau, work, lwork) } @@ -1179,23 +1539,47 @@ func (impl Implementation) Dgehrd(n, ilo, ihi int, a []float64, lda int, tau, wo // length. func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool { mn := min(m, n) - if lwork == -1 { - work[0] = float64(mn + max(mn, nrhs)) + minwrk := mn + max(mn, nrhs) + switch { + case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: + panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case nrhs < 0: + panic(nrhsLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, nrhs): + panic(badLdB) + case lwork < max(1, minwrk) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) + } + + // Quick return if possible. + if mn == 0 || nrhs == 0 { + impl.Dlaset(blas.All, max(m, n), nrhs, 0, 0, b, ldb) + work[0] = 1 return true } - checkMatrix(m, n, a, lda) - checkMatrix(max(m, n), nrhs, b, ldb) - if len(work) < lwork { - panic(shortWork) + + if lwork == -1 { + return lapacke.Dgels(byte(trans), m, n, nrhs, a, lda, b, ldb, work, -1) } - if lwork < mn+max(mn, nrhs) { - panic(badWork) + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(b) < (max(m, n)-1)*ldb+nrhs: + panic(shortB) } + return lapacke.Dgels(byte(trans), m, n, nrhs, a, lda, b, ldb, work, lwork) } -const noSVDO = "dgesvd: not coded for overwrite" - // Dgesvd computes the singular value decomposition of the input matrix A. // // The singular value decomposition is @@ -1239,39 +1623,70 @@ const noSVDO = "dgesvd: not coded for overwrite" // // Dgesvd returns whether the decomposition successfully completed. func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool) { - checkMatrix(m, n, a, lda) - if jobU == lapack.SVDAll { - checkMatrix(m, m, u, ldu) - } else if jobU == lapack.SVDStore { - checkMatrix(m, min(m, n), u, ldu) - } - if jobVT == lapack.SVDAll { - checkMatrix(n, n, vt, ldvt) - } else if jobVT == lapack.SVDStore { - checkMatrix(min(m, n), n, vt, ldvt) - } - if jobU == lapack.SVDOverwrite && jobVT == lapack.SVDOverwrite { - panic(noSVDO) - } - if len(s) < min(m, n) { - panic(badS) - } - if jobU == lapack.SVDOverwrite || jobVT == lapack.SVDOverwrite { - panic("lapack: SVD not coded to overwrite original matrix") - } - minWork := max(5*min(m, n), 3*min(m, n)+max(m, n)) - if lwork != -1 { - if len(work) < lwork { - panic(badWork) - } - if lwork < minWork { - panic(badWork) - } + wantua := jobU == lapack.SVDAll + wantus := jobU == lapack.SVDStore + wantuo := jobU == lapack.SVDOverwrite + wantun := jobU == lapack.SVDNone + if !(wantua || wantus || wantuo || wantun) { + panic(badSVDJob) } - if lwork == -1 { - work[0] = float64(minWork) + + wantva := jobVT == lapack.SVDAll + wantvs := jobVT == lapack.SVDStore + wantvas := wantva || wantvs + wantvo := jobVT == lapack.SVDOverwrite + wantvn := jobVT == lapack.SVDNone + if !(wantva || wantvs || wantvo || wantvn) { + panic(badSVDJob) + } + + if wantuo && wantvo { + panic(bothSVDOver) + } + + minmn := min(m, n) + minwork := 1 + if minmn > 0 { + minwork = max(3*minmn+max(m, n), 5*minmn) + } + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case ldu < 1, wantua && ldu < m, wantus && ldu < minmn: + panic(badLdU) + case ldvt < 1 || (wantvas && ldvt < n): + panic(badLdVT) + case lwork < minwork && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) + } + + // Quick return if possible. + if minmn == 0 { + work[0] = 1 return true } + + if lwork == -1 { + return lapacke.Dgesvd(byte(jobU), byte(jobVT), m, n, a, lda, s, u, ldu, vt, ldvt, work, -1) + } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(s) < minmn: + panic(shortS) + case (len(u) < (m-1)*ldu+m && wantua) || (len(u) < (m-1)*ldu+minmn && wantus): + panic(shortU) + case (len(vt) < (n-1)*ldvt+n && wantva) || (len(vt) < (minmn-1)*ldvt+n && wantvs): + panic(shortVT) + } + return lapacke.Dgesvd(byte(jobU), byte(jobVT), m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork) } @@ -1292,11 +1707,28 @@ func (impl Implementation) Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float // system of equations. func (Implementation) Dgetf2(m, n int, a []float64, lda int, ipiv []int) (ok bool) { mn := min(m, n) - checkMatrix(m, n, a, lda) - if len(ipiv) < mn { - panic(badIpiv) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - ipiv32 := make([]int32, len(ipiv)) + + // Quick return if possible. + if mn == 0 { + return true + } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(ipiv) != mn: + panic(badLenIpiv) + } + + ipiv32 := make([]int32, mn) ok = lapacke.Dgetf2(m, n, a, lda, ipiv32) for i, v := range ipiv32 { ipiv[i] = int(v) - 1 // Transform to zero-indexed. @@ -1323,14 +1755,31 @@ func (Implementation) Dgetf2(m, n int, a []float64, lda int, ipiv []int) (ok boo // system of equations. func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) { mn := min(m, n) - checkMatrix(m, n, a, lda) - if len(ipiv) < mn { - panic(badIpiv) + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - ipiv32 := make([]int32, len(ipiv)) - ok = lapacke.Dgetrf(m, n, a, lda, ipiv32) - for i, v := range ipiv32 { - ipiv[i] = int(v) - 1 // Transform to zero-indexed. + + // Quick return if possible. + if mn == 0 { + return true + } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(ipiv) != mn: + panic(badLenIpiv) + } + + ipiv32 := make([]int32, mn) + ok = lapacke.Dgetrf(m, n, a, lda, ipiv32) + for i, v := range ipiv32 { + ipiv[i] = int(v) - 1 // Transform to zero-indexed. } return ok } @@ -1348,23 +1797,42 @@ func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (o // by the temporary space available. If lwork == -1, instead of performing Dgetri, // the optimal work length will be stored into work[0]. func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) { - checkMatrix(n, n, a, lda) - if len(ipiv) < n { - panic(badIpiv) + iws := max(1, n) + switch { + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < iws && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if lwork == -1 { - work[0] = float64(n) + + if n == 0 { + work[0] = 1 return true } - if lwork < n { - panic(badWork) + + if lwork == -1 { + return lapacke.Dgetri(n, a, lda, nil, work, -1) + return true } - if len(work) < lwork { - panic(badWork) + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(ipiv) != n: + panic(badLenIpiv) } - ipiv32 := make([]int32, len(ipiv)) + + ipiv32 := make([]int32, n) for i, v := range ipiv { - ipiv32[i] = int32(v) + 1 // Transform to one-indexed. + v++ // Transform to one-indexed. + if v != int(int32(v)) { + panic("lapack: ipiv element out of range") + } + ipiv32[i] = int32(v) } return lapacke.Dgetri(n, a, lda, ipiv32, work, lwork) } @@ -1381,14 +1849,40 @@ func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work // a and ipiv contain the LU factorization of A and the permutation indices as // computed by Dgetrf. ipiv is zero-indexed. func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) { - checkMatrix(n, n, a, lda) - checkMatrix(n, nrhs, b, ldb) - if len(ipiv) < n { - panic(badIpiv) + switch { + case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: + panic(badTrans) + case n < 0: + panic(nLT0) + case nrhs < 0: + panic(nrhsLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, nrhs): + panic(badLdB) + } + + // Quick return if possible. + if n == 0 || nrhs == 0 { + return + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(b) < (n-1)*ldb+nrhs: + panic(shortB) + case len(ipiv) != n: + panic(badLenIpiv) } - ipiv32 := make([]int32, len(ipiv)) + + ipiv32 := make([]int32, n) for i, v := range ipiv { - ipiv32[i] = int32(v) + 1 // Transform to one-indexed. + v++ // Transform to one-indexed. + if v != int(int32(v)) { + panic("lapack: ipiv element out of range") + } + ipiv32[i] = int32(v) } lapacke.Dgetrs(byte(trans), n, nrhs, a, lda, ipiv32, b, ldb) } @@ -1490,58 +1984,66 @@ func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64 // lwork is -1, work[0] holds the optimal lwork on return, but Dggsvd3 does // not perform the GSVD. func (impl Implementation) Dggsvd3(jobU, jobV, jobQ lapack.GSVDJob, m, n, p int, a []float64, lda int, b []float64, ldb int, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64, lwork int, iwork []int) (k, l int, ok bool) { - checkMatrix(m, n, a, lda) - checkMatrix(p, n, b, ldb) - - switch jobU { - case lapack.GSVDU: - checkMatrix(m, m, u, ldu) - case lapack.GSVDNone: - default: + wantu := jobU == lapack.GSVDU + wantv := jobV == lapack.GSVDV + wantq := jobQ == lapack.GSVDQ + switch { + case !wantu && jobU != lapack.GSVDNone: panic(badGSVDJob + "U") - } - switch jobV { - case lapack.GSVDV: - checkMatrix(p, p, v, ldv) - case lapack.GSVDNone: - default: + case !wantv && jobV != lapack.GSVDNone: panic(badGSVDJob + "V") - } - switch jobQ { - case lapack.GSVDQ: - checkMatrix(n, n, q, ldq) - case lapack.GSVDNone: - default: + case !wantq && jobQ != lapack.GSVDNone: panic(badGSVDJob + "Q") + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case p < 0: + panic(pLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, n): + panic(badLdB) + case ldu < 1, wantu && ldu < m: + panic(badLdU) + case ldv < 1, wantv && ldv < p: + panic(badLdV) + case ldq < 1, wantq && ldq < n: + panic(badLdQ) + case len(iwork) < n: + panic(shortWork) + case lwork < 1 && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if len(alpha) != n { - panic(badAlpha) - } - if len(beta) != n { - panic(badBeta) + // Determine optimal work length. + if lwork == -1 { + lapacke.Dggsvd3(byte(jobU), byte(jobV), byte(jobQ), m, n, p, nil, nil, a, lda, b, ldb, alpha, beta, u, ldu, v, ldv, q, ldq, work, -1, nil) + return 0, 0, true } - if lwork != -1 && lwork <= n { - panic(badWork) - } - if len(work) < max(1, lwork) { - panic(shortWork) - } - if len(iwork) < n { - panic(badWork) + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(b) < (p-1)*ldb+n: + panic(shortB) + case wantu && len(u) < (m-1)*ldu+m: + panic(shortU) + case wantv && len(v) < (p-1)*ldv+p: + panic(shortV) + case wantq && len(q) < (n-1)*ldq+n: + panic(shortQ) + case len(alpha) != n: + panic(badLenAlpha) + case len(beta) != n: + panic(badLenBeta) } _k := []int32{0} _l := []int32{0} - _iwork := make([]int32, len(iwork)) - for i, v := range iwork { - v++ - if v != int(int32(v)) { - panic("lapack: iwork element out of range") - } - _iwork[i] = int32(v) - } + _iwork := make([]int32, n) ok = lapacke.Dggsvd3(byte(jobU), byte(jobV), byte(jobQ), m, n, p, _k, _l, a, lda, b, ldb, alpha, beta, u, ldu, v, ldv, q, ldq, work, lwork, _iwork) for i, v := range _iwork { iwork[i] = int(v - 1) @@ -1593,61 +2095,67 @@ func (impl Implementation) Dggsvd3(jobU, jobV, jobQ lapack.GSVDJob, m, n, p int, // // Dggsvp3 is an internal routine. It is exported for testing purposes. func (impl Implementation) Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, iwork []int, tau, work []float64, lwork int) (k, l int) { - checkMatrix(m, n, a, lda) - checkMatrix(p, n, b, ldb) - wantu := jobU == lapack.GSVDU - if !wantu && jobU != lapack.GSVDNone { - panic(badGSVDJob + "U") - } - if jobU != lapack.GSVDNone { - checkMatrix(m, m, u, ldu) - } - wantv := jobV == lapack.GSVDV - if !wantv && jobV != lapack.GSVDNone { - panic(badGSVDJob + "V") - } - if jobV != lapack.GSVDNone { - checkMatrix(p, p, v, ldv) - } - wantq := jobQ == lapack.GSVDQ - if !wantq && jobQ != lapack.GSVDNone { + switch { + case !wantu && jobU != lapack.GSVDNone: + panic(badGSVDJob + "U") + case !wantv && jobV != lapack.GSVDNone: + panic(badGSVDJob + "V") + case !wantq && jobQ != lapack.GSVDNone: panic(badGSVDJob + "Q") - } - if jobQ != lapack.GSVDNone { - checkMatrix(n, n, q, ldq) + case m < 0: + panic(mLT0) + case p < 0: + panic(pLT0) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, n): + panic(badLdB) + case ldu < 1, wantu && ldu < m: + panic(badLdU) + case ldv < 1, wantv && ldv < p: + panic(badLdV) + case ldq < 1, wantq && ldq < n: + panic(badLdQ) + case len(iwork) != n: + panic(shortWork) + case lwork < 1 && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if len(tau) < n { - panic(badTau) - } - if len(iwork) != n { - panic(badWork) - } - if lwork != -1 && lwork < 1 { - panic(badWork) + if lwork == -1 { + lapacke.Dggsvp3(byte(jobU), byte(jobV), byte(jobQ), m, p, n, a, lda, b, ldb, tola, tolb, nil, nil, u, ldu, v, ldv, q, ldq, nil, tau, work, -1) + return 0, 0 } - if len(work) < max(1, lwork) { - panic(shortWork) + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(b) < (p-1)*ldb+n: + panic(shortB) + case wantu && len(u) < (m-1)*ldu+m: + panic(shortU) + case wantv && len(v) < (p-1)*ldv+p: + panic(shortV) + case wantq && len(q) < (n-1)*ldq+n: + panic(shortQ) + case len(tau) < n: + // tau check must come after lwkopt query since + // the Dggsvd3 call for lwkopt query may have + // lwork == -1, and tau is provided by work. + panic(shortTau) } _k := []int32{0} _l := []int32{0} - _iwork := make([]int32, len(iwork)) - for i, v := range iwork { - v++ - if v != int(int32(v)) { - panic("lapack: iwork element out of range") - } - _iwork[i] = int32(v) - } + _iwork := make([]int32, n) lapacke.Dggsvp3(byte(jobU), byte(jobV), byte(jobQ), m, p, n, a, lda, b, ldb, tola, tolb, _k, _l, u, ldu, v, ldv, q, ldq, _iwork, tau, work, lwork) - for i, v := range _iwork { - iwork[i] = int(v - 1) - } - return int(_k[0]), int(_l[0]) } @@ -1662,39 +2170,53 @@ func (impl Implementation) Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int, // P^T is of order n. If k < n, then Dorgbr returns the first m rows of P^T, // where n >= m >= k. If k >= n, then Dorgbr returns P^T as an n×n matrix. func (impl Implementation) Dorgbr(vect lapack.GenOrtho, m, n, k int, a []float64, lda int, tau, work []float64, lwork int) { + wantq := vect == lapack.GenerateQ mn := min(m, n) - var wantq bool - switch vect { - case lapack.GenerateQ: - wantq = true - case lapack.GeneratePT: - default: + switch { + case vect != lapack.GenerateQ && vect != lapack.GeneratePT: panic(badGenOrtho) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case wantq && n > m: + panic(nGTM) + case wantq && n < min(m, k): + panic("lapack: n < min(m,k)") + case !wantq && m > n: + panic(mGTN) + case !wantq && m < min(n, k): + panic("lapack: m < min(n,k)") + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, mn) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if wantq { - if m < n || n < min(m, k) || m < min(m, k) { - panic(badDims) - } - } else { - if n < m || m < min(n, k) || n < min(n, k) { - panic(badDims) - } - } - if wantq { - checkMatrix(m, k, a, lda) - } else { - checkMatrix(k, n, a, lda) + + // Quick return if possible. + if m == 0 || n == 0 { + work[0] = 1 + return } + if lwork == -1 { - work[0] = float64(mn) + lapacke.Dorgbr(byte(vect), m, n, k, a, lda, tau, work, -1) return } - if len(work) < lwork { - panic(badWork) - } - if lwork < mn { - panic(badWork) + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case wantq && len(tau) < min(m, k): + panic(shortTau) + case !wantq && len(tau) < min(n, k): + panic(shortTau) } + lapacke.Dorgbr(byte(vect), m, n, k, a, lda, tau, work, lwork) } @@ -1727,18 +2249,38 @@ func (impl Implementation) Dorgbr(vect lapack.GenOrtho, m, n, k int, a []float64 // // Dorghr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorghr(n, ilo, ihi int, a []float64, lda int, tau, work []float64, lwork int) { - checkMatrix(n, n, a, lda) nh := ihi - ilo switch { case ilo < 0 || max(1, n) <= ilo: panic(badIlo) case ihi < min(ilo, n-1) || n <= ihi: panic(badIhi) + case lda < max(1, n): + panic(badLdA) case lwork < max(1, nh) && lwork != -1: - panic(badWork) + panic(badLWork) case len(work) < max(1, lwork): panic(shortWork) } + + // Quick return if possible. + if n == 0 { + work[0] = 1 + return + } + + if lwork == -1 { + lapacke.Dorghr(n, ilo+1, ihi+1, a, lda, tau, work, -1) + return + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(tau) < n-1: + panic(shortTau) + } + lapacke.Dorghr(n, ilo+1, ihi+1, a, lda, tau, work, lwork) } @@ -1759,29 +2301,40 @@ func (impl Implementation) Dorghr(n, ilo, ihi int, a []float64, lda int, tau, wo // // Dorglq is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorglq(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) { - if lwork == -1 { - work[0] = float64(m) - return - } - checkMatrix(m, n, a, lda) - if k < 0 { + switch { + case m < 0: + panic(mLT0) + case n < m: + panic(nLTM) + case k < 0: panic(kLT0) - } - if k > m { + case k > m: panic(kGTM) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, m) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if m > n { - panic(nLTM) - } - if len(tau) < k { - panic(badTau) + + if m == 0 { + work[0] = 1 + return } - if len(work) < lwork { - panic(shortWork) + + if lwork == -1 { + lapacke.Dorglq(m, n, k, a, lda, tau, work, -1) + return } - if lwork < m { - panic(badWork) + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(shortTau) } + lapacke.Dorglq(m, n, k, a, lda, tau, work, lwork) } @@ -1809,24 +2362,40 @@ func (impl Implementation) Dorglq(m, n, k int, a []float64, lda int, tau, work [ // Dorgql is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) { switch { + case m < 0: + panic(mLT0) case n < 0: panic(nLT0) - case m < n: - panic(mLTN) + case n > m: + panic(nGTM) case k < 0: panic(kLT0) case k > n: panic(kGTN) + case lda < max(1, n): + panic(badLdA) case lwork < max(1, n) && lwork != -1: - panic(badWork) - case len(work) < lwork: + panic(badLWork) + case len(work) < max(1, lwork): panic(shortWork) } - if lwork != -1 { - checkMatrix(m, n, a, lda) - if len(tau) < k { - panic(badTau) - } + + // Quick return if possible. + if n == 0 { + work[0] = 1 + return + } + + if lwork == -1 { + lapacke.Dorgql(m, n, k, a, lda, tau, work, -1) + return + } + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(shortTau) } lapacke.Dorgql(m, n, k, a, lda, tau, work, lwork) @@ -1850,29 +2419,42 @@ func (impl Implementation) Dorgql(m, n, k int, a []float64, lda int, tau, work [ // // Dorgqr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) { - if lwork == -1 { - work[0] = float64(n) - return - } - checkMatrix(m, n, a, lda) - if k < 0 { + switch { + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case n > m: + panic(nGTM) + case k < 0: panic(kLT0) - } - if k > n { + case k > n: panic(kGTN) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, n) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if n > m { - panic(mLTN) - } - if len(tau) < k { - panic(badTau) + + if n == 0 { + work[0] = 1 + return } - if len(work) < lwork { - panic(shortWork) + + if lwork == -1 { + lapacke.Dorgqr(m, n, k, a, lda, tau, work, -1) + return } - if lwork < n { - panic(badWork) + + switch { + case len(a) < (m-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(shortTau) } + lapacke.Dorgqr(m, n, k, a, lda, tau, work, lwork) } @@ -1895,20 +2477,36 @@ func (impl Implementation) Dorgqr(m, n, k int, a []float64, lda int, tau, work [ // // Dorgtr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, tau, work []float64, lwork int) { - checkMatrix(n, n, a, lda) - if len(tau) < n-1 { - panic(badTau) + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, n-1) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if len(work) < lwork { - panic(badWork) + + if n == 0 { + work[0] = 1 + return } - if lwork < n-1 && lwork != -1 { - panic(badWork) + + if lwork == -1 { + lapacke.Dorgtr(byte(uplo), n, a, lda, tau, work, -1) + return } - upper := uplo == blas.Upper - if !upper && uplo != blas.Lower { - panic(badUplo) + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(tau) < n-1: + panic(shortTau) } + lapacke.Dorgtr(byte(uplo), n, a, lda, tau, work, lwork) } @@ -1948,36 +2546,61 @@ func (impl Implementation) Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, t // // Dormbr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) { - if side != blas.Left && side != blas.Right { - panic(badSide) - } - if trans != blas.NoTrans && trans != blas.Trans { - panic(badTrans) - } - if vect != lapack.ApplyP && vect != lapack.ApplyQ { - panic(badApplyOrtho) - } nq := n nw := m if side == blas.Left { nq = m nw = n } - if vect == lapack.ApplyQ { - checkMatrix(nq, min(nq, k), a, lda) - } else { - checkMatrix(min(nq, k), nq, a, lda) + applyQ := vect == lapack.ApplyQ + switch { + case !applyQ && vect != lapack.ApplyP: + panic(badApplyOrtho) + case side != blas.Left && side != blas.Right: + panic(badSide) + case trans != blas.NoTrans && trans != blas.Trans: + panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case applyQ && lda < max(1, min(nq, k)): + panic(badLdA) + case !applyQ && lda < max(1, nq): + panic(badLdA) + case ldc < max(1, n): + panic(badLdC) + case lwork < max(1, nw) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) } - if len(tau) < min(nq, k) { - panic(badTau) + + // Quick return if possible. + if m == 0 || n == 0 { + work[0] = 1 + return } - checkMatrix(m, n, c, ldc) - if len(work) < lwork { - panic(shortWork) + + if lwork == -1 { + lapacke.Dormbr(byte(vect), byte(side), byte(trans), m, n, k, a, lda, tau, c, ldc, work, -1) + return } - if lwork < max(1, nw) && lwork != -1 { - panic(badWork) + + minnqk := min(nq, k) + switch { + case applyQ && len(a) < (nq-1)*lda+minnqk: + panic(shortA) + case !applyQ && len(a) < (minnqk-1)*lda+nq: + panic(shortA) + case len(tau) < minnqk: + panic(shortTau) + case len(c) < (m-1)*ldc+n: + panic(shortC) } + lapacke.Dormbr(byte(vect), byte(side), byte(trans), m, n, k, a, lda, tau, c, ldc, work, lwork) } @@ -2025,39 +2648,53 @@ func (impl Implementation) Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans // // Dormhr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dormhr(side blas.Side, trans blas.Transpose, m, n, ilo, ihi int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) { - var ( - nq int // The order of Q. - nw int // The minimum length of work. - ) - switch side { - case blas.Left: + nq := n // The order of Q. + nw := m // The minimum length of work. + if side == blas.Left { nq = m nw = n - case blas.Right: - nq = n - nw = m - default: - panic(badSide) } switch { + case side != blas.Left && side != blas.Right: + panic(badSide) case trans != blas.NoTrans && trans != blas.Trans: panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) case ilo < 0 || max(1, nq) <= ilo: panic(badIlo) case ihi < min(ilo, nq-1) || nq <= ihi: panic(badIhi) + case lda < max(1, nq): + panic(badLdA) case lwork < max(1, nw) && lwork != -1: - panic(badWork) + panic(badLWork) case len(work) < max(1, lwork): panic(shortWork) } - if lwork != -1 { - checkMatrix(m, n, c, ldc) - checkMatrix(nq, nq, a, lda) - if len(tau) != nq-1 && nq > 0 { - panic(badTau) - } + + // Quick return if possible. + if m == 0 || n == 0 { + work[0] = 1 + return } + + if lwork == -1 { + lapacke.Dormhr(byte(side), byte(trans), m, n, ilo+1, ihi+1, a, lda, tau, c, ldc, work, -1) + return + } + + switch { + case len(a) < (nq-1)*lda+nq: + panic(shortA) + case len(c) < (m-1)*ldc+n: + panic(shortC) + case len(tau) != nq-1: + panic(badLenTau) + } + lapacke.Dormhr(byte(side), byte(trans), m, n, ilo+1, ihi+1, a, lda, tau, c, ldc, work, lwork) } @@ -2080,31 +2717,56 @@ func (impl Implementation) Dormhr(side blas.Side, trans blas.Transpose, m, n, il // tau contains the Householder scales and must have length at least k, and // this function will panic otherwise. func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) { - if side != blas.Left && side != blas.Right { - panic(badSide) - } - if trans != blas.Trans && trans != blas.NoTrans { - panic(badTrans) - } left := side == blas.Left + nw := m if left { - checkMatrix(k, m, a, lda) - } else { - checkMatrix(k, n, a, lda) - } - checkMatrix(m, n, c, ldc) - if len(tau) < k { - panic(badTau) + nw = n } - if len(work) < lwork { + switch { + case !left && side != blas.Right: + panic(badSide) + case trans != blas.Trans && trans != blas.NoTrans: + panic(badTrans) + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case left && k > m: + panic(kGTM) + case !left && k > n: + panic(kGTN) + case left && lda < max(1, m): + panic(badLdA) + case !left && lda < max(1, n): + panic(badLdA) + case lwork < max(1, nw) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): panic(shortWork) } - nw := m - if left { - nw = n + + // Quick return if possible. + if m == 0 || n == 0 || k == 0 { + work[0] = 1 + return + } + + if lwork == -1 { + lapacke.Dormlq(byte(side), byte(trans), m, n, k, a, lda, tau, c, ldc, work, -1) + return } - if lwork < max(1, nw) && lwork != -1 { - panic(badWork) + + switch { + case left && len(a) < (k-1)*lda+m: + panic(shortA) + case !left && len(a) < (k-1)*lda+n: + panic(shortA) + case len(tau) < k: + panic(shortTau) + case len(c) < (m-1)*ldc+n: + panic(shortC) } lapacke.Dormlq(byte(side), byte(trans), m, n, k, a, lda, tau, c, ldc, work, lwork) @@ -2138,35 +2800,56 @@ func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k // If lwork is -1, instead of performing Dormqr, the optimal workspace size will // be stored into work[0]. func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) { - var nq, nw int - switch side { - default: - panic(badSide) - case blas.Left: + left := side == blas.Left + nq := n + nw := m + if left { nq = m nw = n - case blas.Right: - nq = n - nw = m } switch { + case !left && side != blas.Right: + panic(badSide) case trans != blas.NoTrans && trans != blas.Trans: panic(badTrans) - case m < 0 || n < 0: - panic(negDimension) - case k < 0 || nq < k: - panic("lapack: invalid value of k") - case len(work) < lwork: + case m < 0: + panic(mLT0) + case n < 0: + panic(nLT0) + case k < 0: + panic(kLT0) + case left && k > m: + panic(kGTM) + case !left && k > n: + panic(kGTN) + case lda < max(1, k): + panic(badLdA) + case ldc < max(1, n): + panic(badLdC) + case lwork < max(1, nw) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): panic(shortWork) - case lwork < max(1, nw) && lwork != -1: - panic(badWork) } - if lwork != -1 { - checkMatrix(nq, k, a, lda) - checkMatrix(m, n, c, ldc) - if len(tau) != k { - panic(badTau) - } + + // Quick return if possible. + if m == 0 || n == 0 || k == 0 { + work[0] = 1 + return + } + + if lwork == -1 { + lapacke.Dormqr(byte(side), byte(trans), m, n, k, a, lda, tau, c, ldc, work, -1) + return + } + + switch { + case len(a) < (nq-1)*lda+k: + panic(shortA) + case len(tau) != k: + panic(badLenTau) + case len(c) < (m-1)*ldc+n: + panic(shortC) } lapacke.Dormqr(byte(side), byte(trans), m, n, k, a, lda, tau, c, ldc, work, lwork) @@ -2181,30 +2864,35 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k // work is a temporary data slice of length at least 3*n and Dpocon will panic otherwise. // // iwork is a temporary data slice of length at least n and Dpocon will panic otherwise. -// Elements of iwork must fit within the int32 type or Dpocon will panic. func (impl Implementation) Dpocon(uplo blas.Uplo, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 { - checkMatrix(n, n, a, lda) - if uplo != blas.Upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case anorm < 0: + panic(negANorm) } - if len(work) < 3*n { - panic(badWork) - } - if len(iwork) < n { - panic(badWork) + + // Quick return if possible. + if n == 0 { + return 1 } - rcond := make([]float64, 1) - _iwork := make([]int32, len(iwork)) - for i, v := range iwork { - if v != int(int32(v)) { - panic("lapack: iwork element out of range") - } - _iwork[i] = int32(v) + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(work) < 3*n: + panic(shortWork) + case len(iwork) < n: + panic(shortIWork) } + + rcond := []float64{0} + _iwork := make([]int32, n) lapacke.Dpocon(byte(uplo), n, a, lda, anorm, rcond, work, _iwork) - for i, v := range _iwork { - iwork[i] = int(v) - } return rcond[0] } @@ -2233,23 +2921,29 @@ func (impl Implementation) Dpocon(uplo blas.Uplo, n int, a []float64, lda int, a // // Dsteqr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dsteqr(compz lapack.EVComp, n int, d, e, z []float64, ldz int, work []float64) (ok bool) { - if n < 0 { + switch { + case compz != lapack.EVCompNone && compz != lapack.EVTridiag && compz != lapack.EVOrig: + panic(badEVComp) + case n < 0: panic(nLT0) + case ldz < 1, compz != lapack.EVCompNone && ldz < n: + panic(badLdZ) } - if len(d) < n { - panic(badD) - } - if len(e) < n-1 { - panic(badE) - } - if compz != lapack.EVCompNone && compz != lapack.EVTridiag && compz != lapack.EVOrig { - panic(badEVComp) + + // Quick return if possible. + if n == 0 { + return true } - if compz != lapack.EVCompNone { - if len(work) < max(1, 2*n-2) { - panic(badWork) - } - checkMatrix(n, n, z, ldz) + + switch { + case len(d) < n: + panic(shortD) + case len(e) < n-1: + panic(shortE) + case compz != lapack.EVCompNone && len(z) < (n-1)*ldz+n: + panic(shortZ) + case compz != lapack.EVCompNone && len(work) < max(1, 2*n-2): + panic(shortWork) } return lapacke.Dsteqr(byte(compz), n, d, e, z, ldz, work) @@ -2271,14 +2965,17 @@ func (impl Implementation) Dsterf(n int, d, e []float64) (ok bool) { if n < 0 { panic(nLT0) } + + // Quick return if possible. if n == 0 { return true } - if len(d) < n { - panic(badD) - } - if len(e) < n-1 { - panic(badE) + + switch { + case len(d) < n: + panic(shortD) + case len(e) < n-1: + panic(shortE) } return lapacke.Dsterf(n, d, e) @@ -2300,17 +2997,38 @@ func (impl Implementation) Dsterf(n int, d, e []float64) (ok bool) { // limited by the usable length. If lwork == -1, instead of computing Dsyev the // optimal work length is stored into work[0]. func (impl Implementation) Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []float64, lda int, w, work []float64, lwork int) (ok bool) { - checkMatrix(n, n, a, lda) + switch { + case jobz != lapack.EVNone && jobz != lapack.EVCompute: + panic(badEVJob) + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < max(1, 3*n-1) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) + } + + // Quick return if possible. + if n == 0 { + return true + } + if lwork == -1 { - work[0] = 3*float64(n) - 1 + return lapacke.Dsyev(byte(jobz), byte(uplo), n, a, lda, w, work, -1) return } - if len(work) < lwork { - panic(badWork) - } - if lwork < 3*n-1 { - panic(badWork) + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(w) < n: + panic(shortW) } + return lapacke.Dsyev(byte(jobz), byte(uplo), n, a, lda, w, work, lwork) } @@ -2360,24 +3078,39 @@ func (impl Implementation) Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []f // // Dsytrd is an internal routine. It is exported for testing purposes. func (impl Implementation) Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int) { - checkMatrix(n, n, a, lda) - if len(d) < n { - panic(badD) - } - if len(e) < n-1 { - panic(badE) - } - if len(tau) < n-1 { - panic(badTau) - } - if len(work) < lwork { + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case lwork < 1 && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): panic(shortWork) } - if lwork != -1 && lwork < 1 { - panic(badWork) + + // Quick return if possible. + if n == 0 { + work[0] = 1 + return + } + + if lwork == -1 { + lapacke.Dsytrd(byte(uplo), n, a, lda, d, e, tau, work, -1) + return } - if uplo != blas.Upper && uplo != blas.Lower { - panic(badUplo) + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(d) < n: + panic(shortD) + case len(e) < n-1: + panic(shortE) + case len(tau) < n-1: + panic(shortTau) } lapacke.Dsytrd(byte(uplo), n, a, lda, d, e, tau, work, lwork) @@ -2389,35 +3122,36 @@ func (impl Implementation) Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d // work is a temporary data slice of length at least 3*n and Dtrcon will panic otherwise. // // iwork is a temporary data slice of length at least n and Dtrcon will panic otherwise. -// Elements of iwork must fit within the int32 type or Dtrcon will panic. func (impl Implementation) Dtrcon(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int, work []float64, iwork []int) float64 { - if norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum { + switch { + case norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum: panic(badNorm) - } - if uplo != blas.Upper && uplo != blas.Lower { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) - } - if diag != blas.NonUnit && diag != blas.Unit { + case diag != blas.NonUnit && diag != blas.Unit: panic(badDiag) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } - if len(work) < 3*n { - panic(badWork) + + if n == 0 { + return 1 } - if len(iwork) < n { - panic(badWork) + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(work) < 3*n: + panic(shortWork) + case len(iwork) < n: + panic(shortIWork) } + rcond := []float64{0} - _iwork := make([]int32, len(iwork)) - for i, v := range iwork { - if v != int(int32(v)) { - panic("lapack: iwork element out of range") - } - _iwork[i] = int32(v) - } + _iwork := make([]int32, n) lapacke.Dtrcon(byte(norm), byte(uplo), byte(diag), n, a, lda, rcond, work, _iwork) - for i, v := range _iwork { - iwork[i] = int(v) - } return rcond[0] } @@ -2461,29 +3195,37 @@ func (impl Implementation) Dtrcon(norm lapack.MatrixNorm, uplo blas.Uplo, diag b // // Dtrexc is an internal routine. It is exported for testing purposes. func (impl Implementation) Dtrexc(compq lapack.UpdateSchurComp, n int, t []float64, ldt int, q []float64, ldq int, ifst, ilst int, work []float64) (ifstOut, ilstOut int, ok bool) { - checkMatrix(n, n, t, ldt) - switch compq { - default: - panic("lapack: bad value of compq") - case lapack.UpdateSchurNone: - // q is not referenced but LAPACKE checks that ldq >= n always. - q = nil - ldq = max(1, n) - case lapack.UpdateSchur: - checkMatrix(n, n, q, ldq) - } - if (ifst < 0 || n <= ifst) && n > 0 { - panic("lapack: ifst out of range") + switch { + case compq != lapack.UpdateSchur && compq != lapack.UpdateSchurNone: + panic(badUpdateSchurComp) + case n < 0: + panic(nLT0) + case ldt < max(1, n): + panic(badLdT) + case ldq < 1, compq == lapack.UpdateSchur && ldq < n: + panic(badLdQ) + case (ifst < 0 || n <= ifst) && n > 0: + panic(badIfst) + case (ilst < 0 || n <= ilst) && n > 0: + panic(badIlst) } - if (ilst < 0 || n <= ilst) && n > 0 { - panic("lapack: ilst out of range") + + // Quick return if possible. + if n == 0 { + return ifst, ilst, true } - if len(work) < n { - panic(badWork) + + switch { + case len(t) < (n-1)*ldt+n: + panic(shortT) + case compq == lapack.UpdateSchur && len(q) < (n-1)*ldq+n: + panic(shortQ) + case len(work) < n: + panic(shortWork) } // Quick return if possible. - if n <= 1 { + if n == 1 { return ifst, ilst, true } @@ -2502,13 +3244,25 @@ func (impl Implementation) Dtrexc(compq lapack.UpdateSchurComp, n int, t []float // Dtrtri returns whether the matrix a is singular. // If the matrix is singular, the inversion is not performed. func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) { - checkMatrix(n, n, a, lda) - if uplo != blas.Upper && uplo != blas.Lower { + switch { + case uplo != blas.Upper && uplo != blas.Lower: panic(badUplo) - } - if diag != blas.NonUnit && diag != blas.Unit { + case diag != blas.NonUnit && diag != blas.Unit: panic(badDiag) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) } + + if n == 0 { + return true + } + + if len(a) < (n-1)*lda+n { + panic(shortA) + } + return lapacke.Dtrtri(byte(uplo), byte(diag), n, a, lda) } @@ -2516,6 +3270,34 @@ func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []flo // Dtrtrs returns whether the solve completed successfully. // If A is singular, no solve is performed. func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) { + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: + panic(badTrans) + case diag != blas.NonUnit && diag != blas.Unit: + panic(badDiag) + case n < 0: + panic(nLT0) + case nrhs < 0: + panic(nrhsLT0) + case lda < max(1, n): + panic(badLdA) + case ldb < max(1, nrhs): + panic(badLdB) + } + + if n == 0 { + return true + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(b) < (n-1)*ldb+nrhs: + panic(shortB) + } + return lapacke.Dtrtrs(byte(uplo), byte(trans), byte(diag), n, nrhs, a, lda, b, ldb) } @@ -2627,45 +3409,52 @@ func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag bla // // Dhseqr is an internal routine. It is exported for testing purposes. func (impl Implementation) Dhseqr(job lapack.SchurJob, compz lapack.SchurComp, n, ilo, ihi int, h []float64, ldh int, wr, wi []float64, z []float64, ldz int, work []float64, lwork int) (unconverged int) { - switch job { - default: + wantz := compz == lapack.SchurHess || compz == lapack.SchurOrig + + switch { + case job != lapack.EigenvaluesOnly && job != lapack.EigenvaluesAndSchur: panic(badSchurJob) - case lapack.EigenvaluesOnly, lapack.EigenvaluesAndSchur: - } - var wantz bool - switch compz { - default: + case compz != lapack.SchurNone && compz != lapack.SchurHess && compz != lapack.SchurOrig: panic(badSchurComp) - case lapack.SchurNone: - case lapack.SchurHess, lapack.SchurOrig: - wantz = true - } - switch { case n < 0: panic(nLT0) case ilo < 0 || max(0, n-1) < ilo: panic(badIlo) case ihi < min(ilo, n-1) || n <= ihi: panic(badIhi) - case len(work) < lwork: - panic(shortWork) + case ldh < max(1, n): + panic(badLdH) + case ldz < 1, wantz && ldz < n: + panic(badLdZ) case lwork < max(1, n) && lwork != -1: - panic(badWork) - } - if lwork != -1 { - checkMatrix(n, n, h, ldh) - switch { - case wantz: - checkMatrix(n, n, z, ldz) - case len(wr) < n: - panic("lapack: wr has insufficient length") - case len(wi) < n: - panic("lapack: wi has insufficient length") - } + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) + } + + // Quick return if possible. + if n == 0 { + work[0] = 1 + return 0 + } + + // Quick return in case of a workspace query. + if lwork == -1 { + return lapacke.Dhseqr(byte(job), byte(compz), n, ilo+1, ihi+1, h, ldh, wr, wi, z, ldz, work, -1) + } + + switch { + case len(h) < (n-1)*ldh+n: + panic(shortH) + case wantz && len(z) < (n-1)*ldz+n: + panic(shortZ) + case len(wr) < n: + panic(shortWr) + case len(wi) < n: + panic(shortWi) } - return lapacke.Dhseqr(byte(job), byte(compz), n, ilo+1, ihi+1, - h, ldh, wr, wi, z, ldz, work, lwork) + return lapacke.Dhseqr(byte(job), byte(compz), n, ilo+1, ihi+1, h, ldh, wr, wi, z, ldz, work, lwork) } // Dgeev computes the eigenvalues and, optionally, the left and/or right @@ -2716,52 +3505,31 @@ func (impl Implementation) Dhseqr(job lapack.SchurJob, compz lapack.SchurComp, n // computed and wr[first:] and wi[first:] contain those eigenvalues which have // converged. func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob, n int, a []float64, lda int, wr, wi []float64, vl []float64, ldvl int, vr []float64, ldvr int, work []float64, lwork int) (first int) { - var wantvl bool - switch jobvl { - default: - panic("lapack: invalid LeftEVJob") - case lapack.LeftEVCompute: - wantvl = true - case lapack.LeftEVNone: - wantvl = false - } - var wantvr bool - switch jobvr { - default: - panic("lapack: invalid RightEVJob") - case lapack.RightEVCompute: - wantvr = true - case lapack.RightEVNone: - wantvr = false - } - switch { - case n < 0: - panic(nLT0) - case len(work) < lwork: - panic(shortWork) - } + wantvl := jobvl == lapack.LeftEVCompute + wantvr := jobvr == lapack.RightEVCompute var minwrk int if wantvl || wantvr { minwrk = max(1, 4*n) } else { minwrk = max(1, 3*n) } - if lwork != -1 { - checkMatrix(n, n, a, lda) - if wantvl { - checkMatrix(n, n, vl, ldvl) - } - if wantvr { - checkMatrix(n, n, vr, ldvr) - } - switch { - case len(wr) != n: - panic("lapack: bad length of wr") - case len(wi) != n: - panic("lapack: bad length of wi") - case lwork < minwrk: - panic(badWork) - } + switch { + case jobvl != lapack.LeftEVCompute && jobvl != lapack.LeftEVNone: + panic(badLeftEVJob) + case jobvr != lapack.RightEVCompute && jobvr != lapack.RightEVNone: + panic(badRightEVJob) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + case ldvl < 1 || (ldvl < n && wantvl): + panic(badLdVL) + case ldvr < 1 || (ldvr < n && wantvr): + panic(badLdVR) + case lwork < minwrk && lwork != -1: + panic(badLWork) + case len(work) < lwork: + panic(shortWork) } // Quick return if possible. @@ -2770,12 +3538,31 @@ func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob return 0 } - first = lapacke.Dgeev(byte(jobvl), byte(jobvr), n, a, max(n, lda), wr, wi, - vl, max(n, ldvl), vr, max(n, ldvr), work, lwork) - if lwork == -1 && int(work[0]) < minwrk { - work[0] = float64(minwrk) + // TODO(vladimir-ch): The calls to lapacke.Dgeev below require max(n,ldvl) and + // max(n,ldvr) because the leading dimension checks in + // LAPACKE_dgeev_work are too strict. This has been reported in + // https://github.com/Reference-LAPACK/lapack/issues/327 + // Remove the calls to max if and when the upstream fixes this. + + if lwork == -1 { + lapacke.Dgeev(byte(jobvl), byte(jobvr), n, a, lda, wr, wi, vl, max(n, ldvl), vr, max(n, ldvr), work, -1) + return 0 + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(wr) != n: + panic(badLenWr) + case len(wi) != n: + panic(badLenWi) + case len(vl) < (n-1)*ldvl+n && wantvl: + panic(shortVL) + case len(vr) < (n-1)*ldvr+n && wantvr: + panic(shortVR) } - return first + + return lapacke.Dgeev(byte(jobvl), byte(jobvr), n, a, lda, wr, wi, vl, max(n, ldvl), vr, max(n, ldvr), work, lwork) } // Dtgsja computes the generalized singular value decomposition (GSVD) @@ -2923,48 +3710,78 @@ func (impl Implementation) Dgeev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob // // Dtgsja is an internal routine. It is exported for testing purposes. func (impl Implementation) Dtgsja(jobU, jobV, jobQ lapack.GSVDJob, m, p, n, k, l int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64) (cycles int, ok bool) { - checkMatrix(m, n, a, lda) - checkMatrix(p, n, b, ldb) - - if len(alpha) != n { - panic(badAlpha) - } - if len(beta) != n { - panic(badBeta) - } - initu := jobU == lapack.GSVDUnit wantu := initu || jobU == lapack.GSVDU - if !initu && !wantu && jobU != lapack.GSVDNone { - panic(badGSVDJob + "U") - } - if jobU != lapack.GSVDNone { - checkMatrix(m, m, u, ldu) - } initv := jobV == lapack.GSVDUnit wantv := initv || jobV == lapack.GSVDV - if !initv && !wantv && jobV != lapack.GSVDNone { - panic(badGSVDJob + "V") - } - if jobV != lapack.GSVDNone { - checkMatrix(p, p, v, ldv) - } initq := jobQ == lapack.GSVDUnit wantq := initq || jobQ == lapack.GSVDQ - if !initq && !wantq && jobQ != lapack.GSVDNone { + + switch { + case !initu && !wantu && jobU != lapack.GSVDNone: + panic(badGSVDJob + "U") + case !initv && !wantv && jobV != lapack.GSVDNone: + panic(badGSVDJob + "V") + case !initq && !wantq && jobQ != lapack.GSVDNone: panic(badGSVDJob + "Q") - } - if jobQ != lapack.GSVDNone { - checkMatrix(n, n, q, ldq) - } + case m < 0: + panic(mLT0) + case p < 0: + panic(pLT0) + case n < 0: + panic(nLT0) - if len(work) < 2*n { - panic(badWork) + case lda < max(1, n): + panic(badLdA) + case len(a) < (m-1)*lda+n: + panic(shortA) + + case ldb < max(1, n): + panic(badLdB) + case len(b) < (p-1)*ldb+n: + panic(shortB) + + case len(alpha) != n: + panic(badLenAlpha) + case len(beta) != n: + panic(badLenBeta) + + case ldu < 1, wantu && ldu < m: + panic(badLdU) + case wantu && len(u) < (m-1)*ldu+m: + panic(shortU) + + case ldv < 1, wantv && ldv < p: + panic(badLdV) + case wantv && len(v) < (p-1)*ldv+p: + panic(shortV) + + case ldq < 1, wantq && ldq < n: + panic(badLdQ) + case wantq && len(q) < (n-1)*ldq+n: + panic(shortQ) + + case len(work) < 2*n: + panic(shortWork) } ncycle := []int32{0} ok = lapacke.Dtgsja(byte(jobU), byte(jobV), byte(jobQ), m, p, n, k, l, a, lda, b, ldb, tola, tolb, alpha, beta, u, ldu, v, ldv, q, ldq, work, ncycle) return int(ncycle[0]), ok } + +func min(m, n int) int { + if m < n { + return m + } + return n +} + +func max(m, n int) int { + if m < n { + return n + } + return m +}