Skip to content
This repository was archived by the owner on Nov 24, 2018. It is now read-only.

Commit 2eb39c4

Browse files
committed
Merge pull request #32 from gonum/qrlqsolve
Add cgo and lapack64 functions for performing a QR and LQ solve from …
2 parents c6728a6 + 2e760db commit 2eb39c4

File tree

8 files changed

+207
-13
lines changed

8 files changed

+207
-13
lines changed

cgo/lapack.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,114 @@ func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64
327327
}
328328
clapack.Dgetrs(trans, n, nrhs, a, lda, ipiv32, b, ldb)
329329
}
330+
331+
// Dormlq multiplies the matrix C by the othogonal matrix Q defined by the
332+
// slices a and tau. A and tau are as returned from Dgelqf.
333+
// C = Q * C if side == blas.Left and trans == blas.NoTrans
334+
// C = Q^T * C if side == blas.Left and trans == blas.Trans
335+
// C = C * Q if side == blas.Right and trans == blas.NoTrans
336+
// C = C * Q^T if side == blas.Right and trans == blas.Trans
337+
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
338+
// A is of size k×n. This uses a blocked algorithm.
339+
//
340+
// Work is temporary storage, and lwork specifies the usable memory length.
341+
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
342+
// and this function will panic otherwise.
343+
// Dormlq uses a block algorithm, but the block size is limited
344+
// by the temporary space available. If lwork == -1, instead of performing Dormlq,
345+
// the optimal work length will be stored into work[0].
346+
//
347+
// tau contains the householder scales and must have length at least k, and
348+
// this function will panic otherwise.
349+
func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
350+
if side != blas.Left && side != blas.Right {
351+
panic(badSide)
352+
}
353+
if trans != blas.Trans && trans != blas.NoTrans {
354+
panic(badTrans)
355+
}
356+
left := side == blas.Left
357+
if left {
358+
checkMatrix(k, m, a, lda)
359+
} else {
360+
checkMatrix(k, n, a, lda)
361+
}
362+
checkMatrix(m, n, c, ldc)
363+
if len(tau) < k {
364+
panic(badTau)
365+
}
366+
if lwork == -1 {
367+
if left {
368+
work[0] = float64(n)
369+
return
370+
}
371+
work[0] = float64(m)
372+
return
373+
}
374+
if left {
375+
if lwork < n {
376+
panic(badWork)
377+
}
378+
} else {
379+
if lwork < m {
380+
panic(badWork)
381+
}
382+
}
383+
clapack.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc)
384+
}
385+
386+
// Dormqr multiplies the matrix C by the othogonal matrix Q defined by the
387+
// slices a and tau. a and tau are as returned from Dgeqrf.
388+
// C = Q * C if side == blas.Left and trans == blas.NoTrans
389+
// C = Q^T * C if side == blas.Left and trans == blas.Trans
390+
// C = C * Q if side == blas.Right and trans == blas.NoTrans
391+
// C = C * Q^T if side == blas.Right and trans == blas.Trans
392+
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
393+
// A is of size k×n. This uses a blocked algorithm.
394+
//
395+
// tau contains the householder scales and must have length at least k, and
396+
// this function will panic otherwise.
397+
//
398+
// The C interface does not support providing temporary storage. To provide compatibility
399+
// with native, lwork == -1 will not run Dgeqrf but will instead write the minimum
400+
// work necessary to work[0]. If len(work) < lwork, Dgeqrf will panic.
401+
func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
402+
left := side == blas.Left
403+
if left {
404+
checkMatrix(m, k, a, lda)
405+
} else {
406+
checkMatrix(n, k, a, lda)
407+
}
408+
checkMatrix(m, n, c, ldc)
409+
410+
if len(tau) < k {
411+
panic(badTau)
412+
}
413+
414+
if lwork == -1 {
415+
if left {
416+
work[0] = float64(m)
417+
return
418+
}
419+
work[0] = float64(n)
420+
return
421+
}
422+
423+
if left {
424+
if lwork < n {
425+
panic(badWork)
426+
}
427+
} else {
428+
if lwork < m {
429+
panic(badWork)
430+
}
431+
}
432+
433+
clapack.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc)
434+
}
435+
436+
// Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs
437+
// returns whether the solve completed successfully. If A is singular, no solve is performed.
438+
func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) {
439+
return clapack.Dtrtrs(uplo, trans, diag, n, nrhs, a, lda, b, ldb)
440+
}

