Skip to content
This repository was archived by the owner on Nov 24, 2018. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions cgo/lapack.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,38 @@ func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (o
return ok
}

// Dgetri computes the inverse of the matrix A using the LU factorization computed
// by Dgetrf. On entry, a contains the PLU decomposition of A as computed by
// Dgetrf and on exit contains the reciprocal of the original matrix.
//
// Dtrtri will not perform the inversion if the matrix is singular, and returns
// a boolean indicating whether the inversion was successful.
//
// The C interface does not support providing temporary storage. To provide compatibility
// with native, lwork == -1 will not run Dgetri but will instead write the minimum
// work necessary to work[0]. If len(work) < lwork, Dgetri will panic.
func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) {
checkMatrix(n, n, a, lda)
if len(ipiv) < n {
panic(badIpiv)
}
if lwork == -1 {
work[0] = float64(n)
return true
}
if lwork < n {
panic(badWork)
}
if len(work) < lwork {
panic(badWork)
}
ipiv32 := make([]int32, len(ipiv))
for i, v := range ipiv {
ipiv32[i] = int32(v) + 1 // Transform to one-indexed.
}
return clapack.Dgetri(n, a, lda, ipiv32)
}

// Dgetrs solves a system of equations using an LU factorization.
// The system of equations solved is
// A * X = B if trans == blas.Trans
Expand Down
4 changes: 4 additions & 0 deletions cgo/lapack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func TestDgetrf(t *testing.T) {
testlapack.DgetrfTest(t, impl)
}

func TestDgetri(t *testing.T) {
testlapack.DgetriTest(t, impl)
}

func TestDgetrs(t *testing.T) {
testlapack.DgetrsTest(t, impl)
}
Expand Down
88 changes: 88 additions & 0 deletions native/dgetri.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package native

import (
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
)

// Dgetri computes the inverse of the matrix A using the LU factorization computed
// by Dgetrf. On entry, a contains the PLU decomposition of A as computed by
// Dgetrf and on exit contains the reciprocal of the original matrix.
//
// Dgetri will not perform the inversion if the matrix is singular, and returns
// a boolean indicating whether the inversion was successful.
//
// Work is temporary storage, and lwork specifies the usable memory length.
// At minimum, lwork >= n and this function will panic otherwise.
// Dgetri is a blocked inversion, but the block size is limited
// by the temporary space available. If lwork == -1, instead of performing Dgetri,
// the optimal work length will be stored into work[0].
func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) {
checkMatrix(n, n, a, lda)
if len(ipiv) < n {
panic(badIpiv)
}
nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1)
if lwork == -1 {
work[0] = float64(n * nb)
return true
}
if lwork < n {
panic(badWork)
}
if len(work) < lwork {
panic(badWork)
}
if n == 0 {
return true
}
ok = impl.Dtrtri(blas.Upper, blas.NonUnit, n, a, lda)
if !ok {
return false
}
nbmin := 2
ldwork := nb
if nb > 1 && nb < n {
iws := max(ldwork*n, 1)
if lwork < iws {
nb = lwork / ldwork
nbmin = max(2, impl.Ilaenv(2, "DGETRI", " ", n, -1, -1, -1))
}
}
bi := blas64.Implementation()
// TODO(btracey): Replace this with a more row-major oriented algorithm.
if nb < nbmin || nb >= n {
// Unblocked code.
for j := n - 1; j >= 0; j-- {
for i := j + 1; i < n; i++ {
work[i*ldwork] = a[i*lda+j]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why stride this by ldwork? Ahh, I see, you have transposed work. What for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't transposed work. Before, the indexing is effectively work(i,1), and I have transformed that indexing into row major. this particular loop is easy to re-orient, but it's more complicated below in the Dgemv and Dgemm calls.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry, that's what I mean.

a[i*lda+j] = 0
}
if j < n {
bi.Dgemv(blas.NoTrans, n, n-j-1, -1, a[(j+1):], lda, work[(j+1)*ldwork:], ldwork, 1, a[j:], lda)
}
}
} else {
nn := ((n - 1) / nb) * nb
for j := nn; j >= 0; j -= nb {
jb := min(nb, n-j)
for jj := j; jj < j+jb-1; jj++ {
for i := jj + 1; i < n; i++ {
work[i*ldwork+(jj-j)] = a[i*lda+jj]
a[i*lda+jj] = 0
}
}
if j+jb < n {
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, jb, n-j-jb, -1, a[(j+jb):], lda, work[(j+jb)*ldwork:], ldwork, 1, a[j:], lda)
bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda)
}
}
}
for j := n - 2; j >= 0; j-- {
jp := ipiv[j]
if jp != j {
bi.Dswap(n, a[j:], lda, a[jp:], lda)
}
}
return true
}
4 changes: 2 additions & 2 deletions native/dtrtri.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
// into a. This is the BLAS level 3 version of the algorithm which builds upon
// Dtrti2 to operate on matrix blocks instead of only individual columns.
//
// Dtrti returns whether the matrix a is singular or whether it's not singular.
// If the matrix is singular the inversion is not performed.
// Dtrtri will not perform the inversion if the matrix is singular, and returns
// a boolean indicating whether the inversion was successful.
func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) {
checkMatrix(n, n, a, lda)
if uplo != blas.Upper && uplo != blas.Lower {
Expand Down
4 changes: 4 additions & 0 deletions native/lapack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func TestDgeqrf(t *testing.T) {
testlapack.DgeqrfTest(t, impl)
}

func TestDgetri(t *testing.T) {
testlapack.DgetriTest(t, impl)
}

func TestDgetf2(t *testing.T) {
testlapack.Dgetf2Test(t, impl)
}
Expand Down
84 changes: 84 additions & 0 deletions testlapack/dgetri.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package testlapack

import (
"math"
"math/rand"
"testing"

"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
)

type Dgetrier interface {
Dgetrfer
Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) bool
}

func DgetriTest(t *testing.T, impl Dgetrier) {
bi := blas64.Implementation()
for _, test := range []struct {
n, lda int
}{
{5, 0},
{5, 8},
{45, 0},
{45, 50},
{65, 0},
{65, 70},
{150, 0},
{150, 250},
} {
n := test.n
lda := test.lda
if lda == 0 {
lda = n
}
// Generate a random well conditioned matrix
perm := rand.Perm(n)
a := make([]float64, n*lda)
for i := 0; i < n; i++ {
a[i*lda+perm[i]] = 1
}
for i := range a {
a[i] += 0.01 * rand.Float64()
}
aCopy := make([]float64, len(a))
copy(aCopy, a)
ipiv := make([]int, n)
// Compute LU decomposition.
impl.Dgetrf(n, n, a, lda, ipiv)
// Compute inverse.
work := make([]float64, 1)
impl.Dgetri(n, a, lda, ipiv, work, -1)
work = make([]float64, int(work[0]))
lwork := len(work)

ok := impl.Dgetri(n, a, lda, ipiv, work, lwork)
if !ok {
t.Errorf("Unexpected singular matrix.")
}

// Check that A(inv) * A = I.
ans := make([]float64, len(a))
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, aCopy, lda, a, lda, 0, ans, lda)
isEye := true
for i := 0; i < n; i++ {
for j := 0; j < n; j++ {
if i == j {
// This tolerance is so high because computing matrix inverses
// is very unstable.
if math.Abs(ans[i*lda+j]-1) > 2e-2 {
isEye = false
}
} else {
if math.Abs(ans[i*lda+j]) > 2e-2 {
isEye = false
}
}
}
}
if !isEye {
t.Errorf("Inv(A) * A != I. n = %v, lda = %v", n, lda)
}
}
}