Skip to content

Commit

Permalink
blas/netlib: fix dgbmv bounds checks
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Jul 10, 2017
1 parent 7370c3d commit 8862e36
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
15 changes: 11 additions & 4 deletions blas/netlib/blas.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ const (
rowMajor order = 101 + iota
)

func min(a, b int) int {
if a < b {
return a
}
return b
}

func max(a, b int) int {
if a > b {
return a
Expand Down Expand Up @@ -1447,7 +1454,7 @@ func (Implementation) Sgbmv(tA blas.Transpose, m, n, kL, kU int, alpha float32,
if (incY > 0 && (lenY-1)*incY >= len(y)) || (incY < 0 && (1-lenY)*incY >= len(y)) {
panic("blas: y index out of range")
}
if lda*(m-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
if lda*(min(m, n+kL)-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
panic("blas: index of a out of range")
}
C.cblas_sgbmv(C.enum_CBLAS_ORDER(rowMajor), C.enum_CBLAS_TRANSPOSE(tA), C.int(m), C.int(n), C.int(kL), C.int(kU), C.float(alpha), (*C.float)(_a), C.int(lda), (*C.float)(_x), C.int(incX), C.float(beta), (*C.float)(_y), C.int(incY))
Expand Down Expand Up @@ -1820,7 +1827,7 @@ func (Implementation) Dgbmv(tA blas.Transpose, m, n, kL, kU int, alpha float64,
if (incY > 0 && (lenY-1)*incY >= len(y)) || (incY < 0 && (1-lenY)*incY >= len(y)) {
panic("blas: y index out of range")
}
if lda*(m-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
if lda*(min(m, n+kL)-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
panic("blas: index of a out of range")
}
C.cblas_dgbmv(C.enum_CBLAS_ORDER(rowMajor), C.enum_CBLAS_TRANSPOSE(tA), C.int(m), C.int(n), C.int(kL), C.int(kU), C.double(alpha), (*C.double)(_a), C.int(lda), (*C.double)(_x), C.int(incX), C.double(beta), (*C.double)(_y), C.int(incY))
Expand Down Expand Up @@ -2183,7 +2190,7 @@ func (Implementation) Cgbmv(tA blas.Transpose, m, n, kL, kU int, alpha complex64
if (incY > 0 && (lenY-1)*incY >= len(y)) || (incY < 0 && (1-lenY)*incY >= len(y)) {
panic("blas: y index out of range")
}
if lda*(m-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
if lda*(min(m, n+kL)-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
panic("blas: index of a out of range")
}
C.cblas_cgbmv(C.enum_CBLAS_ORDER(rowMajor), C.enum_CBLAS_TRANSPOSE(tA), C.int(m), C.int(n), C.int(kL), C.int(kU), unsafe.Pointer(&alpha), unsafe.Pointer(_a), C.int(lda), unsafe.Pointer(_x), C.int(incX), unsafe.Pointer(&beta), unsafe.Pointer(_y), C.int(incY))
Expand Down Expand Up @@ -2507,7 +2514,7 @@ func (Implementation) Zgbmv(tA blas.Transpose, m, n, kL, kU int, alpha complex12
if (incY > 0 && (lenY-1)*incY >= len(y)) || (incY < 0 && (1-lenY)*incY >= len(y)) {
panic("blas: y index out of range")
}
if lda*(m-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
if lda*(min(m, n+kL)-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
panic("blas: index of a out of range")
}
C.cblas_zgbmv(C.enum_CBLAS_ORDER(rowMajor), C.enum_CBLAS_TRANSPOSE(tA), C.int(m), C.int(n), C.int(kL), C.int(kU), unsafe.Pointer(&alpha), unsafe.Pointer(_a), C.int(lda), unsafe.Pointer(_x), C.int(incX), unsafe.Pointer(&beta), unsafe.Pointer(_y), C.int(incY))
Expand Down
9 changes: 8 additions & 1 deletion blas/netlib/generate_blas.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ func othersShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter)

switch {
case has["kL"] && has["kU"]:
fmt.Fprintf(buf, ` if lda*(m-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
fmt.Fprintf(buf, ` if lda*(min(m, n+kL)-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
panic("blas: index of a out of range")
}
`)
Expand Down Expand Up @@ -831,6 +831,13 @@ const (
rowMajor order = 101 + iota
)
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
Expand Down

0 comments on commit 8862e36

Please sign in to comment.