Skip to content
This repository was archived by the owner on Nov 24, 2018. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions cgo/lapack.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,27 @@ func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (o
}
return ok
}

// Dgetrs solves a system of equations using an LU factorization.
// The system of equations solved is
// A * X = B if trans == blas.Trans
// A^T * X = B if trans == blas.NoTrans
// A is a general n×n matrix with stride lda. B is a general matrix of size n×nrhs.
//
// On entry b contains the elements of the matrix B. On exit, b contains the
// elements of X, the solution to the system of equations.
//
// a and ipiv contain the LU factorization of A and the permutation indices as
// computed by Dgetrf. ipiv is zero-indexed.
func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) {
checkMatrix(n, n, a, lda)
checkMatrix(n, nrhs, b, ldb)
if len(ipiv) < n {
panic(badIpiv)
}
ipiv32 := make([]int32, len(ipiv))
for i, v := range ipiv {
ipiv32[i] = int32(v) + 1 // Transform to one-indexed.
}
clapack.Dgetrs(trans, n, nrhs, a, lda, ipiv32, b, ldb)
}
4 changes: 4 additions & 0 deletions cgo/lapack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ func TestDgetf2(t *testing.T) {
func TestDgetrf(t *testing.T) {
testlapack.DgetrfTest(t, impl)
}

func TestDgetrs(t *testing.T) {
testlapack.DgetrsTest(t, impl)
}
51 changes: 51 additions & 0 deletions native/dgetrs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package native

import (
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
)

// Dgetrs solves a system of equations using an LU factorization.
// The system of equations solved is
// A * X = B if trans == blas.Trans
// A^T * X = B if trans == blas.NoTrans
// A is a general n×n matrix with stride lda. B is a general matrix of size n×nrhs.
//
// On entry b contains the elements of the matrix B. On exit, b contains the
// elements of X, the solution to the system of equations.
//
// a and ipiv contain the LU factorization of A and the permutation indices as
// computed by Dgetrf. ipiv is zero-indexed.
func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) {
checkMatrix(n, n, a, lda)
checkMatrix(n, nrhs, b, ldb)
if len(ipiv) < n {
panic(badIpiv)
}
if n == 0 || nrhs == 0 {
return
}
if trans != blas.Trans && trans != blas.NoTrans {
panic(badTrans)
}
bi := blas64.Implementation()
if trans == blas.NoTrans {
// Solve A * X = B.
impl.Dlaswp(nrhs, b, ldb, 0, n-1, ipiv, 1)
// Solve L * X = B, updating b.
bi.Dtrsm(blas.Left, blas.Lower, blas.NoTrans, blas.Unit,
n, nrhs, 1, a, lda, b, ldb)
// Solve U * X = B, updating b.
bi.Dtrsm(blas.Left, blas.Upper, blas.NoTrans, blas.NonUnit,
n, nrhs, 1, a, lda, b, ldb)
return
}
// Solve A^T * X = B.
// Solve U^T * X = B, updating b.
bi.Dtrsm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit,
n, nrhs, 1, a, lda, b, ldb)
// Solve L^T * X = B, updating b.
bi.Dtrsm(blas.Left, blas.Lower, blas.Trans, blas.Unit,
n, nrhs, 1, a, lda, b, ldb)
impl.Dlaswp(nrhs, b, ldb, 0, n-1, ipiv, -1)
}
4 changes: 4 additions & 0 deletions native/lapack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ func TestDgetrf(t *testing.T) {
testlapack.DgetrfTest(t, impl)
}

func TestDgetrs(t *testing.T) {
testlapack.DgetrsTest(t, impl)
}

func TestDlange(t *testing.T) {
testlapack.DlangeTest(t, impl)
}
Expand Down
116 changes: 116 additions & 0 deletions testlapack/dgetrs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package testlapack

import (
"math/rand"
"testing"

"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
"github.com/gonum/floats"
)

type Dgetrser interface {
Dgetrfer
Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int)
}

func DgetrsTest(t *testing.T, impl Dgetrser) {
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have done the smaller loop inside, but this is fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This form matches how the QR-related tests are.

for _, test := range []struct {
n, nrhs, lda, ldb int
tol float64
}{
{3, 3, 0, 0, 1e-14},
{3, 3, 0, 0, 1e-14},
{3, 5, 0, 0, 1e-14},
{3, 5, 0, 0, 1e-14},
{5, 3, 0, 0, 1e-14},
{5, 3, 0, 0, 1e-14},

{3, 3, 8, 10, 1e-14},
{3, 3, 8, 10, 1e-14},
{3, 5, 8, 10, 1e-14},
{3, 5, 8, 10, 1e-14},
{5, 3, 8, 10, 1e-14},
{5, 3, 8, 10, 1e-14},

{300, 300, 0, 0, 1e-10},
{300, 300, 0, 0, 1e-10},
{300, 500, 0, 0, 1e-10},
{300, 500, 0, 0, 1e-10},
{500, 300, 0, 0, 1e-10},
{500, 300, 0, 0, 1e-10},

{300, 300, 700, 600, 1e-10},
{300, 300, 700, 600, 1e-10},
{300, 500, 700, 600, 1e-10},
{300, 500, 700, 600, 1e-10},
{500, 300, 700, 600, 1e-10},
{500, 300, 700, 600, 1e-10},
} {
n := test.n
nrhs := test.nrhs
lda := test.lda
if lda == 0 {
lda = n
}
ldb := test.ldb
if ldb == 0 {
ldb = nrhs
}
a := make([]float64, n*lda)
for i := range a {
a[i] = rand.Float64()
}
b := make([]float64, n*ldb)
for i := range b {
b[i] = rand.Float64()
}
aCopy := make([]float64, len(a))
copy(aCopy, a)
bCopy := make([]float64, len(b))
copy(bCopy, b)

ipiv := make([]int, n)
for i := range ipiv {
ipiv[i] = rand.Int()
}

// Compute the LU factorization.
impl.Dgetrf(n, n, a, lda, ipiv)
// Solve the system of equations given the result.
impl.Dgetrs(trans, n, nrhs, a, lda, ipiv, b, ldb)

// Check that the system of equations holds.
A := blas64.General{
Rows: n,
Cols: n,
Stride: lda,
Data: aCopy,
}
B := blas64.General{
Rows: n,
Cols: nrhs,
Stride: ldb,
Data: bCopy,
}
X := blas64.General{
Rows: n,
Cols: nrhs,
Stride: ldb,
Data: b,
}
tmp := blas64.General{
Rows: n,
Cols: nrhs,
Stride: ldb,
Data: make([]float64, n*ldb),
}
copy(tmp.Data, bCopy)
blas64.Gemm(trans, blas.NoTrans, 1, A, X, 0, B)
if !floats.EqualApprox(tmp.Data, bCopy, test.tol) {
t.Errorf("Linear solve mismatch. trans = %v, n = %v, nrhs = %v, lda = %v, ldb = %v", trans, n, nrhs, lda, ldb)
}
}
}
}