cgo/lapack_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package cgo
77
import (
88
"testing"
99

10+
"github.com/gonum/blas"
1011
"github.com/gonum/lapack/testlapack"
1112
)
1213

@@ -47,3 +48,29 @@ func TestDgetrf(t *testing.T) {
4748
func TestDgetrs(t *testing.T) {
4849
testlapack.DgetrsTest(t, impl)
4950
}
51+
52+
// blockedTranslate transforms some blocked C calls to be the unblocked algorithms
53+
// for testing, as several of the unblocked algorithms are not defined by the C
54+
// interface.
55+
type blockedTranslate struct {
56+
Implementation
57+
}
58+
59+
func (d blockedTranslate) Dorm2r(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) {
60+
impl.Dormqr(side, trans, m, n, k, a, lda, tau, c, ldc, work, len(work))
61+
}
62+
63+
func (d blockedTranslate) Dorml2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) {
64+
impl.Dormlq(side, trans, m, n, k, a, lda, tau, c, ldc, work, len(work))
65+
}
66+
67+
func TestDormqr(t *testing.T) {
68+
testlapack.Dorm2rTest(t, blockedTranslate{impl})
69+
}
70+
71+
/*
72+
// Test disabled because of bug in c interface. Leaving stub for easy reproducer.
73+
func TestDormlq(t *testing.T) {
74+
testlapack.Dorml2Test(t, blockedTranslate{impl})
75+
}
76+
*/

lapack.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@ type Float64 interface {
2626
Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool
2727
Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
2828
Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
29-
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
3029
Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool)
3130
Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int)
31+
Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
32+
Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
33+
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
34+
Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool)
3235
}
3336

3437
// Direct specifies the direction of the multiplication for the Householder matrix.

lapack64/lapack64.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,53 @@ func Getrf(a blas64.General, ipiv []int) bool {
162162
func Getrs(trans blas.Transpose, a blas64.General, b blas64.General, ipiv []int) {
163163
lapack64.Dgetrs(trans, a.Cols, b.Cols, a.Data, a.Stride, ipiv, b.Data, b.Stride)
164164
}
165+
166+
// Ormlq multiplies the matrix C by the othogonal matrix Q defined by
167+
// A and tau. A and tau are as returned from Gelqf.
168+
// C = Q * C if side == blas.Left and trans == blas.NoTrans
169+
// C = Q^T * C if side == blas.Left and trans == blas.Trans
170+
// C = C * Q if side == blas.Right and trans == blas.NoTrans
171+
// C = C * Q^T if side == blas.Right and trans == blas.Trans
172+
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
173+
// A is of size k×n. This uses a blocked algorithm.
174+
//
175+
// Work is temporary storage, and lwork specifies the usable memory length.
176+
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
177+
// and this function will panic otherwise.
178+
// Ormlq uses a block algorithm, but the block size is limited
179+
// by the temporary space available. If lwork == -1, instead of performing Ormlq,
180+
// the optimal work length will be stored into work[0].
181+
//
182+
// Tau contains the householder scales and must have length at least k, and
183+
// this function will panic otherwise.
184+
func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) {
185+
lapack64.Dormlq(side, trans, c.Rows, c.Cols, a.Rows, a.Data, a.Stride, tau, c.Data, c.Stride, work, lwork)
186+
}
187+
188+
// Ormqr multiplies the matrix C by the othogonal matrix Q defined by
189+
// A and tau. A and tau are as returned from Geqrf.
190+
// C = Q * C if side == blas.Left and trans == blas.NoTrans
191+
// C = Q^T * C if side == blas.Left and trans == blas.Trans
192+
// C = C * Q if side == blas.Right and trans == blas.NoTrans
193+
// C = C * Q^T if side == blas.Right and trans == blas.Trans
194+
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
195+
// A is of size k×n. This uses a blocked algorithm.
196+
//
197+
// tau contains the householder scales and must have length at least k, and
198+
// this function will panic otherwise.
199+
//
200+
// Work is temporary storage, and lwork specifies the usable memory length.
201+
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
202+
// and this function will panic otherwise.
203+
// Ormqr uses a block algorithm, but the block size is limited
204+
// by the temporary space available. If lwork == -1, instead of performing Ormqr,
205+
// the optimal work length will be stored into work[0].
206+
func Ormqr(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) {
207+
lapack64.Dormqr(side, trans, c.Rows, c.Cols, a.Cols, a.Data, a.Stride, tau, c.Data, c.Stride, work, lwork)
208+
}
209+
210+
// Trtrs solves a triangular system of the form A * X = B or A^T * X = B. Trtrs
211+
// returns whether the solve completed successfully. If A is singular, no solve is performed.
212+
func Trtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, a blas64.Triangular, b blas64.General) (ok bool) {
213+
return lapack64.Dtrtrs(uplo, trans, diag, a.N, b.Cols, a.Data, a.Stride, b.Data, b.Stride)
214+
}

