Skip to content
This repository has been archived by the owner on Dec 9, 2018. It is now read-only.

Commit

Permalink
Merge pull request #48 from gonum/adddspr2
Browse files Browse the repository at this point in the history
Adddspr2
  • Loading branch information
btracey committed Dec 21, 2014
2 parents aa37936 + df03edd commit c64865e
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 1 deletion.
4 changes: 4 additions & 0 deletions cblas/level2double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ func TestDsyr2(t *testing.T) {
testblas.Dsyr2Test(t, blasser)
}

func TestDspr2(t *testing.T) {
testblas.Dspr2Test(t, blasser)
}

func TestDspr(t *testing.T) {
testblas.DsprTest(t, blasser)
}
Expand Down
92 changes: 91 additions & 1 deletion goblas/level2double.go
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,96 @@ func (Blas) Dspr(ul blas.Uplo, n int, alpha float64, x []float64, incX int, a []
offset += i + 2
}
}

// Dsyr2 performs the symmetric rank-2 update
// a += alpha * x * y^T + alpha * y * x^T
// where a is in packed format.
func (Blas) Dspr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64) {
panic("referenceblas: function not implemented")
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if n < 0 {
panic(nLT0)
}
if incX == 0 || incY == 0 {
panic(zeroInc)
}

if len(a) < (n*(n+1))/2 {
panic("goblas: not enough data in a")
}
var ky, kx int
if incY > 0 {
ky = 0
} else {
ky = -(n - 1) * incY
}
if incX > 0 {
kx = 0
} else {
kx = -(n - 1) * incX
}
var offset int // Offset is the index of (i,i).
if ul == blas.Upper {
if incX == 1 && incY == 1 {
for i := 0; i < n; i++ {
atmp := a[offset:]
xi := x[i]
yi := y[i]
xtmp := x[i:n]
ytmp := y[i:n]
for j, v := range xtmp {
atmp[j] += alpha * (xi*ytmp[j] + v*yi)
}
offset += n - i
}
return
}
ix := kx
iy := ky
for i := 0; i < n; i++ {
jx := kx + i*incX
jy := ky + i*incY
atmp := a[offset:]
xi := x[ix]
yi := y[iy]
for j := 0; j < n-i; j++ {
atmp[j] += alpha * (xi*y[jy] + x[jx]*yi)
jx += incX
jy += incY
}
ix += incX
iy += incY
offset += n - i
}
return
}
if incX == 1 && incY == 1 {
for i := 0; i < n; i++ {
atmp := a[offset-i:]
xi := x[i]
yi := y[i]
xtmp := x[:i+1]
for j, v := range xtmp {
atmp[j] += alpha * (xi*y[j] + v*yi)
}
offset += i + 2
}
return
}
ix := kx
iy := ky
for i := 0; i < n; i++ {
jx := kx
jy := ky
atmp := a[offset-i:]
for j := 0; j <= i; j++ {
atmp[j] += alpha * (x[ix]*y[jy] + x[jx]*y[iy])
jx += incX
jy += incY
}
ix += incX
iy += incY
offset += i + 2
}
}
4 changes: 4 additions & 0 deletions goblas/level2double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ func TestDsyr2(t *testing.T) {
testblas.Dsyr2Test(t, blasser)
}

func TestDspr2(t *testing.T) {
testblas.Dspr2Test(t, blasser)
}

func TestDspr(t *testing.T) {
testblas.DsprTest(t, blasser)
}
Expand Down
76 changes: 76 additions & 0 deletions testblas/dspr2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package testblas

import (
"testing"

"github.com/gonum/blas"
"github.com/gonum/floats"
)

type Dspr2er interface {
Dspr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64)
}

func Dspr2Test(t *testing.T, blasser Dspr2er) {
for i, test := range []struct {
n int
a [][]float64
ul blas.Uplo
x []float64
y []float64
alpha float64
ans [][]float64
}{
{
n: 3,
a: [][]float64{
{7, 2, 4},
{0, 3, 5},
{0, 0, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Upper,
ans: [][]float64{
{47, 56, 72},
{0, 75, 95},
{0, 0, 118},
},
},
{
n: 3,
a: [][]float64{
{7, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Lower,
ans: [][]float64{
{47, 0, 0},
{56, 75, 0},
{72, 95, 118},
},
},
} {
incTest := func(incX, incY, extra int) {
aFlat := flattenTriangular(test.a, test.ul)
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
blasser.Dspr2(test.ul, test.n, test.alpha, x, incX, y, incY, aFlat)
ansFlat := flattenTriangular(test.ans, test.ul)
if !floats.EqualApprox(aFlat, ansFlat, 1e-14) {
t.Errorf("Case %v, incX = %v, incY = %v. Want %v, got %v.", i, incX, incY, ansFlat, aFlat)
}
}
incTest(1, 1, 0)
incTest(-2, 1, 0)
incTest(-2, 3, 0)
incTest(2, -3, 0)
incTest(3, -2, 0)
incTest(-3, -4, 0)
}
}

0 comments on commit c64865e

Please sign in to comment.