Skip to content
This repository was archived by the owner on Nov 24, 2018. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cgo/lapack.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
// Copied from lapack/native. Keep in sync.
const (
absIncNotOne = "lapack: increment not one or negative one"
badDiag = "lapack: bad diag"
badDirect = "lapack: bad direct"
badIpiv = "lapack: insufficient permutation length"
badLdA = "lapack: index of a out of range"
Expand Down
334 changes: 334 additions & 0 deletions native/dlatrs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
package native

import (
"math"

"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
)

// Dlatrs solves a triangular system of equations scaled to prevent overflow. It
// solves
// A * x = scale * b if trans == blas.NoTrans
// A^T * x = scale * b if trans == blas.Trans
// where the scale s is set for numeric stability.
//
// A is an n×n triangular matrix. On entry, the slice x contains the values of
// of b, and on exit it contains the solution vector x.
//
// If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal
// part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater
// than or equal to the infinity norm, and greater than or equal to the one-norm
// otherwise. If normin == false, then cnorm is treated as an output, and is set
// to contain the 1-norm of the off-diagonal part of the j^th column of A.
func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) {
if uplo != blas.Upper && uplo != blas.Lower {
panic(badUplo)
}
if trans != blas.Trans && trans != blas.NoTrans {
panic(badTrans)
}
if diag != blas.Unit && diag != blas.NonUnit {
panic(badDiag)
}
upper := uplo == blas.Upper
noTrans := trans == blas.NoTrans
nonUnit := diag == blas.NonUnit

if n < 0 {
panic(nLT0)
}
checkMatrix(n, n, a, lda)
checkVector(n, x, 1)
checkVector(n, cnorm, 1)

if n == 0 {
return
}
scale = 1
bi := blas64.Implementation()
if !normin {
if upper {
for j := 0; j < n; j++ {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take it the lda here is because we are row major.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. It's computing the sum over the column.

cnorm[j] = bi.Dasum(j, a[j:], lda)
}
} else {
for j := 0; j < n-1; j++ {
cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda)
}
cnorm[n-1] = 0
}
}
// Scale the column norms by tscal if the maximum element in cnorm is greater than bignum.
imax := bi.Idamax(n, cnorm, 1)
tmax := cnorm[imax]
var tscal float64
if tmax <= bignum {
tscal = 1
} else {
tscal = 1 / (smlnum * tmax)
bi.Dscal(n, tscal, cnorm, 1)
}

// Compute a bound on the computed solution vector to see if bi.Dtrsv can be used.
j := bi.Idamax(n, x, 1)
xmax := math.Abs(x[j])
xbnd := xmax
var grow float64
var jfirst, jlast, jinc int
if noTrans {
if upper {
jfirst = n - 1
jlast = 0
jinc = -1
} else {
jfirst = 0
jlast = n - 1
jinc = 1
}
// Compute the growth in A * x = b.
if tscal != 1 {
grow = 0
goto Finish
}
if nonUnit {
grow = 1 / math.Max(xbnd, smlnum)
xbnd = grow
for j := jfirst; j != jlast; j += jinc {
if grow <= smlnum {
goto Finish
}
tjj := math.Abs(a[j*lda+j])
xbnd = math.Min(xbnd, math.Min(1, tjj)*grow)
if tjj+cnorm[j] >= smlnum {
grow *= tjj / (tjj + cnorm[j])
} else {
grow = 0
}
}
grow = xbnd
} else {
grow = math.Min(1, 1/math.Max(xbnd, smlnum))
for j := jfirst; j != jlast; j += jinc {
if grow <= smlnum {
goto Finish
}
grow *= 1 / (1 + cnorm[j])
}
}
} else {
if upper {
jfirst = 0
jlast = n - 1
jinc = 1
} else {
jfirst = n - 1
jlast = 0
jinc = -1
}
if tscal != 1 {
grow = 0
goto Finish
}
if nonUnit {
grow = 1 / (math.Max(xbnd, smlnum))
xbnd = grow
for j := jfirst; j != jlast; j += jinc {
if grow <= smlnum {
goto Finish
}
xj := 1 + cnorm[j]
grow = math.Min(grow, xbnd/xj)
tjj := math.Abs(a[j*lda+j])
if xj > tjj {
xbnd *= tjj / xj
}
}
grow = math.Min(grow, xbnd)
} else {
grow = math.Min(1, 1/math.Max(xbnd, smlnum))
for j := jfirst; j != jlast; j += jinc {
if grow <= smlnum {
goto Finish
}
xj := 1 + cnorm[j]
grow /= xj
}
}
}

