Skip to content
This repository was archived by the owner on Nov 24, 2018. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions cgo/lapack.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

// Copied from lapack/native. Keep in sync.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reflect this comment in the native location as well - I should have done that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, oops. I had this in at one point as well.

const (
absIncNotOne = "lapack: increment not one or negative one"
badDirect = "lapack: bad direct"
badIpiv = "lapack: insufficient permutation length"
badLdA = "lapack: index of a out of range"
Expand Down Expand Up @@ -76,7 +77,7 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
return clapack.Dpotrf(ul, n, a, lda)
}

// Dgetf2 computes the LU decomposition of the m×n matrix a.
// Dgetf2 computes the LU decomposition of the m×n matrix A.
// The LU decomposition is a factorization of a into
// A = P * L * U
// where P is a permutation matrix, L is a unit lower triangular matrix, and
Expand All @@ -85,9 +86,9 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
//
// ipiv is a permutation vector. It indicates that row i of the matrix was
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
// otherwise.
// otherwise. ipiv is zero-indexed.
//
// Dgetf2 returns whether the matrix a is singular. The LU decomposition will
// Dgetf2 returns whether the matrix A is singular. The LU decomposition will
// be computed regardless of the singularity of A, but division by zero
// will occur if the false is returned and the result is used to solve a
// system of equations.
Expand All @@ -100,7 +101,38 @@ func (Implementation) Dgetf2(m, n int, a []float64, lda int, ipiv []int) (ok boo
ipiv32 := make([]int32, len(ipiv))
ok = clapack.Dgetf2(m, n, a, lda, ipiv32)
for i, v := range ipiv32 {
ipiv[i] = int(v) - 1 // OpenBLAS returns one indexed.
ipiv[i] = int(v) - 1 // Transform to zero-indexed.
}
return ok
}

// Dgetrf computes the LU decomposition of the m×n matrix A.
// The LU decomposition is a factorization of a into
// A = P * L * U
// where P is a permutation matrix, L is a unit lower triangular matrix, and
// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored
// in place into a.
//
// ipiv is a permutation vector. It indicates that row i of the matrix was
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we note that ipiv is zero-based here. It may be surprising that we don't do one-based when everyone (?) else does.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
// otherwise. ipiv is zero-indexed.
//
// Dgetrf is the blocked version of the algorithm.
//
// Dgetrf returns whether the matrix A is singular. The LU decomposition will
// be computed regardless of the singularity of A, but division by zero
// will occur if the false is returned and the result is used to solve a
// system of equations.
func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) {
mn := min(m, n)
checkMatrix(m, n, a, lda)
if len(ipiv) < mn {
panic(badIpiv)
}
ipiv32 := make([]int32, len(ipiv))
ok = clapack.Dgetrf(m, n, a, lda, ipiv32)
for i, v := range ipiv32 {
ipiv[i] = int(v) - 1 // Transform to zero-indexed.
}
return ok
}
4 changes: 4 additions & 0 deletions cgo/lapack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ func TestDpotrf(t *testing.T) {
func TestDgetf2(t *testing.T) {
testlapack.Dgetf2Test(t, impl)
}

func TestDgetrf(t *testing.T) {
testlapack.DgetrfTest(t, impl)
}
5 changes: 1 addition & 4 deletions native/dgeqrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/gonum/lapack"
)

// Dgeqrf computes the QR factorization of the m×n matrix a using a blocked
// Dgeqrf computes the QR factorization of the m×n matrix A using a blocked
// algorithm. Please see the documentation for Dgeqr2 for a description of the
// parameters at entry and exit.
//
Expand All @@ -21,9 +21,6 @@ import (
//
// tau must be at least len min(m,n), and this function will panic otherwise.
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
// TODO(btracey): This algorithm is oriented for column-major storage.
// Consider modifying the algorithm to better suit row-major storage.

// nb is the optimal blocksize, i.e. the number of columns transformed at a time.
nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
lworkopt := n * max(nb, 1)
Expand Down
6 changes: 3 additions & 3 deletions native/dgetf2.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/gonum/blas/blas64"
)

// Dgetf2 computes the LU decomposition of the m×n matrix a.
// Dgetf2 computes the LU decomposition of the m×n matrix A.
// The LU decomposition is a factorization of a into
// A = P * L * U
// where P is a permutation matrix, L is a unit lower triangular matrix, and
Expand All @@ -15,9 +15,9 @@ import (
//
// ipiv is a permutation vector. It indicates that row i of the matrix was
// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
// otherwise.
// otherwise. ipiv is zero-indexed.
//
// Dgetf2 returns whether the matrix a is singular. The LU decomposition will
// Dgetf2 returns whether the matrix A is singular. The LU decomposition will
// be computed regardless of the singularity of A, but division by zero
// will occur if the false is returned and the result is used to solve a
// system of equations.
Expand Down
66 changes: 66 additions & 0 deletions native/dgetrf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package native

import (
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
)

// Dgetrf computes the LU decomposition of the m×n matrix a.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalisation of "A".

// The LU decomposition is a factorization of a into
// A = P * L * U
// where P is a permutation matrix, L is a unit lower triangular matrix, and
// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored
// in place into a.
//
// ipiv is a permutation vector. It indicates that row i of the matrix was
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that ipiv is zero-based (probably in other places also - maybe another PR?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comment for the LU cases.

// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic
// otherwise. ipiv is zero-indexed.
//
// Dgetrf is the blocked version of the algorithm.
//
// Dgetrf returns whether the matrix A is singular. The LU decomposition will
// be computed regardless of the singularity of A, but division by zero
// will occur if the false is returned and the result is used to solve a
// system of equations.
func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (ok bool) {
mn := min(m, n)
checkMatrix(m, n, a, lda)
if len(ipiv) < mn {
panic(badIpiv)
}
if m == 0 || n == 0 {
return
}
bi := blas64.Implementation()
nb := impl.Ilaenv(1, "DGETRF", " ", m, n, -1, -1)
if nb <= 1 || nb >= min(m, n) {
// Use the unblocked algorithm.
return impl.Dgetf2(m, n, a, lda, ipiv)
}
ok = true
for j := 0; j < mn; j += nb {
jb := min(mn-j, nb)
blockOk := impl.Dgetf2(m-j, jb, a[j*lda+j:], lda, ipiv[j:])
if !blockOk {
ok = false
}
for i := j; i <= min(m-1, j+jb-1); i++ {
ipiv[i] = j + ipiv[i]
}
impl.Dlaswp(j, a, lda, j, j+jb-1, ipiv, 1)
if j+jb < n {
impl.Dlaswp(n-j-jb, a[j+jb:], lda, j, j+jb-1, ipiv, 1)
bi.Dtrsm(blas.Left, blas.Lower, blas.NoTrans, blas.Unit,
jb, n-j-jb, 1,
a[j*lda+j:], lda,
a[j*lda+j+jb:], lda)
if j+jb < m {
bi.Dgemm(blas.NoTrans, blas.NoTrans, m-j-jb, n-j-jb, jb, -1,
a[(j+jb)*lda+j:], lda,
a[j*lda+j+jb:], lda,
1, a[(j+jb)*lda+j+jb:], lda)
}
}
}
return ok
}
23 changes: 23 additions & 0 deletions native/dlaswp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package native

import "github.com/gonum/blas/blas64"

// Dlaswp swaps the rows k1 to k2 of a according to the indices in ipiv.
// a is a matrix with n columns and stride lda. incX is the increment for ipiv.
// k1 and k2 are zero-indexed. If incX is negative, then loops from k2 to k1
func (impl Implementation) Dlaswp(n int, a []float64, lda, k1, k2 int, ipiv []int, incX int) {
if incX != 1 && incX != -1 {
panic(absIncNotOne)
}
bi := blas64.Implementation()
if incX == 1 {
for k := k1; k <= k2; k++ {
bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1)
}
return
}
for k := k2; k >= k1; k-- {
bi.Dswap(n, a[k*lda:], 1, a[ipiv[k]*lda:], 1)
}
return
}
2 changes: 2 additions & 0 deletions native/general.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ type Implementation struct{}

var _ lapack.Float64 = Implementation{}

// This list is duplicated in lapack/cgo. Keep in sync.
const (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sync comment.

absIncNotOne = "lapack: increment not one or negative one"
badDirect = "lapack: bad direct"
badIpiv = "lapack: insufficient permutation length"
badLdA = "lapack: index of a out of range"
Expand Down
4 changes: 4 additions & 0 deletions native/lapack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func TestDgetf2(t *testing.T) {
testlapack.Dgetf2Test(t, impl)
}

func TestDgetrf(t *testing.T) {
testlapack.DgetrfTest(t, impl)
}

func TestDlange(t *testing.T) {
testlapack.DlangeTest(t, impl)
}
Expand Down
Loading