Skip to content
This repository has been archived by the owner on Dec 9, 2018. It is now read-only.

Commit

Permalink
Merge pull request #29 from gonum/fixdtbmv
Browse files Browse the repository at this point in the history
Fixdtbmv
  • Loading branch information
btracey committed Dec 11, 2014
2 parents be8a6a6 + d0c2b69 commit 6b4b681
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 138 deletions.
4 changes: 4 additions & 0 deletions cblas/level2double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ func TestDger(t *testing.T) {
testblas.DgerTest(t, blasser)
}

func TestDtbmv(t *testing.T) {
testblas.DtbmvTest(t, blasser)
}

func TestDtxmv(t *testing.T) {
testblas.DtxmvTest(t, blasser)
}
Expand Down
280 changes: 144 additions & 136 deletions goblas/level2double.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,19 +674,6 @@ func (b Blas) Dsymv(ul blas.Uplo, n int, alpha float64, a []float64, lda int, x
// where x is an n element vector and A is an n by n unit, or non-unit,
// upper or lower triangular band matrix.
func (Blas) Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int) {
// Verify inputs
// Transform for row major
if tA == blas.NoTrans {
tA = blas.Trans
} else {
tA = blas.NoTrans
}
if ul == blas.Upper {
ul = blas.Lower
} else {
ul = blas.Upper
}

if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
Expand Down Expand Up @@ -717,148 +704,169 @@ func (Blas) Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []fl
} else if incX != 1 {
kx = 0
}
_ = kx

nonunit := d != blas.Unit

