Skip to content

Commit

Permalink
blas/gonum: clean up and fix parameter checks
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Nov 14, 2018
1 parent df84973 commit 42795c7
Show file tree
Hide file tree
Showing 8 changed files with 580 additions and 318 deletions.
7 changes: 6 additions & 1 deletion blas/gonum/dgemm.go
Expand Up @@ -65,7 +65,7 @@ func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a
}

// Quick return if possible.
if m == 0 || n == 0 || ((alpha == 0 || k == 0) && beta == 1) {
if m == 0 || n == 0 {
return
}

Expand All @@ -91,6 +91,11 @@ func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a
panic(shortC)
}

// Quick return if possible.
if (alpha == 0 || k == 0) && beta == 1 {
return
}

// scale c
if beta != 1 {
if beta == 0 {
Expand Down
36 changes: 18 additions & 18 deletions blas/gonum/level1double.go
Expand Up @@ -23,7 +23,7 @@ func (Implementation) Dnrm2(n int, x []float64, incX int) float64 {
}
return 0
}
if incX > 0 && (n-1)*incX >= len(x) {
if len(x) <= (n-1)*incX {
panic(shortX)
}
if n < 2 {
Expand Down Expand Up @@ -97,7 +97,7 @@ func (Implementation) Dasum(n int, x []float64, incX int) float64 {
}
return 0
}
if incX > 0 && (n-1)*incX >= len(x) {
if len(x) <= (n-1)*incX {
panic(shortX)
}
if incX == 1 {
Expand All @@ -123,15 +123,15 @@ func (Implementation) Idamax(n int, x []float64, incX int) int {
}
return -1
}
if incX > 0 && (n-1)*incX >= len(x) {
if len(x) <= (n-1)*incX {
panic(shortX)
}
if n < 2 {
if n == 1 {
return 0
}
if n == 0 {
return -1 // Netlib returns invalid index when n == 0
return -1 // Netlib returns invalid index when n == 0.
}
panic(nLT0)
}
Expand Down Expand Up @@ -175,10 +175,10 @@ func (Implementation) Dswap(n int, x []float64, incX int, y []float64, incY int)
}
panic(nLT0)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
if (incX > 0 && len(x) <= (n-1)*incX) || (incX < 0 && len(x) <= (1-n)*incX) {
panic(shortX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
if (incY > 0 && len(y) <= (n-1)*incY) || (incY < 0 && len(y) <= (1-n)*incY) {
panic(shortY)
}
if incX == 1 && incY == 1 {
Expand Down Expand Up @@ -217,10 +217,10 @@ func (Implementation) Dcopy(n int, x []float64, incX int, y []float64, incY int)
}
panic(nLT0)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
if (incX > 0 && len(x) <= (n-1)*incX) || (incX < 0 && len(x) <= (1-n)*incX) {
panic(shortX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
if (incY > 0 && len(y) <= (n-1)*incY) || (incY < 0 && len(y) <= (1-n)*incY) {
panic(shortY)
}
if incX == 1 && incY == 1 {
Expand Down Expand Up @@ -256,10 +256,10 @@ func (Implementation) Daxpy(n int, alpha float64, x []float64, incX int, y []flo
}
panic(nLT0)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
if (incX > 0 && len(x) <= (n-1)*incX) || (incX < 0 && len(x) <= (1-n)*incX) {
panic(shortX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
if (incY > 0 && len(y) <= (n-1)*incY) || (incY < 0 && len(y) <= (1-n)*incY) {
panic(shortY)
}
if alpha == 0 {
Expand Down Expand Up @@ -289,7 +289,7 @@ func (Implementation) Daxpy(n int, alpha float64, x []float64, incX int, y []flo
// c = a/r, the cosine of the plane rotation
// s = b/r, the sine of the plane rotation
//
// NOTE: There is a discrepancy between the refence implementation and the BLAS
// NOTE: There is a discrepancy between the reference implementation and the BLAS
// technical manual regarding the sign for r when a or b are zero.
// Drotg agrees with the definition in the manual and other
// common BLAS implementations.
Expand Down Expand Up @@ -444,10 +444,10 @@ func (Implementation) Drot(n int, x []float64, incX int, y []float64, incY int,
}
panic(nLT0)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
if (incX > 0 && len(x) <= (n-1)*incX) || (incX < 0 && len(x) <= (1-n)*incX) {
panic(shortX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
if (incY > 0 && len(y) <= (n-1)*incY) || (incY < 0 && len(y) <= (1-n)*incY) {
panic(shortY)
}
if incX == 1 && incY == 1 {
Expand Down Expand Up @@ -488,10 +488,10 @@ func (Implementation) Drotm(n int, x []float64, incX int, y []float64, incY int,
}
panic(nLT0)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
if (incX > 0 && len(x) <= (n-1)*incX) || (incX < 0 && len(x) <= (1-n)*incX) {
panic(shortX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
if (incY > 0 && len(y) <= (n-1)*incY) || (incY < 0 && len(y) <= (1-n)*incY) {
panic(shortY)
}

Expand Down Expand Up @@ -590,15 +590,15 @@ func (Implementation) Dscal(n int, alpha float64, x []float64, incX int) {
}
return
}
if (n-1)*incX >= len(x) {
panic(shortX)
}
if n < 1 {
if n == 0 {
return
}
panic(nLT0)
}
if (n-1)*incX >= len(x) {
panic(shortX)
}
if alpha == 0 {
if incX == 1 {
x = x[:n]
Expand Down
36 changes: 18 additions & 18 deletions blas/gonum/level1single.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 42795c7

Please sign in to comment.