Finish:
if grow*tscal > smlnum {
bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1)
// TODO(btracey): check if this else is everything
} else {
if xmax > bignum {
scale = bignum / xmax
bi.Dscal(n, scale, x, 1)
xmax = bignum
}
if noTrans {
for j := jfirst; j != jlast; j += jinc {
xj := math.Abs(x[j])
var tjjs float64
if nonUnit {
tjjs = a[j*lda+j] * tscal
} else {
tjjs = tscal
if tscal == 1 {
break
}
}
tjj := math.Abs(tjjs)
if tjj > smlnum {
if tjj < 1 {
if xj > tjj*bignum {
rec := 1 / xj
bi.Dscal(n, rec, x, 1)
scale *= rec
xmax *= rec
}
}
x[j] /= tjjs
xj = math.Abs(x[j])
} else if tjj > 0 {
if xj > tjj*bignum {
rec := (tjj * bignum) / xj
if cnorm[j] > 1 {
rec /= cnorm[j]
}
bi.Dscal(n, rec, x, 1)
scale *= rec
xmax *= rec
}
x[j] /= tjjs
xj = math.Abs(x[j])
} else {
for i := 0; i < n; i++ {
x[i] = 0
}
x[j] = 1
xj = 1
scale = 0
xmax = 0
}
if xj > 1 {
rec := 1 / xj
if cnorm[j] > (bignum-xmax)*rec {
rec *= 0.5
bi.Dscal(n, rec, x, 1)
scale *= rec
}
} else if xj*cnorm[j] > bignum-xmax {
bi.Dscal(n, 0.5, x, 1)
scale *= 0.5
}
if upper {
if j > 0 {
bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1)
i := bi.Idamax(j, x, 1)
xmax = math.Abs(x[i])
}
} else {
if j < n-1 {
bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1)
i := j + bi.Idamax(n-j-1, x[j+1:], 1)
xmax = math.Abs(x[i])
}
}
}
} else {
for j := jfirst; j != jlast; j += jinc {
xj := math.Abs(x[j])
uscal := tscal
rec := 1 / math.Max(xmax, 1)
var tjjs float64
if cnorm[j] > (bignum-xj)*rec {
rec *= 0.5
if nonUnit {
tjjs = a[j*lda+j] * tscal
} else {
tjjs = tscal
}
tjj := math.Abs(tjjs)
if tjj > 1 {
rec = math.Min(1, rec*tjj)
uscal /= tjjs
}
if rec < 1 {
bi.Dscal(n, rec, x, 1)
scale *= rec
xmax *= rec
}
}
var sumj float64
if uscal == 1 {
if upper {
sumj = bi.Ddot(j, a[j:], lda, x, 1)
} else if j < n-1 {
sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1)
}
} else {
if upper {
for i := 0; i < j; i++ {
sumj += (a[i*lda+j] * uscal) * x[i]
}
} else if j < n {
for i := j + 1; i < n; i++ {
sumj += (a[i*lda+j] * uscal) * x[i]
}
}
}
if uscal == tscal {
x[j] -= sumj
xj := math.Abs(x[j])
var tjjs float64
if nonUnit {
tjjs = a[j*lda+j] * tscal
} else {
tjjs = tscal
if tscal == 1 {
goto Out2
}
}
tjj := math.Abs(tjjs)
if tjj > smlnum {
if tjj < 1 {
if xj > tjj*bignum {
rec = 1 / xj
bi.Dscal(n, rec, x, 1)
scale *= rec
xmax *= rec
}
}
x[j] /= tjjs
} else if tjj > 0 {
if xj > tjj*bignum {
rec = (tjj * bignum) / xj
bi.Dscal(n, rec, x, 1)
scale *= rec
xmax *= rec
}
x[j] /= tjjs
} else {
for i := 0; i < n; i++ {
x[i] = 0
}
x[j] = 1
scale = 0
xmax = 0
}
} else {
x[j] = x[j]/tjjs - sumj
}
Out2:
xmax = math.Max(xmax, math.Abs(x[j]))
}
}
scale /= tscal
}
if tscal != 1 {
bi.Dscal(n, 1/tscal, cnorm, 1)
}
return scale
}
4 changes: 4 additions & 0 deletions native/general.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var _ lapack.Float64 = Implementation{}
// This list is duplicated in lapack/cgo. Keep in sync.
const (
absIncNotOne = "lapack: increment not one or negative one"
badDiag = "lapack: bad diag"
badDirect = "lapack: bad direct"
badIpiv = "lapack: insufficient permutation length"
badLdA = "lapack: index of a out of range"
Expand Down Expand Up @@ -79,6 +80,7 @@ func max(a, b int) int {
// TODO(btracey): Is there a better way to find the smallest number such that 1+E > 1

var dlamchE, dlamchS, dlamchP float64
var smlnum, bignum float64

func init() {
onePlusEps := math.Nextafter(1, math.Inf(1))
Expand All @@ -92,4 +94,6 @@ func init() {
dlamchS = sfmin
radix := 2.0
dlamchP = radix * eps
smlnum = dlamchS / dlamchP
bignum = 1 / smlnum
}