Skip to content
This repository has been archived by the owner on Nov 24, 2018. It is now read-only.

Commit

Permalink
native: clean Dlasy2 for 1x2 and 2x1 cases
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Jun 9, 2016
1 parent ded3f05 commit 4fc6a17
Showing 1 changed file with 38 additions and 22 deletions.
60 changes: 38 additions & 22 deletions native/dlasy2.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ import (
//
// Dlasy2 is an internal routine. It is exported for testing purposes.
func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []float64, ldtl int, tr []float64, ldtr int, b []float64, ldb int, x []float64, ldx int) (scale, xnorm float64, ok bool) {
// TODO(vladimir-ch): Add input validation checks conditionally skipped
// using the build tag mechanism.

ok = true
// Quick return if possible.
if n1 == 0 || n2 == 0 {
Expand Down Expand Up @@ -58,11 +61,13 @@ func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []floa
return scale, xnorm, ok
}

bi := blas64.Implementation()

if n1+n2 == 3 {
// 1×2 or 2×1 case.
var (
smin float64
tmp [4]float64 // tmp is used as a 2×2 column-major matrix.
tmp [4]float64 // tmp is used as a 2×2 row-major matrix.
btmp [2]float64
)
if n1 == 1 && n2 == 2 {
Expand All @@ -75,11 +80,11 @@ func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []floa
tmp[0] = tl[0] + sgn*tr[0]
tmp[3] = tl[0] + sgn*tr[ldtr+1]
if tranr {
tmp[1] = sgn * tr[ldtr]
tmp[2] = sgn * tr[1]
} else {
tmp[1] = sgn * tr[1]
tmp[2] = sgn * tr[ldtr]
} else {
tmp[1] = sgn * tr[ldtr]
tmp[2] = sgn * tr[1]
}
btmp[0] = b[0]
btmp[1] = b[1]
Expand All @@ -93,35 +98,40 @@ func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []floa
tmp[0] = tl[0] + sgn*tr[0]
tmp[3] = tl[ldtl+1] + sgn*tr[0]
if tranl {
tmp[1] = tl[1]
tmp[2] = tl[ldtl]
} else {
tmp[1] = tl[ldtl]
tmp[2] = tl[1]
} else {
tmp[1] = tl[1]
tmp[2] = tl[ldtl]
}
btmp[0] = b[0]
btmp[1] = b[ldb]
}

// Solve 2×2 system using complete pivoting.
// Set pivots less than smin to smin.
bi := blas64.Implementation()

ipiv := bi.Idamax(len(tmp), tmp[:], 1)
// Compute the upper triangular matrix [u11 u12].
// [ 0 u22]
u11 := tmp[ipiv]
if math.Abs(u11) <= smin {
ok = false
u11 = smin
}
locu12 := [4]int{2, 3, 0, 1}
locu12 := [4]int{1, 0, 3, 2} // Index in tmp of the element on the same row as the pivot.
u12 := tmp[locu12[ipiv]]
locl21 := [4]int{1, 0, 3, 2}
locl21 := [4]int{2, 3, 0, 1} // Index in tmp of the element on the same column as the pivot.
l21 := tmp[locl21[ipiv]] / u11
locu22 := [4]int{3, 2, 1, 0}
u22 := tmp[locu22[ipiv]] - u12*l21
locu22 := [4]int{3, 2, 1, 0} // Index in tmp of the remaining element.
u22 := tmp[locu22[ipiv]] - l21*u12
if math.Abs(u22) <= smin {
ok = false
u22 = smin
}
if ipiv&0x1 != 0 { // true for ipiv equal to 1 and 3.
if ipiv&0x2 != 0 { // true for ipiv equal to 2 and 3.
// The pivot was in the second row, swap the elements of
// the right-hand side.
btmp[0], btmp[1] = btmp[1], btmp[0]-l21*btmp[1]
} else {
btmp[1] -= l21 * btmp[0]
Expand All @@ -132,17 +142,21 @@ func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []floa
btmp[0] *= scale
btmp[1] *= scale
}
x21 := btmp[1] / u22
x20 := btmp[0]/u11 - (u12/u11)*x21
if ipiv&0x2 != 0 { // true for ipiv equal to 2 and 3.
x20, x21 = x21, x20
// Solve the system [u11 u12] [x21] = [ btmp[0] ].
// [ 0 u22] [x22] [ btmp[1] ]
x22 := btmp[1] / u22
x21 := btmp[0]/u11 - (u12/u11)*x22
if ipiv&0x1 != 0 { // true for ipiv equal to 1 and 3.
// The pivot was in the second column, swap the elements
// of the solution.
x21, x22 = x22, x21
}
x[0] = x20
x[0] = x21
if n1 == 1 {
x[1] = x21
x[1] = x22
xnorm = math.Abs(x[0]) + math.Abs(x[1])
} else {
x[ldx] = x21
x[ldx] = x22
xnorm = math.Max(math.Abs(x[0]), math.Abs(x[ldx]))
}
return scale, xnorm, ok
Expand All @@ -160,7 +174,7 @@ func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []floa
smin = math.Max(smin, math.Max(math.Abs(tl[ldtl]), math.Abs(tl[ldtl+1])))
smin = math.Max(eps*smin, smlnum)

var t16 [16]float64 // t16 is used as a 4×4 column-major matrix.
var t16 [16]float64 // t16 is used as a 4×4 row-major matrix.
t16[0*4+0] = tl[0] + sgn*tr[0]
t16[1*4+1] = tl[ldtl+1] + sgn*tr[0]
t16[2*4+2] = tl[0] + sgn*tr[ldtr+1]
Expand Down Expand Up @@ -195,7 +209,6 @@ func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []floa
btmp[3] = b[ldb+1]

// Perform elimination.
bi := blas64.Implementation()
var jpiv [4]int
for i := 0; i < 3; i++ {
var (
Expand All @@ -212,10 +225,12 @@ func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []floa
}
}
if ipsv != i {
// Swap rows ipsv and i.
bi.Dswap(4, t16[ipsv*4:], 1, t16[i*4:], 1)
btmp[ipsv], btmp[i] = btmp[i], btmp[ipsv]
}
if jpsv != i {
// Swap columns jpsv and i.
bi.Dswap(4, t16[jpsv:], 4, t16[i:], 4)
}
jpiv[i] = jpsv
Expand All @@ -239,6 +254,7 @@ func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []floa
8*smlnum*math.Abs(btmp[1]) > math.Abs(t16[1*4+1]) ||
8*smlnum*math.Abs(btmp[2]) > math.Abs(t16[2*4+2]) ||
8*smlnum*math.Abs(btmp[3]) > math.Abs(t16[3*4+3]) {

maxbtmp := math.Max(math.Abs(btmp[0]), math.Abs(btmp[1]))
maxbtmp = math.Max(maxbtmp, math.Max(math.Abs(btmp[2]), math.Abs(btmp[3])))
scale = 1 / 8 / maxbtmp
Expand Down

0 comments on commit 4fc6a17

Please sign in to comment.