Skip to content
This repository has been archived by the owner on Dec 9, 2018. It is now read-only.

Commit

Permalink
Fixed dsymv implementation and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
btracey committed Dec 15, 2014
1 parent 1f56c5d commit 41c4018
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 37 deletions.
4 changes: 4 additions & 0 deletions cblas/level2double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ func TestDtbsv(t *testing.T) {
func TestDsbmv(t *testing.T) {
testblas.DsbmvTest(t, blasser)
}

func TestDsymv(t *testing.T) {
testblas.DsymvTest(t, blasser)
}
117 changes: 80 additions & 37 deletions goblas/level2double.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,69 +616,112 @@ func (b Blas) Dsymv(ul blas.Uplo, n int, alpha float64, a []float64, lda int, x
// Set up start points
var kx, ky int
if incX > 0 {
kx = 1
kx = 0
} else {
kx = -(n - 1) * incX
}
if incY > 0 {
ky = 1
ky = 0
} else {
ky = -(n - 1) * incY
}

// Form y = beta * y
if beta != 1 {
b.Dscal(n, beta, y, incY)
if incY > 0 {
b.Dscal(n, beta, y, incY)
} else {
b.Dscal(n, beta, y, -incY)
}
}

if alpha == 0 {
return
}

// TODO: Need to think about changing the major and minor
// looping when row major (help with cache misses)
if n == 1 {
y[0] += alpha * a[0] * x[0]
return
}

// Form y = Ax + y
switch {
default:
panic("goblas: unreachable")
case ul == blas.Upper:
jx := kx
jy := ky
for j := 0; j < n; j++ {
tmp1 := alpha * x[jx]
var tmp2 float64
ix := kx
if ul == blas.Upper {
if incX == 1 {
iy := ky
for i := 0; i < j-2; i++ {
y[iy] += tmp1 * a[i*lda+j]
tmp2 += a[i*lda+j] * x[ix]
ix += incX
for i := 0; i < n; i++ {
xv := x[i] * alpha
sum := x[i] * a[i*lda+i]
jy := ky + (i+1)*incY
atmp := a[i*lda+i+1 : i*lda+n]
for j, v := range atmp {
jp := j + i + 1
sum += x[jp] * v
y[jy] += xv * v
jy += incY
}
y[iy] += alpha * sum
iy += incY
}
y[jy] += tmp1*a[j*lda+j] + alpha*tmp2
jx += incX
jy += incY
return
}
ix := kx
iy := ky
for i := 0; i < n; i++ {
xv := x[ix] * alpha
sum := x[ix] * a[i*lda+i]
jx := kx + (i+1)*incX
jy := ky + (i+1)*incY
atmp := a[i*lda+i+1 : i*lda+n]
for _, v := range atmp {
sum += x[jx] * v
y[jy] += xv * v
jx += incX
jy += incY
}
y[iy] += alpha * sum
ix += incX
iy += incY
}
case ul == blas.Lower:
return
}
// Cases where a is lower triangular.
if incX == 1 {
iy := ky
for i := 0; i < n; i++ {
jy := ky
xv := alpha * x[i]
atmp := a[i*lda : i*lda+i]
var sum float64
for j, v := range atmp {
sum += x[j] * v
y[jy] += xv * v
jy += incY
}
sum += x[i] * a[i*lda+i]
sum *= alpha
y[iy] += sum
iy += incY
}
return
}
ix := kx
iy := ky
for i := 0; i < n; i++ {
jx := kx
jy := ky
for j := 0; j < n; j++ {
tmp1 := alpha * x[jx]
var tmp2 float64
y[jy] += tmp1 * a[j*lda+j]
ix := jx
iy := jy
for i := j; i < n; i++ {
ix += incX
iy += incY
y[iy] += tmp1 * a[i*lda+j]
tmp2 += a[i*lda+j] * x[ix]
}
y[jy] += alpha * tmp2
xv := alpha * x[ix]
atmp := a[i*lda : i*lda+i]
var sum float64
for _, v := range atmp {
sum += x[jx] * v
y[jy] += xv * v
jx += incX
jy += incY
}
sum += x[ix] * a[i*lda+i]
sum *= alpha
y[iy] += sum
ix += incX
iy += incY
}
}

Expand Down
4 changes: 4 additions & 0 deletions goblas/level2double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ func TestDsbmv(t *testing.T) {
func TestDtbmv(t *testing.T) {
testblas.DtbmvTest(t, blasser)
}

func TestDsymv(t *testing.T) {
testblas.DsymvTest(t, blasser)
}
73 changes: 73 additions & 0 deletions testblas/dsymv.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package testblas

import (
"testing"

"github.com/gonum/blas"
"github.com/gonum/floats"
)

type Dsymver interface {
Dsymv(ul blas.Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
}

func DsymvTest(t *testing.T, blasser Dsymver) {
for i, test := range []struct {
ul blas.Uplo
n int
a [][]float64
x []float64
y []float64
alpha float64
beta float64
ans []float64
}{
{
ul: blas.Upper,
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 8, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
{
ul: blas.Lower,
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 8, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
} {
incTest := func(incX, incY, extra int) {
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
aFlat := flatten(test.a)
ans := makeIncremented(test.ans, incY, extra)

blasser.Dsymv(test.ul, test.n, test.alpha, aFlat, test.n, x, incX, test.beta, y, incY)
if !floats.EqualApprox(ans, y, 1e-14) {
t.Errorf("Case %v, incX=%v, incY=%v: Want %v, got %v.", i, incX, incY, ans, y)
}
}
incTest(1, 1, 0)
incTest(2, 3, 0)
incTest(3, 2, 0)
incTest(-3, 2, 0)
incTest(-2, 4, 0)
incTest(2, -1, 0)
incTest(-3, -4, 3)
}
}

0 comments on commit 41c4018

Please sign in to comment.