Skip to content

Commit

Permalink
blas/gonum: avoid bounds checks in Ztbmv
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Apr 25, 2018
1 parent 47734a4 commit 0802470
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions blas/gonum/level2cmplx128.go
Original file line number Diff line number Diff line change
Expand Up @@ -1273,52 +1273,60 @@ func (Implementation) Ztbmv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag
if uplo == blas.Upper {
if incX == 1 {
for i := 0; i < n; i++ {
xi := x[i]
if diag == blas.NonUnit {
x[i] *= ab[i*ldab]
xi *= ab[i*ldab]
}
kk := min(k, n-i-1)
for j, aij := range ab[i*ldab+1 : i*ldab+kk+1] {
x[i] += x[i+j+1] * aij
xi += x[i+j+1] * aij
}
x[i] = xi
}
} else {
ix := kx
for i := 0; i < n; i++ {
xi := x[ix]
if diag == blas.NonUnit {
x[ix] *= ab[i*ldab]
xi *= ab[i*ldab]
}
kk := min(k, n-i-1)
jx := ix + incX
for _, aij := range ab[i*ldab+1 : i*ldab+kk+1] {
x[ix] += x[jx] * aij
xi += x[jx] * aij
jx += incX
}
x[ix] = xi
ix += incX
}
}
} else {
if incX == 1 {
for i := n - 1; i >= 0; i-- {
xi := x[i]
if diag == blas.NonUnit {
x[i] *= ab[i*ldab+k]
xi *= ab[i*ldab+k]
}
kk := min(k, i)
for j, aij := range ab[i*ldab+k-kk : i*ldab+k] {
x[i] += x[i-kk+j] * aij
xi += x[i-kk+j] * aij
}
x[i] = xi
}
} else {
ix := kx + (n-1)*incX
for i := n - 1; i >= 0; i-- {
xi := x[ix]
if diag == blas.NonUnit {
x[ix] *= ab[i*ldab+k]
xi *= ab[i*ldab+k]
}
kk := min(k, i)
jx := ix - kk*incX
for _, aij := range ab[i*ldab+k-kk : i*ldab+k] {
x[ix] += x[jx] * aij
xi += x[jx] * aij
jx += incX
}
x[ix] = xi
ix -= incX
}
}
Expand All @@ -1328,8 +1336,9 @@ func (Implementation) Ztbmv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag
if incX == 1 {
for i := n - 1; i >= 0; i-- {
kk := min(k, n-i-1)
xi := x[i]
for j, aij := range ab[i*ldab+1 : i*ldab+kk+1] {
x[i+j+1] += x[i] * aij
x[i+j+1] += xi * aij
}
if diag == blas.NonUnit {
x[i] *= ab[i*ldab]
Expand All @@ -1340,8 +1349,9 @@ func (Implementation) Ztbmv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag
for i := n - 1; i >= 0; i-- {
kk := min(k, n-i-1)
jx := ix + incX
xi := x[ix]
for _, aij := range ab[i*ldab+1 : i*ldab+kk+1] {
x[jx] += x[ix] * aij
x[jx] += xi * aij
jx += incX
}
if diag == blas.NonUnit {
Expand All @@ -1354,8 +1364,9 @@ func (Implementation) Ztbmv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag
if incX == 1 {
for i := 0; i < n; i++ {
kk := min(k, i)
xi := x[i]
for j, aij := range ab[i*ldab+k-kk : i*ldab+k] {
x[i-kk+j] += x[i] * aij
x[i-kk+j] += xi * aij
}
if diag == blas.NonUnit {
x[i] *= ab[i*ldab+k]
Expand All @@ -1366,8 +1377,9 @@ func (Implementation) Ztbmv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag
for i := 0; i < n; i++ {
kk := min(k, i)
jx := ix - kk*incX
xi := x[ix]
for _, aij := range ab[i*ldab+k-kk : i*ldab+k] {
x[jx] += x[ix] * aij
x[jx] += xi * aij
jx += incX
}
if diag == blas.NonUnit {
Expand All @@ -1382,8 +1394,9 @@ func (Implementation) Ztbmv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag
if incX == 1 {
for i := n - 1; i >= 0; i-- {
kk := min(k, n-i-1)
xi := x[i]
for j, aij := range ab[i*ldab+1 : i*ldab+kk+1] {
x[i+j+1] += x[i] * cmplx.Conj(aij)
x[i+j+1] += xi * cmplx.Conj(aij)
}
if diag == blas.NonUnit {
x[i] *= cmplx.Conj(ab[i*ldab])
Expand All @@ -1394,8 +1407,9 @@ func (Implementation) Ztbmv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag
for i := n - 1; i >= 0; i-- {
kk := min(k, n-i-1)
jx := ix + incX
xi := x[ix]
for _, aij := range ab[i*ldab+1 : i*ldab+kk+1] {
x[jx] += x[ix] * cmplx.Conj(aij)
x[jx] += xi * cmplx.Conj(aij)
jx += incX
}
if diag == blas.NonUnit {
Expand All @@ -1408,8 +1422,9 @@ func (Implementation) Ztbmv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag
if incX == 1 {
for i := 0; i < n; i++ {
kk := min(k, i)
xi := x[i]
for j, aij := range ab[i*ldab+k-kk : i*ldab+k] {
x[i-kk+j] += x[i] * cmplx.Conj(aij)
x[i-kk+j] += xi * cmplx.Conj(aij)
}
if diag == blas.NonUnit {
x[i] *= cmplx.Conj(ab[i*ldab+k])
Expand All @@ -1420,8 +1435,9 @@ func (Implementation) Ztbmv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag
for i := 0; i < n; i++ {
kk := min(k, i)
jx := ix - kk*incX
xi := x[ix]
for _, aij := range ab[i*ldab+k-kk : i*ldab+k] {
x[jx] += x[ix] * cmplx.Conj(aij)
x[jx] += xi * cmplx.Conj(aij)
jx += incX
}
if diag == blas.NonUnit {
Expand Down

0 comments on commit 0802470

Please sign in to comment.