native/dorm2r.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ package native
66

77
import "github.com/gonum/blas"
88

9-
// Dorm2r multiplies a general matrix c by an orthogonal matrix from a QR factorization
9+
// Dorm2r multiplies a general matrix C by an orthogonal matrix from a QR factorization
1010
// determined by Dgeqrf.
1111
// C = Q * C if side == blas.Left and trans == blas.NoTrans
1212
// C = Q^T * C if side == blas.Left and trans == blas.Trans

native/dormlq.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ import (
99
"github.com/gonum/lapack"
1010
)
1111

12-
// Dormlq multiplies the matrix c by the othogonal matrix q defined by the
12+
// Dormlq multiplies the matrix C by the othogonal matrix Q defined by the
1313
// slices a and tau. A and tau are as returned from Dgelqf.
1414
// C = Q * C if side == blas.Left and trans == blas.NoTrans
1515
// C = Q^T * C if side == blas.Left and trans == blas.Trans
1616
// C = C * Q if side == blas.Right and trans == blas.NoTrans
1717
// C = C * Q^T if side == blas.Right and trans == blas.Trans
18-
// If side == blas.Left, a is a matrix of side k×m, and if side == blas.Right
19-
// a is of size k×n. This uses a blocked algorithm.
18+
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
19+
// A is of size k×n. This uses a blocked algorithm.
2020
//
2121
// Work is temporary storage, and lwork specifies the usable memory length.
2222
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
@@ -25,7 +25,7 @@ import (
2525
// by the temporary space available. If lwork == -1, instead of performing Dormlq,
2626
// the optimal work length will be stored into work[0].
2727
//
28-
// Tau contains the householder scales and must have length at least k, and
28+
// tau contains the householder scales and must have length at least k, and
2929
// this function will panic otherwise.
3030
func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
3131
if side != blas.Left && side != blas.Right {

native/dormqr.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ import (
99
"github.com/gonum/lapack"
1010
)
1111

12-
// Dormqr multiplies the matrix c by the othogonal matrix q defined by the
12+
// Dormqr multiplies the matrix C by the othogonal matrix Q defined by the
1313
// slices a and tau. A and tau are as returned from Dgeqrf.
1414
// C = Q * C if side == blas.Left and trans == blas.NoTrans
1515
// C = Q^T * C if side == blas.Left and trans == blas.Trans
1616
// C = C * Q if side == blas.Right and trans == blas.NoTrans
1717
// C = C * Q^T if side == blas.Right and trans == blas.Trans
18-
// If side == blas.Left, a is a matrix of side k×m, and if side == blas.Right
19-
// a is of size k×n. This uses a blocked algorithm.
18+
// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right
19+
// A is of size k×n. This uses a blocked algorithm.
2020
//
2121
// Work is temporary storage, and lwork specifies the usable memory length.
2222
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
@@ -25,7 +25,7 @@ import (
2525
// by the temporary space available. If lwork == -1, instead of performing Dormqr,
2626
// the optimal work length will be stored into work[0].
2727
//
28-
// Tau contains the householder scales and must have length at least k, and
28+
// tau contains the householder scales and must have length at least k, and
2929
// this function will panic otherwise.
3030
func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
3131
left := side == blas.Left
@@ -37,6 +37,10 @@ func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k
3737
}
3838
checkMatrix(m, n, c, ldc)
3939

40+
if len(tau) < k {
41+
panic(badTau)
42+
}
43+
4044
const nbmax = 64
4145
nw := n
4246
if side == blas.Right {

native/dtrtrs.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@ import (
99
"github.com/gonum/blas/blas64"
1010
)
1111

12-
// Dtrtrs solves a triangular system of the form a * x = b or a^T * x = b. Dtrtrs
13-
// checks for singularity in a. If a is singular, false is returned and no solve
14-
// is performed. True is returned otherwise.
12+
// Dtrtrs solves a triangular system of the form A * X = B or A^T * X = B. Dtrtrs
13+
// returns whether the solve completed successfully. If A is singular, no solve is performed.
1514
func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) {
1615
nounit := diag == blas.NonUnit
1716
if n == 0 {

0 commit comments

Comments
 (0)