From 32477bac0b9b9ee399d47f7099eb4ddcf342d6e3 Mon Sep 17 00:00:00 2001 From: Vladimir Chalupecky Date: Sun, 22 Jul 2018 23:17:32 +0200 Subject: [PATCH] testlapack clean isOrthonormal --- lapack/testlapack/dgelq2.go | 19 +++---------------- lapack/testlapack/dgeqp3.go | 16 +++------------- lapack/testlapack/dgeqr2.go | 16 +++------------- lapack/testlapack/dgerq2.go | 16 +++------------- lapack/testlapack/dgerqf.go | 16 +++------------- lapack/testlapack/dgesvd.go | 24 ++++-------------------- lapack/testlapack/dlaqp2.go | 16 +++------------- lapack/testlapack/dlaqps.go | 17 ++++------------- lapack/testlapack/general.go | 31 +++++++++++++++++++------------ 9 files changed, 45 insertions(+), 126 deletions(-) diff --git a/lapack/testlapack/dgelq2.go b/lapack/testlapack/dgelq2.go index 513f4bcbb7..e60dd2b38d 100644 --- a/lapack/testlapack/dgelq2.go +++ b/lapack/testlapack/dgelq2.go @@ -5,7 +5,6 @@ package testlapack import ( - "math" "testing" "golang.org/x/exp/rand" @@ -70,21 +69,9 @@ func Dgelq2Test(t *testing.T, impl Dgelq2er) { Q := constructQ("LQ", m, n, a, lda, tau) - // Check that Q is orthonormal - for i := 0; i < Q.Rows; i++ { - nrm := blas64.Nrm2(Q.Cols, blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]}) - if math.Abs(nrm-1) > 1e-14 { - t.Errorf("Q not normal. Norm is %v", nrm) - } - for j := 0; j < i; j++ { - dot := blas64.Dot(Q.Rows, - blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]}, - blas64.Vector{Inc: 1, Data: Q.Data[j*Q.Stride:]}, - ) - if math.Abs(dot) > 1e-14 { - t.Errorf("Q not orthogonal. Dot is %v", dot) - } - } + // Check that Q is orthogonal. + if !isOrthonormal(Q) { + t.Errorf("Case %v: Q not orthogonal", c) } L := blas64.General{ diff --git a/lapack/testlapack/dgeqp3.go b/lapack/testlapack/dgeqp3.go index 0028dd5371..73e479b9cf 100644 --- a/lapack/testlapack/dgeqp3.go +++ b/lapack/testlapack/dgeqp3.go @@ -5,7 +5,6 @@ package testlapack import ( - "math" "testing" "golang.org/x/exp/rand" @@ -98,18 +97,9 @@ func Dgeqp3Test(t *testing.T, impl Dgeqp3er) { // Q based on the vectors. q := constructQ("QR", m, n, a, lda, tau) - // Check that q is orthonormal - for i := 0; i < m; i++ { - nrm := blas64.Nrm2(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}) - if math.Abs(nrm-1) > 1e-13 { - t.Errorf("Case %v, q not normal", c) - } - for j := 0; j < i; j++ { - dot := blas64.Dot(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}, blas64.Vector{Inc: 1, Data: q.Data[j*m:]}) - if math.Abs(dot) > 1e-14 { - t.Errorf("Case %v, q not orthogonal", c) - } - } + // Check that Q is orthogonal. + if !isOrthonormal(q) { + t.Errorf("Case %v, Q not orthogonal", c) } // Check that A * P = Q * R r := blas64.General{ diff --git a/lapack/testlapack/dgeqr2.go b/lapack/testlapack/dgeqr2.go index 999d263ad7..2247a1439e 100644 --- a/lapack/testlapack/dgeqr2.go +++ b/lapack/testlapack/dgeqr2.go @@ -5,7 +5,6 @@ package testlapack import ( - "math" "testing" "golang.org/x/exp/rand" @@ -72,18 +71,9 @@ func Dgeqr2Test(t *testing.T, impl Dgeqr2er) { // Q based on the vectors. q := constructQ("QR", m, n, a, lda, tau) - // Check that q is orthonormal - for i := 0; i < m; i++ { - nrm := blas64.Nrm2(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}) - if math.Abs(nrm-1) > 1e-14 { - t.Errorf("Case %v, q not normal", c) - } - for j := 0; j < i; j++ { - dot := blas64.Dot(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}, blas64.Vector{Inc: 1, Data: q.Data[j*m:]}) - if math.Abs(dot) > 1e-14 { - t.Errorf("Case %v, q not orthogonal", c) - } - } + // Check that Q is orthogonal. + if !isOrthonormal(q) { + t.Errorf("Case %v, Q not orthogonal", c) } // Check that A = Q * R r := blas64.General{ diff --git a/lapack/testlapack/dgerq2.go b/lapack/testlapack/dgerq2.go index b9607d08af..d05618938e 100644 --- a/lapack/testlapack/dgerq2.go +++ b/lapack/testlapack/dgerq2.go @@ -5,7 +5,6 @@ package testlapack import ( - "math" "testing" "golang.org/x/exp/rand" @@ -71,18 +70,9 @@ func Dgerq2Test(t *testing.T, impl Dgerq2er) { // Q based on the vectors. q := constructQ("RQ", m, n, a, lda, tau) - // Check that q is orthonormal - for i := 0; i < q.Rows; i++ { - nrm := blas64.Nrm2(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}) - if math.IsNaN(nrm) || math.Abs(nrm-1) > 1e-14 { - t.Errorf("Case %v, q not normal", c) - } - for j := 0; j < i; j++ { - dot := blas64.Dot(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}, blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]}) - if math.IsNaN(dot) || math.Abs(dot) > 1e-14 { - t.Errorf("Case %v, q not orthogonal", c) - } - } + // Check that Q is orthogonal. + if !isOrthonormal(q) { + t.Errorf("Case %v, Q not orthogonal", c) } // Check that A = R * Q r := blas64.General{ diff --git a/lapack/testlapack/dgerqf.go b/lapack/testlapack/dgerqf.go index 968ebd83b4..54d452159b 100644 --- a/lapack/testlapack/dgerqf.go +++ b/lapack/testlapack/dgerqf.go @@ -5,7 +5,6 @@ package testlapack import ( - "math" "testing" "golang.org/x/exp/rand" @@ -95,18 +94,9 @@ func DgerqfTest(t *testing.T, impl Dgerqfer) { // Q based on the vectors. q := constructQ("RQ", m, n, a, lda, tau) - // Check that q is orthonormal - for i := 0; i < q.Rows; i++ { - nrm := blas64.Nrm2(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}) - if math.IsNaN(nrm) || math.Abs(nrm-1) > 1e-14 { - t.Errorf("Case %v, q not normal", c) - } - for j := 0; j < i; j++ { - dot := blas64.Dot(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}, blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]}) - if math.IsNaN(dot) || math.Abs(dot) > 1e-14 { - t.Errorf("Case %v, q not orthogonal", c) - } - } + // Check that Q is orthogonal. + if !isOrthonormal(q) { + t.Errorf("Case %v, Q not orthogonal", c) } // Check that A = R * Q r := blas64.General{ diff --git a/lapack/testlapack/dgesvd.go b/lapack/testlapack/dgesvd.go index 4042e1566c..6295b7d4ea 100644 --- a/lapack/testlapack/dgesvd.go +++ b/lapack/testlapack/dgesvd.go @@ -261,27 +261,11 @@ func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float6 if !thin { // Check that U and V are orthogonal. - for i := 0; i < uMat.Rows; i++ { - for j := i + 1; j < uMat.Rows; j++ { - dot := blas64.Dot(uMat.Cols, - blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]}, - blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]}, - ) - if dot > 1e-8 { - t.Errorf("U not orthogonal %s", errStr) - } - } + if !isOrthonormal(uMat) { + t.Errorf("U not orthogonal %s", errStr) } - for i := 0; i < vTMat.Rows; i++ { - for j := i + 1; j < vTMat.Rows; j++ { - dot := blas64.Dot(vTMat.Cols, - blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]}, - blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]}, - ) - if dot > 1e-8 { - t.Errorf("V not orthogonal %s", errStr) - } - } + if !isOrthonormal(vTMat) { + t.Errorf("V not orthogonal %s", errStr) } } } diff --git a/lapack/testlapack/dlaqp2.go b/lapack/testlapack/dlaqp2.go index 1e8aee4215..99f3929af0 100644 --- a/lapack/testlapack/dlaqp2.go +++ b/lapack/testlapack/dlaqp2.go @@ -6,7 +6,6 @@ package testlapack import ( "fmt" - "math" "testing" "gonum.org/v1/gonum/blas" @@ -74,18 +73,9 @@ func Dlaqp2Test(t *testing.T, impl Dlaqp2er) { mo := m - test.offset q := constructQ("QR", mo, n, a.Data[test.offset*a.Stride:], a.Stride, tau) - // Check that q is orthonormal - for i := 0; i < mo; i++ { - nrm := blas64.Nrm2(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}) - if math.Abs(nrm-1) > 1e-13 { - t.Errorf("Case %v, q not normal", ti) - } - for j := 0; j < i; j++ { - dot := blas64.Dot(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}, blas64.Vector{Inc: 1, Data: q.Data[j*mo:]}) - if math.Abs(dot) > 1e-14 { - t.Errorf("Case %v, q not orthogonal", ti) - } - } + // Check that Q is orthogonal. + if !isOrthonormal(q) { + t.Errorf("Case %v, Q not orthogonal", ti) } // Check that A * P = Q * R diff --git a/lapack/testlapack/dlaqps.go b/lapack/testlapack/dlaqps.go index 590489f16c..647b4ed847 100644 --- a/lapack/testlapack/dlaqps.go +++ b/lapack/testlapack/dlaqps.go @@ -6,7 +6,6 @@ package testlapack import ( "fmt" - "math" "testing" "gonum.org/v1/gonum/blas" @@ -70,18 +69,10 @@ func DlaqpsTest(t *testing.T, impl Dlaqpser) { mo := m - test.offset q := constructQ("QR", mo, kb, a.Data[test.offset*a.Stride:], a.Stride, tau) - // Check that q is orthonormal - for i := 0; i < mo; i++ { - nrm := blas64.Nrm2(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}) - if math.Abs(nrm-1) > 1e-13 { - t.Errorf("Case %v, q not normal", ti) - } - for j := 0; j < i; j++ { - dot := blas64.Dot(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}, blas64.Vector{Inc: 1, Data: q.Data[j*mo:]}) - if math.Abs(dot) > 1e-14 { - t.Errorf("Case %v, q not orthogonal", ti) - } - } + + // Check that Q is orthogonal. + if !isOrthonormal(q) { + t.Errorf("Case %v, Q not orthogonal", ti) } // Check that A * P = Q * R diff --git a/lapack/testlapack/general.go b/lapack/testlapack/general.go index 15b2febba9..d8bcf5f659 100644 --- a/lapack/testlapack/general.go +++ b/lapack/testlapack/general.go @@ -818,26 +818,33 @@ func printRowise(a []float64, m, n, lda int, beyond bool) { } } -// isOrthonormal checks that a general matrix is orthonormal. +// isOrthonormal returns whether a square matrix Q is orthogonal. func isOrthonormal(q blas64.General) bool { + if q.Rows != q.Cols { + panic("matrix not square") + } n := q.Rows + // A real square matrix is orthogonal if and only if its rows form + // an orthonormal basis of the Euclidean space R^n. + const tol = 1e-10 for i := 0; i < n; i++ { - for j := i; j < n; j++ { + nrm := blas64.Nrm2(n, blas64.Vector{Data: q.Data[i*q.Stride:], Inc: 1}) + if math.IsNaN(nrm) { + return false + } + if math.Abs(nrm-1) > tol { + return false + } + for j := i + 1; j < n; j++ { dot := blas64.Dot(n, - blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}, - blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]}, + blas64.Vector{Data: q.Data[i*q.Stride:], Inc: 1}, + blas64.Vector{Data: q.Data[j*q.Stride:], Inc: 1}, ) if math.IsNaN(dot) { return false } - if i == j { - if math.Abs(dot-1) > 1e-10 { - return false - } - } else { - if math.Abs(dot) > 1e-10 { - return false - } + if math.Abs(dot) > tol { + return false } } }