forked from gonum/gonum
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dgetc2.go
135 lines (116 loc) · 3.33 KB
/
dgetc2.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright ©2021 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testlapack
import (
"fmt"
"math"
"testing"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/gonum/lapack"
)
type Dgetc2er interface {
Dgetc2(n int, a []float64, lda int, ipiv, jpiv []int) (k int)
}
func Dgetc2Test(t *testing.T, impl Dgetc2er) {
rnd := rand.New(rand.NewSource(1))
for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 20} {
for _, lda := range []int{n, n + 5} {
dgetc2Test(t, impl, rnd, n, lda, false)
dgetc2Test(t, impl, rnd, n, lda, true)
}
}
}
func dgetc2Test(t *testing.T, impl Dgetc2er, rnd *rand.Rand, n, lda int, perturb bool) {
const tol = 1e-14
name := fmt.Sprintf("n=%v,lda=%v,perturb=%v", n, lda, perturb)
// Generate a random lower-triangular matrix with unit diagonal.
l := randomGeneral(n, n, max(1, n), rnd)
for i := 0; i < n; i++ {
l.Data[i*l.Stride+i] = 1
for j := i + 1; j < n; j++ {
l.Data[i*l.Stride+j] = 0
}
}
// Generate a random upper-triangular matrix.
u := randomGeneral(n, n, max(1, n), rnd)
for i := 0; i < n; i++ {
for j := 0; j < i; j++ {
u.Data[i*u.Stride+j] = 0
}
}
if perturb && n > 0 {
// Make U singular by randomly placing a zero on the diagonal.
i := rnd.Intn(n)
u.Data[i*u.Stride+i] = 0
}
// Construct A = L*U.
a := zeros(n, n, max(1, lda))
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, l, u, 0, a)
// Allocate slices for pivots and pre-fill them with invalid indices.
ipiv := make([]int, n)
jpiv := make([]int, n)
for i := 0; i < n; i++ {
ipiv[i] = -1
jpiv[i] = -1
}
// Call Dgetc2 to compute the LU decomposition.
lu := cloneGeneral(a)
k := impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv)
if n == 0 {
return
}
if perturb && k < 0 {
t.Errorf("%v: expected matrix perturbation", name)
}
// Verify all indices have been set.
for i := 0; i < n; i++ {
if ipiv[i] < 0 {
t.Errorf("%v: ipiv[%d] is not set", name, i)
}
if jpiv[i] < 0 {
t.Errorf("%v: jpiv[%d] is not set", name, i)
}
}
// Construct L and U matrices from Dgetc2 output.
l = zeros(n, n, n)
u = zeros(n, n, n)
for i := 0; i < n; i++ {
for j := 0; j < i; j++ {
l.Data[i*l.Stride+j] = lu.Data[i*lu.Stride+j]
}
l.Data[i*l.Stride+i] = 1
for j := i; j < n; j++ {
u.Data[i*u.Stride+j] = lu.Data[i*lu.Stride+j]
}
}
diff := zeros(n, n, n)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, l, u, 0, diff)
// Apply permutation matrices P and Q to L*U.
for i := n - 1; i >= 0; i-- {
ipv := ipiv[i]
if ipv != i {
row1 := blas64.Vector{N: n, Data: diff.Data[i*diff.Stride:], Inc: 1}
row2 := blas64.Vector{N: n, Data: diff.Data[ipv*diff.Stride:], Inc: 1}
blas64.Swap(row1, row2)
}
jpv := jpiv[i]
if jpv != i {
col1 := blas64.Vector{N: n, Data: diff.Data[i:], Inc: diff.Stride}
col2 := blas64.Vector{N: n, Data: diff.Data[jpv:], Inc: diff.Stride}
blas64.Swap(col1, col2)
}
}
// Compute the residual |P*L*U*Q - A| and check that it is small.
for i := 0; i < n; i++ {
for j := 0; j < n; j++ {
diff.Data[i*diff.Stride+j] -= a.Data[i*a.Stride+j]
}
}
resid := dlange(lapack.MaxColumnSum, n, n, diff.Data, diff.Stride)
if resid > tol || math.IsNaN(resid) {
t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol)
}
}