Skip to content

Commit

Permalink
blas/testblas: add test for Ztrsv
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Nov 26, 2017
1 parent b4ff6c6 commit fb148a2
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 4 deletions.
4 changes: 4 additions & 0 deletions blas/gonum/level2cmplx128_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ func TestZher2(t *testing.T) {
func TestZtrmv(t *testing.T) {
testblas.ZtrmvTest(t, impl)
}

func TestZtrsv(t *testing.T) {
testblas.ZtrsvTest(t, impl)
}
26 changes: 22 additions & 4 deletions blas/testblas/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,22 @@ func zsame(x, y []complex128) bool {
return true
}

func zEqualApprox(x, y []complex128, tol float64) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
w := y[i]
if cmplx.IsNaN(v) && cmplx.IsNaN(w) {
continue
}
if cmplx.Abs(v-w) > tol {
return false
}
}
return true
}

func makeZVector(data []complex128, inc int) []complex128 {
if inc == 0 {
panic("bad test")
Expand All @@ -289,21 +305,23 @@ func makeZGeneral(data []complex128, m, n int, ld int) []complex128 {
if m < 0 || n < 0 {
panic("bad test")
}
if len(data) != m*n {
if data != nil && len(data) != m*n {
panic("bad test")
}
if ld < max(1, n) {
panic("bad test")
}
if len(data) == 0 {
if m == 0 || n == 0 {
return nil
}
a := make([]complex128, (m-1)*ld+n)
for i := range a {
a[i] = cmplx.NaN()
}
for i := 0; i < m; i++ {
copy(a[i*ld:i*ld+n], data[i*n:i*n+n])
if data != nil {
for i := 0; i < m; i++ {
copy(a[i*ld:i*ld+n], data[i*n:i*n+n])
}
}
return a
}
Expand Down
94 changes: 94 additions & 0 deletions blas/testblas/ztrsv.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright ©2017 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package testblas

import (
"math/cmplx"
"testing"

"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
)

type Ztrsver interface {
Ztrsv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n int, a []complex128, lda int, x []complex128, incX int)

Ztrmver
}

func ZtrsvTest(t *testing.T, impl Ztrsver) {
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
for _, trans := range []blas.Transpose{blas.NoTrans} {
for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} {
for _, n := range []int{0, 1, 2, 3, 4, 20, 50} {
for _, lda := range []int{max(1, n), n + 11} {
// for _, incX := range []int{-11, -3, -2, -1, 1, 2, 3, 7} {
for _, incX := range []int{1, 2} {
ztrsvTest(t, impl, uplo, trans, diag, n, lda, incX)
}
}
}
}
}
}
}

func ztrsvTest(t *testing.T, impl Ztrsver, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, lda, incX int) {
rnd := rand.New(rand.NewSource(1))
a := makeZGeneral(nil, n, n, lda)
if uplo == blas.Upper {
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
re := float64(rnd.Intn(21) - 10)
im := float64(rnd.Intn(21) - 10)
if i == j && re == 0 && im == 0 {
re = float64(rnd.Intn(10) + 1)
im = float64(rnd.Intn(10) + 1)
}
a[i*lda+j] = complex(re, im)
}
}
} else {
for i := 0; i < n; i++ {
for j := 0; j <= i; j++ {
re := float64(rnd.Intn(21) - 10)
im := float64(rnd.Intn(21) - 10)
if i == j && re == 0 && im == 0 {
re = float64(rnd.Intn(10) + 1)
im = float64(rnd.Intn(10) + 1)
}
a[i*lda+j] = complex(re, im)
}
}
}
if diag == blas.Unit {
for i := 0; i < n; i++ {
a[i*lda+i] = cmplx.NaN()
}
}
aCopy := make([]complex128, len(a))
copy(aCopy, a)

xtest := make([]complex128, n)
for i := range xtest {
re := float64(rnd.Intn(21) - 10)
im := float64(rnd.Intn(21) - 10)
xtest[i] = complex(re, im)
}
x := makeZVector(xtest, incX)
want := make([]complex128, len(x))
copy(want, x)

impl.Ztrmv(uplo, trans, diag, n, a, lda, x, incX)
impl.Ztrsv(uplo, trans, diag, n, a, lda, x, incX)

if !zsame(a, aCopy) {
t.Errorf("Case uplo=%v,trans=%v,diag=%v,n=%v,lda=%v,incX=%v: unexpected modification of A", uplo, trans, diag, n, lda, incX)
}

if !zEqualApprox(x, want, 1e-9) {
t.Errorf("Case uplo=%v,trans=%v,diag=%v,n=%v,lda=%v,incX=%v: unexpected result\nwant %v\ngot %v", uplo, trans, diag, n, lda, incX, want, x)
}
}

0 comments on commit fb148a2

Please sign in to comment.