Skip to content

Commit

Permalink
testlapack clean isOrthonormal
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Jul 22, 2018
1 parent 286f685 commit 32477ba
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 126 deletions.
19 changes: 3 additions & 16 deletions lapack/testlapack/dgelq2.go
Expand Up @@ -5,7 +5,6 @@
package testlapack

import (
"math"
"testing"

"golang.org/x/exp/rand"
Expand Down Expand Up @@ -70,21 +69,9 @@ func Dgelq2Test(t *testing.T, impl Dgelq2er) {

Q := constructQ("LQ", m, n, a, lda, tau)

// Check that Q is orthonormal
for i := 0; i < Q.Rows; i++ {
nrm := blas64.Nrm2(Q.Cols, blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]})
if math.Abs(nrm-1) > 1e-14 {
t.Errorf("Q not normal. Norm is %v", nrm)
}
for j := 0; j < i; j++ {
dot := blas64.Dot(Q.Rows,
blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]},
blas64.Vector{Inc: 1, Data: Q.Data[j*Q.Stride:]},
)
if math.Abs(dot) > 1e-14 {
t.Errorf("Q not orthogonal. Dot is %v", dot)
}
}
// Check that Q is orthogonal.
if !isOrthonormal(Q) {
t.Errorf("Case %v: Q not orthogonal", c)
}

L := blas64.General{
Expand Down
16 changes: 3 additions & 13 deletions lapack/testlapack/dgeqp3.go
Expand Up @@ -5,7 +5,6 @@
package testlapack

import (
"math"
"testing"

"golang.org/x/exp/rand"
Expand Down Expand Up @@ -98,18 +97,9 @@ func Dgeqp3Test(t *testing.T, impl Dgeqp3er) {
// Q based on the vectors.
q := constructQ("QR", m, n, a, lda, tau)

// Check that q is orthonormal
for i := 0; i < m; i++ {
nrm := blas64.Nrm2(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]})
if math.Abs(nrm-1) > 1e-13 {
t.Errorf("Case %v, q not normal", c)
}
for j := 0; j < i; j++ {
dot := blas64.Dot(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}, blas64.Vector{Inc: 1, Data: q.Data[j*m:]})
if math.Abs(dot) > 1e-14 {
t.Errorf("Case %v, q not orthogonal", c)
}
}
// Check that Q is orthogonal.
if !isOrthonormal(q) {
t.Errorf("Case %v, Q not orthogonal", c)
}
// Check that A * P = Q * R
r := blas64.General{
Expand Down
16 changes: 3 additions & 13 deletions lapack/testlapack/dgeqr2.go
Expand Up @@ -5,7 +5,6 @@
package testlapack

import (
"math"
"testing"

"golang.org/x/exp/rand"
Expand Down Expand Up @@ -72,18 +71,9 @@ func Dgeqr2Test(t *testing.T, impl Dgeqr2er) {
// Q based on the vectors.
q := constructQ("QR", m, n, a, lda, tau)

// Check that q is orthonormal
for i := 0; i < m; i++ {
nrm := blas64.Nrm2(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]})
if math.Abs(nrm-1) > 1e-14 {
t.Errorf("Case %v, q not normal", c)
}
for j := 0; j < i; j++ {
dot := blas64.Dot(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}, blas64.Vector{Inc: 1, Data: q.Data[j*m:]})
if math.Abs(dot) > 1e-14 {
t.Errorf("Case %v, q not orthogonal", c)
}
}
// Check that Q is orthogonal.
if !isOrthonormal(q) {
t.Errorf("Case %v, Q not orthogonal", c)
}
// Check that A = Q * R
r := blas64.General{
Expand Down
16 changes: 3 additions & 13 deletions lapack/testlapack/dgerq2.go
Expand Up @@ -5,7 +5,6 @@
package testlapack

import (
"math"
"testing"

"golang.org/x/exp/rand"
Expand Down Expand Up @@ -71,18 +70,9 @@ func Dgerq2Test(t *testing.T, impl Dgerq2er) {
// Q based on the vectors.
q := constructQ("RQ", m, n, a, lda, tau)

// Check that q is orthonormal
for i := 0; i < q.Rows; i++ {
nrm := blas64.Nrm2(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]})
if math.IsNaN(nrm) || math.Abs(nrm-1) > 1e-14 {
t.Errorf("Case %v, q not normal", c)
}
for j := 0; j < i; j++ {
dot := blas64.Dot(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}, blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]})
if math.IsNaN(dot) || math.Abs(dot) > 1e-14 {
t.Errorf("Case %v, q not orthogonal", c)
}
}
// Check that Q is orthogonal.
if !isOrthonormal(q) {
t.Errorf("Case %v, Q not orthogonal", c)
}
// Check that A = R * Q
r := blas64.General{
Expand Down
16 changes: 3 additions & 13 deletions lapack/testlapack/dgerqf.go
Expand Up @@ -5,7 +5,6 @@
package testlapack

import (
"math"
"testing"

"golang.org/x/exp/rand"
Expand Down Expand Up @@ -95,18 +94,9 @@ func DgerqfTest(t *testing.T, impl Dgerqfer) {
// Q based on the vectors.
q := constructQ("RQ", m, n, a, lda, tau)

// Check that q is orthonormal
for i := 0; i < q.Rows; i++ {
nrm := blas64.Nrm2(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]})
if math.IsNaN(nrm) || math.Abs(nrm-1) > 1e-14 {
t.Errorf("Case %v, q not normal", c)
}
for j := 0; j < i; j++ {
dot := blas64.Dot(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}, blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]})
if math.IsNaN(dot) || math.Abs(dot) > 1e-14 {
t.Errorf("Case %v, q not orthogonal", c)
}
}
// Check that Q is orthogonal.
if !isOrthonormal(q) {
t.Errorf("Case %v, Q not orthogonal", c)
}
// Check that A = R * Q
r := blas64.General{
Expand Down
24 changes: 4 additions & 20 deletions lapack/testlapack/dgesvd.go
Expand Up @@ -261,27 +261,11 @@ func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float6

if !thin {
// Check that U and V are orthogonal.
for i := 0; i < uMat.Rows; i++ {
for j := i + 1; j < uMat.Rows; j++ {
dot := blas64.Dot(uMat.Cols,
blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]},
blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]},
)
if dot > 1e-8 {
t.Errorf("U not orthogonal %s", errStr)
}
}
if !isOrthonormal(uMat) {
t.Errorf("U not orthogonal %s", errStr)
}
for i := 0; i < vTMat.Rows; i++ {
for j := i + 1; j < vTMat.Rows; j++ {
dot := blas64.Dot(vTMat.Cols,
blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]},
blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]},
)
if dot > 1e-8 {
t.Errorf("V not orthogonal %s", errStr)
}
}
if !isOrthonormal(vTMat) {
t.Errorf("V not orthogonal %s", errStr)
}
}
}
16 changes: 3 additions & 13 deletions lapack/testlapack/dlaqp2.go
Expand Up @@ -6,7 +6,6 @@ package testlapack

