Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lapack/netlib: add Dtbtrs #77

Merged
merged 2 commits into from
Oct 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.13
require (
github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237 // indirect
golang.org/x/exp v0.0.0-20190312203227-4b39c73a6495
gonum.org/v1/gonum v0.8.1
gonum.org/v1/gonum v0.8.1-0.20200930085651-eea0b5cb5cc9
modernc.org/cc v1.0.0
modernc.org/golex v1.0.0 // indirect
modernc.org/mathutil v1.0.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
gonum.org/v1/gonum v0.8.1-0.20200930085651-eea0b5cb5cc9 h1:EhU7NlbjQJrI6umH0+aSOGAIooYzDS56UvW6jAzplz8=
gonum.org/v1/gonum v0.8.1-0.20200930085651-eea0b5cb5cc9/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
gonum.org/v1/gonum v0.8.1 h1:wGtP3yGpc5mCLOLeTeBdjeui9oZSz5De0eOjMLC/QuQ=
gonum.org/v1/gonum v0.8.1/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
Expand Down
18 changes: 9 additions & 9 deletions lapack/netlib/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ package netlib

import "gonum.org/v1/gonum/blas"

// convDpbToLapacke converts a symmetric band matrix A in CBLAS row-major layout
// to LAPACKE row-major layout and stores the result in B.
// bandTriToLapacke converts a triangular or symmetric band matrix A in CBLAS
// row-major layout to LAPACKE row-major layout and stores the result in B.
//
// For example, when n = 6, kd = 2 and uplo == 'U', convDpbToLapacke converts
// For example, when n = 6, kd = 2 and uplo == 'U', bandTriToLapacke converts
// A = a00 a01 a02
// a11 a12 a13
// a22 a23 a24
Expand All @@ -25,7 +25,7 @@ import "gonum.org/v1/gonum/blas"
// stored in a slice as
// b = [* * a02 a13 a24 a35 * a01 a12 a23 a34 a45 a00 a11 a22 a33 a44 a55]
//
// When n = 6, kd = 2 and uplo == 'L', convDpbToLapacke converts
// When n = 6, kd = 2 and uplo == 'L', bandTriToLapacke converts
// A = * * a00
// * a10 a11
// a20 a21 a22
Expand All @@ -42,7 +42,7 @@ import "gonum.org/v1/gonum/blas"
// b = [a00 a11 a22 a33 a44 a55 a10 a21 a32 a43 a54 * a20 a31 a42 a53 * * ]
//
// In these example elements marked as * are not referenced.
func convDpbToLapacke(uplo blas.Uplo, n, kd int, a []float64, lda int, b []float64, ldb int) {
func bandTriToLapacke(uplo blas.Uplo, n, kd int, a []float64, lda int, b []float64, ldb int) {
if uplo == blas.Upper {
for i := 0; i < n; i++ {
for jb := 0; jb < min(n-i, kd+1); jb++ {
Expand All @@ -60,10 +60,10 @@ func convDpbToLapacke(uplo blas.Uplo, n, kd int, a []float64, lda int, b []float
}
}

// convDpbToGonum converts a symmetric band matrix A in LAPACKE row-major layout
// to CBLAS row-major layout and stores the result in B. In other words, it
// performs the inverse conversion to convDpbToLapacke.
func convDpbToGonum(uplo blas.Uplo, n, kd int, a []float64, lda int, b []float64, ldb int) {
// bandTriToGonum converts a triangular or symmetric band matrix A in LAPACKE
// row-major layout to CBLAS row-major layout and stores the result in B. In
// other words, it performs the inverse conversion to bandTriToLapacke.
func bandTriToGonum(uplo blas.Uplo, n, kd int, a []float64, lda int, b []float64, ldb int) {
if uplo == blas.Upper {
for j := 0; j < n; j++ {
for ib := max(0, kd-j); ib < kd+1; ib++ {
Expand Down
10 changes: 5 additions & 5 deletions lapack/netlib/conv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"gonum.org/v1/gonum/floats"
)

func TestConvDpb(t *testing.T) {
func TestConvBandTri(t *testing.T) {
for ti, test := range []struct {
uplo blas.Uplo
n, kd int
Expand Down Expand Up @@ -72,7 +72,7 @@ func TestConvDpb(t *testing.T) {
}
ldb := max(1, n)

convDpbToLapacke(uplo, n, kd, a, lda, got, ldb)
bandTriToLapacke(uplo, n, kd, a, lda, got, ldb)
if !floats.Equal(test.a, a) {
t.Errorf("%v: unexpected modification of A in conversion to LAPACKE row-major", name)
}
Expand All @@ -88,7 +88,7 @@ func TestConvDpb(t *testing.T) {
got[i] = -1
}

convDpbToGonum(uplo, n, kd, b, ldb, got, lda)
bandTriToGonum(uplo, n, kd, b, ldb, got, lda)
if !floats.Equal(test.b, b) {
t.Errorf("%v: unexpected modification of B in conversion to Gonum row-major", name)
}
Expand Down Expand Up @@ -118,8 +118,8 @@ func TestConvDpb(t *testing.T) {
b[i] = rnd.NormFloat64()
}

convDpbToLapacke(uplo, n, kd, a, lda, b, ldb)
convDpbToGonum(uplo, n, kd, b, ldb, a, lda)
bandTriToLapacke(uplo, n, kd, a, lda, b, ldb)
bandTriToGonum(uplo, n, kd, b, ldb, a, lda)

if !floats.Equal(a, aCopy) {
t.Errorf("%v: conversion does not roundtrip", name)
Expand Down
54 changes: 50 additions & 4 deletions lapack/netlib/lapack.go
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ func (impl Implementation) Dpbcon(uplo blas.Uplo, n, kd int, ab []float64, ldab

_ldab := n
_ab := make([]float64, (kd+1)*_ldab)
convDpbToLapacke(uplo, n, kd, ab, ldab, _ab, _ldab)
bandTriToLapacke(uplo, n, kd, ab, ldab, _ab, _ldab)
_rcond := []float64{0}
_iwork := make([]int32, n)
lapacke.Dpbcon(byte(uplo), n, kd, _ab, _ldab, anorm, _rcond, work, _iwork)
Expand Down Expand Up @@ -892,9 +892,9 @@ func (impl Implementation) Dpbtrf(uplo blas.Uplo, n, kd int, ab []float64, ldab

ldabConv := n
abConv := make([]float64, (kd+1)*ldabConv)
convDpbToLapacke(uplo, n, kd, ab, ldab, abConv, ldabConv)
bandTriToLapacke(uplo, n, kd, ab, ldab, abConv, ldabConv)
info := lapacke.Dpbtrf(byte(uplo), n, kd, abConv, ldabConv)
convDpbToGonum(uplo, n, kd, abConv, ldabConv, ab, ldab)
bandTriToGonum(uplo, n, kd, abConv, ldabConv, ab, ldab)
return info
}

Expand Down Expand Up @@ -937,7 +937,7 @@ func (Implementation) Dpbtrs(uplo blas.Uplo, n, kd, nrhs int, ab []float64, ldab

ldabConv := n
abConv := make([]float64, (kd+1)*ldabConv)
convDpbToLapacke(uplo, n, kd, ab, ldab, abConv, ldabConv)
bandTriToLapacke(uplo, n, kd, ab, ldab, abConv, ldabConv)
lapacke.Dpbtrs(byte(uplo), n, kd, nrhs, abConv, ldabConv, b, ldb)
}

Expand Down Expand Up @@ -3274,6 +3274,52 @@ func (impl Implementation) Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d
lapacke.Dsytrd(byte(uplo), n, a, lda, d, e, tau, work, lwork)
}

// Dtbtrs solves a triangular system of the form
// A * X = B if trans == blas.NoTrans
// Aᵀ * X = B if trans == blas.Trans or blas.ConjTrans
// where A is an n×n triangular band matrix with kd super- or subdiagonals, and
// B is an n×nrhs matrix.
//
// Dtbtrs returns whether A is non-singular. If A is singular, no solution X is
// computed.
func (impl Implementation) Dtbtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, kd, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) {
switch {
case uplo != blas.Upper && uplo != blas.Lower:
panic(badUplo)
case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans:
panic(badTrans)
case diag != blas.NonUnit && diag != blas.Unit:
panic(badDiag)
case n < 0:
panic(nLT0)
case kd < 0:
panic(kdLT0)
case nrhs < 0:
panic(nrhsLT0)
case lda < kd+1:
panic(badLdA)
case ldb < max(1, nrhs):
panic(badLdB)
}

// Quick return if possible.
if n == 0 {
return true
}

switch {
case len(a) < (n-1)*lda+kd+1:
panic(shortA)
case len(b) < (n-1)*ldb+nrhs:
panic(shortB)
}

ldaConv := n
aConv := make([]float64, (kd+1)*ldaConv)
bandTriToLapacke(uplo, n, kd, a, lda, aConv, ldaConv)
return lapacke.Dtbtrs(byte(uplo), byte(trans), byte(diag), n, kd, nrhs, aConv, ldaConv, b, ldb)
}

// Dtrcon estimates the reciprocal of the condition number of a triangular matrix A.
// The condition number computed may be based on the 1-norm or the ∞-norm.
//
Expand Down
5 changes: 5 additions & 0 deletions lapack/netlib/lapack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ func TestDtgsja(t *testing.T) {
testlapack.DtgsjaTest(t, impl)
}

func TestDtbtrs(t *testing.T) {
t.Parallel()
testlapack.DtbtrsTest(t, impl)
}

func TestDtrexc(t *testing.T) {
testlapack.DtrexcTest(t, impl)
}
Expand Down