Skip to content
This repository has been archived by the owner on Nov 24, 2018. It is now read-only.

Commit

Permalink
dtgsja towards functional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Feb 20, 2017
1 parent 10492e2 commit ba3a174
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions testlapack/dtgsja.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package testlapack

import (
"fmt"
"math/rand"
"testing"

Expand All @@ -28,6 +27,7 @@ func DtgsjaTest(t *testing.T, impl Dtgsjaer) {
ok bool
}{
{m: 10, p: 10, n: 10, k: 5, l: 3, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
{m: 10, p: 10, n: 10, k: 6, l: 4, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
} {
m := test.m
p := test.p
Expand Down Expand Up @@ -126,7 +126,10 @@ func DtgsjaTest(t *testing.T, impl Dtgsjaer) {
dst := zeroR
dst.Cols = k + l
dst.Data = zeroR.Data[n-k-l:]
copyGeneral(dst, a)
src := a
src.Cols = k + l
src.Data = a.Data[n-k-l:]
copyGeneral(dst, src)

// D1
for i := 0; i < k; i++ {
Expand All @@ -146,11 +149,18 @@ func DtgsjaTest(t *testing.T, impl Dtgsjaer) {
dst.Rows = m
dst.Cols = k + l
dst.Data = zeroR.Data[n-k-l:]
copyGeneral(dst, a)
src := a
src.Cols = m
src.Data = a.Data[n-k-l:]
copyGeneral(dst, src)
dst.Rows = k + l - m
dst.Cols = k + l - m
dst.Data = zeroR.Data[m*zeroR.Stride+n-(k+l-m):]
copyGeneral(dst, b)
src = b
src.Rows = k + l - m
src.Cols = k + l - m
src.Data = b.Data[(m-k)*b.Stride+n+m-k-l:]
copyGeneral(dst, src)

// D1
for i := 0; i < k; i++ {
Expand Down Expand Up @@ -178,6 +188,11 @@ func DtgsjaTest(t *testing.T, impl Dtgsjaer) {
d10r := nanGeneral(m, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d1, zeroR, 0, d10r)

if !equalApproxGeneral(uAns, d10r, 1e-14) {
t.Errorf("test %d: U^T*A*Q != D1*[ 0 R ]\nU^T*A*Q:\n%+v\nD1*[ 0 R ]:\n%+v",
cas, uAns, d10r)
}

// Check V^T*B*Q = D2*[ 0 R ].
vTmp := nanGeneral(p, n, n)
blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp)
Expand All @@ -187,9 +202,9 @@ func DtgsjaTest(t *testing.T, impl Dtgsjaer) {
d20r := nanGeneral(p, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d2, zeroR, 0, d20r)

fmt.Println(uAns)
fmt.Println(d10r)
fmt.Println(vAns)
fmt.Println(d20r)
if !equalApproxGeneral(vAns, d20r, 1e-14) {
t.Errorf("test %d: V^T*B*Q != D2*[ 0 R ]\nV^T*B*Q:\n%+v\nD2*[ 0 R ]:\n%+v",
cas, vAns, d20r)
}
}
}

0 comments on commit ba3a174

Please sign in to comment.