import (
"fmt"
"math"
"testing"

"gonum.org/v1/gonum/blas"
Expand Down Expand Up @@ -74,18 +73,9 @@ func Dlaqp2Test(t *testing.T, impl Dlaqp2er) {

mo := m - test.offset
q := constructQ("QR", mo, n, a.Data[test.offset*a.Stride:], a.Stride, tau)
// Check that q is orthonormal
for i := 0; i < mo; i++ {
nrm := blas64.Nrm2(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]})
if math.Abs(nrm-1) > 1e-13 {
t.Errorf("Case %v, q not normal", ti)
}
for j := 0; j < i; j++ {
dot := blas64.Dot(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}, blas64.Vector{Inc: 1, Data: q.Data[j*mo:]})
if math.Abs(dot) > 1e-14 {
t.Errorf("Case %v, q not orthogonal", ti)
}
}
// Check that Q is orthogonal.
if !isOrthonormal(q) {
t.Errorf("Case %v, Q not orthogonal", ti)
}

// Check that A * P = Q * R
Expand Down
17 changes: 4 additions & 13 deletions lapack/testlapack/dlaqps.go
Expand Up @@ -6,7 +6,6 @@ package testlapack

import (
"fmt"
"math"
"testing"

"gonum.org/v1/gonum/blas"
Expand Down Expand Up @@ -70,18 +69,10 @@ func DlaqpsTest(t *testing.T, impl Dlaqpser) {

mo := m - test.offset
q := constructQ("QR", mo, kb, a.Data[test.offset*a.Stride:], a.Stride, tau)
// Check that q is orthonormal
for i := 0; i < mo; i++ {
nrm := blas64.Nrm2(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]})
if math.Abs(nrm-1) > 1e-13 {
t.Errorf("Case %v, q not normal", ti)
}
for j := 0; j < i; j++ {
dot := blas64.Dot(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}, blas64.Vector{Inc: 1, Data: q.Data[j*mo:]})
if math.Abs(dot) > 1e-14 {
t.Errorf("Case %v, q not orthogonal", ti)
}
}

// Check that Q is orthogonal.
if !isOrthonormal(q) {
t.Errorf("Case %v, Q not orthogonal", ti)
}

// Check that A * P = Q * R
Expand Down
31 changes: 19 additions & 12 deletions lapack/testlapack/general.go
Expand Up @@ -818,26 +818,33 @@ func printRowise(a []float64, m, n, lda int, beyond bool) {
}
}

// isOrthonormal checks that a general matrix is orthonormal.
// isOrthonormal returns whether a square matrix Q is orthogonal.
func isOrthonormal(q blas64.General) bool {
if q.Rows != q.Cols {
panic("matrix not square")
}
n := q.Rows
// A real square matrix is orthogonal if and only if its rows form
// an orthonormal basis of the Euclidean space R^n.
const tol = 1e-10
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
nrm := blas64.Nrm2(n, blas64.Vector{Data: q.Data[i*q.Stride:], Inc: 1})
if math.IsNaN(nrm) {
return false
}
if math.Abs(nrm-1) > tol {
return false
}
for j := i + 1; j < n; j++ {
dot := blas64.Dot(n,
blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]},
blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]},
blas64.Vector{Data: q.Data[i*q.Stride:], Inc: 1},
blas64.Vector{Data: q.Data[j*q.Stride:], Inc: 1},
)
if math.IsNaN(dot) {
return false
}
if i == j {
if math.Abs(dot-1) > 1e-10 {
return false
}
} else {
if math.Abs(dot) > 1e-10 {
return false
}
if math.Abs(dot) > tol {
return false
}
}
}
Expand Down

0 comments on commit 32477ba

Please sign in to comment.