if tA == blas.NoTrans {
if ul == blas.Upper {
if incX == 1 {
for j := 0; j < n; j++ {
if x[j] != 0 {
temp := x[j]
l := k - j
for i := max(0, j-k); i < j; i++ {
x[i] += temp * a[l+i+j*lda]
}
if d == blas.NonUnit {
x[j] *= a[k+j*lda]
}
}
}
} else {
jx := kx
for j := 0; j < n; j++ {
if x[jx] != 0 {
temp := x[jx]
ix := kx
l := k - j
for i := max(0, j-k); i < j; i++ {
x[ix] += temp * a[l+i+j*lda]
ix += incX
}
if d == blas.NonUnit {
x[jx] *= a[k+j*lda]
}
for i := 0; i < n; i++ {
u := min(1+k, n-i)
var sum float64
atmp := a[i*lda:]
xtmp := x[i:]
for j := 1; j < u; j++ {
sum += xtmp[j] * atmp[j]
}
jx += incX
if j >= k {
kx += incX
if nonunit {
sum += xtmp[0] * atmp[0]
} else {
sum += xtmp[0]
}
x[i] = sum
}
return
}
} else {

if incX == 1 {
for j := n - 1; j >= 0; j-- {
if x[j] != 0 {
temp := x[j]
l := -j
for i := min(n-1, j+k); i >= j+1; i-- {
x[i] += temp * a[l+i+j*lda]
}
if d == blas.NonUnit {
x[j] *= a[0+j*lda]
}
}
ix := kx
for i := 0; i < n; i++ {
u := min(1+k, n-i)
var sum float64
atmp := a[i*lda:]
jx := incX
for j := 1; j < u; j++ {
sum += x[ix+jx] * atmp[j]
jx += incX
}
} else {
kx += (n - 1) * incX
jx := kx
for j := n - 1; j >= 0; j-- {
if x[jx] != 0 {
temp := x[jx]
ix := kx
l := -j
for i := min(n-1, j+k); i >= j+1; i-- {
x[ix] += temp * a[l+i+j*lda]
ix -= incX
}
if d == blas.NonUnit {
x[jx] *= a[0+j*lda]
}
}
jx -= incX
if n-j > k {
kx -= incX
}
if nonunit {
sum += x[ix] * atmp[0]
} else {
sum += x[ix]
}
x[ix] = sum
ix += incX
}
return
}
} else {

if ul == blas.Upper {
if incX == 1 {
for j := n - 1; j >= 0; j-- {
temp := x[j]
l := k - j
if d == blas.NonUnit {
temp *= a[k+j*lda]
}
for i := j - 1; i >= max(0, j-k); i-- {
temp += a[l+i+j*lda] * x[i]
}
x[j] = temp
if incX == 1 {
for i := n - 1; i >= 0; i-- {
l := max(0, k-i)
atmp := a[i*lda:]
var sum float64
for j := l; j < k; j++ {
sum += x[i-k+j] * atmp[j]
}
} else {
kx += (n - 1) * incX
jx := kx
for j := n - 1; j >= 0; j-- {
temp := x[jx]
kx -= incX
ix := kx
l := k - j
if d == blas.NonUnit {
temp *= a[k+j*lda]
}
for i := j - 1; i >= max(0, j-k); i-- {
temp += a[l+i+j*lda] * x[ix]
ix -= incX
}
x[jx] = temp
jx -= incX
if nonunit {
sum += x[i] * atmp[k]
} else {
sum += x[i]
}
x[i] = sum
}
} else {

if incX == 1 {
for j := 0; j < n; j++ {
temp := x[j]
l := -j
if d == blas.NonUnit {
temp *= a[0+j*lda]
}
for i := j + 1; i < min(n, j+k+1); i++ {
temp += a[l+i+j*lda] * x[i]
}
x[j] = temp
}
return
}
ix := kx + (n-1)*incX
for i := n - 1; i >= 0; i-- {
l := max(0, k-i)
atmp := a[i*lda:]
var sum float64
jx := l * incX
for j := l; j < k; j++ {
sum += x[ix-k*incX+jx] * atmp[j]
jx += incX
}
if nonunit {
sum += x[ix] * atmp[k]
} else {
jx := kx
for j := 0; j < n; j++ {
temp := x[jx]
kx += incX
ix := kx
l := -j
if d == blas.NonUnit {
temp *= a[0+j*lda]
}
for i := j + 1; i < min(n, j+k+1); i++ {
temp += a[l+i+j*lda] * x[ix]
ix += incX
}
x[jx] = temp
jx += incX
sum += x[ix]
}
x[ix] = sum
ix -= incX
}
return
}
if ul == blas.Upper {
if incX == 1 {
for i := n - 1; i >= 0; i-- {
u := k + 1
if i < u {
u = i + 1
}
var sum float64
for j := 1; j < u; j++ {
sum += x[i-j] * a[(i-j)*lda+j]
}
if nonunit {
sum += x[i] * a[i*lda]
} else {
sum += x[i]
}
x[i] = sum
}
return
}
ix := kx + (n-1)*incX
for i := n - 1; i >= 0; i-- {
u := k + 1
if i < u {
u = i + 1
}
var sum float64
jx := incX
for j := 1; j < u; j++ {
sum += x[ix-jx] * a[(i-j)*lda+j]
jx += incX
}
if nonunit {
sum += x[ix] * a[i*lda]
} else {
sum += x[ix]
}
x[ix] = sum
ix -= incX
}
return
}
if incX == 1 {
for i := 0; i < n; i++ {
u := k
if i+k >= n {
u = n - i - 1
}
var sum float64
for j := 0; j < u; j++ {
sum += x[i+j+1] * a[(i+j+1)*lda+k-j-1]
}
if nonunit {
sum += x[i] * a[i*lda+k]
} else {
sum += x[i]
}
x[i] = sum
}
return
}
ix := kx
for i := 0; i < n; i++ {
u := k
if i+k >= n {
u = n - i - 1
}
var (
sum float64
jx int
)
for j := 0; j < u; j++ {
sum += x[ix+jx+incX] * a[(i+j+1)*lda+k-j-1]
jx += incX
}
if nonunit {
sum += x[ix] * a[i*lda+k]
} else {
sum += x[ix]
}
x[ix] = sum
ix += incX
}
}

Expand Down
4 changes: 4 additions & 0 deletions goblas/level2double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ func TestDtbsv(t *testing.T) {
func TestDsbmv(t *testing.T) {
testblas.DsbmvTest(t, blasser)
}

func TestDtbmv(t *testing.T) {
testblas.DtbmvTest(t, blasser)
}

0 comments on commit 6b4b681

Please sign in to comment.