forked from hrautila/linalg
/
gbtrs.go
204 lines (190 loc) · 5.57 KB
/
gbtrs.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
// Copyright (c) Harri Rautila, 2012, 2013
// This file is part of github.com/hrautila/linalg/lapack package.
// It is free software, distributed under the terms of GNU Lesser General Public
// License Version 3, or any later version. See the COPYING tile included in this archive.
package lapack
import (
//"errors"
"fmt"
"github.com/hrautila/linalg"
"github.com/hrautila/matrix"
)
/*
Solves a real or complex set of linear equations with a banded
coefficient matrix, given the LU factorization computed by gbtrf()
or gbsv().
PURPOSE
Solves linear equations
A*X = B, if trans is PNoTrans
A^T*X = B, if trans is PTrans
A^H*X = B, if trans is PConjTrans
On entry, A and ipiv contain the LU factorization of an n by n
band matrix A as computed by Getrf() or Gbsv(). On exit B is
replaced by the solution X.
ARGUMENTS
A float or complex matrix
B float or complex matrix. Must have the same type as A.
ipiv int vector
kl nonnegative integer
OPTIONS
trans PNoTrans, PTrans or PConjTrans
n nonnegative integer. If negative, the default value is used.
ku nonnegative integer. If negative, the default value is used.
nrhs nonnegative integer. If negative, the default value is used.
ldA positive integer, ldA >= 2*kl+ku+1. If zero, the default value is used.
ldB positive integer, ldB >= max(1,n). If zero, the default value is used.
offsetA nonnegative integer
offsetB nonnegative integer;
*/
func Gbtrs(A, B matrix.Matrix, ipiv []int32, KL int, opts ...linalg.Option) error {
pars, err := linalg.GetParameters(opts...)
if err != nil {
return err
}
ind := linalg.GetIndexOpts(opts...)
ind.Kl = KL
arows := ind.LDa
brows := ind.LDb
if ind.Kl < 0 {
return onError("Gbtrs: invalid kl")
}
if ind.N < 0 {
ind.N = A.Rows()
}
if ind.Nrhs < 0 {
ind.Nrhs = A.Cols()
}
if ind.N == 0 || ind.Nrhs == 0 {
return nil
}
if ind.Ku < 0 {
ind.Ku = A.Rows() - 2*ind.Kl - 1
}
if ind.Ku < 0 {
return onError("Gbtrs: invalid ku")
}
if ind.LDa == 0 {
ind.LDa = max(1, A.LeadingIndex())
arows = max(1, A.Rows())
}
if ind.LDa < 2*ind.Kl+ind.Ku+1 {
return onError("Gbtrs: ldA")
}
if ind.OffsetA < 0 {
return onError("Gbtrs: offsetA")
}
sizeA := A.NumElements()
if sizeA < ind.OffsetA+(ind.N-1)*arows+2*ind.Kl+ind.Ku+1 {
return onError("Gbtrs: sizeA")
}
if ind.LDb == 0 {
ind.LDb = max(1, B.LeadingIndex())
brows = max(1, B.Rows())
}
if ind.OffsetB < 0 {
return onError("Gbtrs: offsetB")
}
sizeB := B.NumElements()
if sizeB < ind.OffsetB+(ind.Nrhs-1)*brows+ind.N {
return onError("Gbtrs: sizeB")
}
if ipiv != nil && len(ipiv) < ind.N {
return onError("Gbtrs: size ipiv")
}
if !matrix.EqualTypes(A, B) {
return onError("Gbtrs: arguments not of same type")
}
info := -1
switch A.(type) {
case *matrix.FloatMatrix:
Aa := A.(*matrix.FloatMatrix).FloatArray()
Ba := B.(*matrix.FloatMatrix).FloatArray()
trans := linalg.ParamString(pars.Trans)
info = dgbtrs(trans, ind.N, ind.Kl, ind.Ku, ind.Nrhs,
Aa[ind.OffsetA:], ind.LDa, ipiv, Ba[ind.OffsetB:], ind.LDb)
case *matrix.ComplexMatrix:
return onError("Gbtrs: complex not yet implemented")
}
if info != 0 {
return onError(fmt.Sprintf("Gbtrs lapack error: %d", info))
}
return nil
}
func GbtrsFloat(A, B *matrix.FloatMatrix, ipiv []int32, KL int, opts ...linalg.Option) error {
pars, err := linalg.GetParameters(opts...)
if err != nil {
return err
}
ind := linalg.GetIndexOpts(opts...)
ind.Kl = KL
err = checkGbtrs(ind, A, B, ipiv)
if err != nil {
return err
}
if ind.N == 0 || ind.Nrhs == 0 {
return nil
}
Aa := A.FloatArray()
Ba := B.FloatArray()
trans := linalg.ParamString(pars.Trans)
info := dgbtrs(trans, ind.N, ind.Kl, ind.Ku, ind.Nrhs,
Aa[ind.OffsetA:], ind.LDa, ipiv, Ba[ind.OffsetB:], ind.LDb)
if info != 0 {
return onError(fmt.Sprintf("Gbtrs: lapack error: %d", info))
}
return nil
}
func checkGbtrs(ind *linalg.IndexOpts, A, B matrix.Matrix, ipiv []int32) error {
arows := ind.LDa
brows := ind.LDb
if ind.Kl < 0 {
return onError("Gbtrs: invalid kl")
}
if ind.N < 0 {
ind.N = A.Rows()
}
if ind.Nrhs < 0 {
ind.Nrhs = A.Cols()
}
if ind.N == 0 || ind.Nrhs == 0 {
return nil
}
if ind.Ku < 0 {
ind.Ku = A.Rows() - 2*ind.Kl - 1
}
if ind.Ku < 0 {
return onError("Gbtrs: invalid ku")
}
if ind.LDa == 0 {
ind.LDa = max(1, A.LeadingIndex())
arows = max(1, A.Rows())
}
if ind.LDa < 2*ind.Kl+ind.Ku+1 {
return onError("Gbtrs: lda")
}
if ind.OffsetA < 0 {
return onError("Gbtrs: offsetA")
}
sizeA := A.NumElements()
if sizeA < ind.OffsetA+(ind.N-1)*arows+2*ind.Kl+ind.Ku+1 {
return onError("Gbtrs: sizeA")
}
if ind.LDb == 0 {
ind.LDb = max(1, B.LeadingIndex())
brows = max(1, B.Rows())
}
if ind.OffsetB < 0 {
return onError("Gbtrs: offsetB")
}
sizeB := B.NumElements()
if sizeB < ind.OffsetB+(ind.Nrhs-1)*brows+ind.N {
return onError("Gbtrs: sizeB")
}
if ipiv != nil && len(ipiv) < ind.N {
return onError("Gbtrs: size ipiv")
}
return nil
}
// Local Variables:
// tab-width: 4
// End: