Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Commit

Permalink
Add symmetric Add function
Browse files Browse the repository at this point in the history
Remove unnecessary fmt statement
  • Loading branch information
btracey committed Feb 3, 2015
1 parent 02ed560 commit 39cd2ae
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
38 changes: 38 additions & 0 deletions mat64/symmetric.go
Expand Up @@ -75,6 +75,44 @@ func (s *SymDense) isZero() bool {
return s.mat.N == 0
}

func (s *SymDense) AddSym(a, b Symmetric) {
n := a.Symmetric()
if n != b.Symmetric() {
panic(ErrShape)
}
if s.isZero() {
s.mat = blas64.Symmetric{
N: n,
Stride: n,
Data: use(s.mat.Data, n*n),
Uplo: blas.Upper,
}
} else if s.mat.N != n {
panic(ErrShape)
}

if a, ok := a.(RawSymmmetricer); ok {
if b, ok := b.(RawSymmmetricer); ok {
amat, bmat := a.RawSymmetric(), b.RawSymmetric()
for i := 0; i < n; i++ {
btmp := bmat.Data[i*bmat.Stride+i : i*bmat.Stride+n]
stmp := s.mat.Data[i*s.mat.Stride+i : i*s.mat.Stride+n]
for j, v := range amat.Data[i*amat.Stride+i : i*amat.Stride+n] {
stmp[j] = v + btmp[j]
}
}
return
}
}

for i := 0; i < n; i++ {
stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
for j := i; j < n; j++ {
stmp[j] = a.At(i, j) + b.At(i, j)
}
}
}

func (s *SymDense) CopySym(a Symmetric) int {
n := a.Symmetric()
n = min(n, s.mat.N)
Expand Down
45 changes: 45 additions & 0 deletions mat64/symmetric_test.go
Expand Up @@ -76,6 +76,51 @@ func (s *S) TestSymAtSet(c *check.C) {
c.Check(t.At(1, 2), check.Equals, 12.0)
}

func (s *S) TestSymAdd(c *check.C) {
for _, test := range []struct {
n int
}{
{n: 1},
{n: 2},
{n: 3},
{n: 4},
{n: 5},
{n: 10},
} {
n := test.n
a := NewSymDense(n, nil)
for i := range a.mat.Data {
a.mat.Data[i] = rand.Float64()
}
b := NewSymDense(n, nil)
for i := range a.mat.Data {
b.mat.Data[i] = rand.Float64()
}
var m Dense
m.Add(a, b)

// Check with new receiver
var s SymDense
s.AddSym(a, b)
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
v := m.At(i, j)
c.Check(s.At(i, j), check.Equals, v)
}
}

// Check with equal receiver
s.CopySym(a)
s.AddSym(&s, b)
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
v := m.At(i, j)
c.Check(s.At(i, j), check.Equals, v)
}
}
}
}

func (s *S) TestCopy(c *check.C) {
for _, test := range []struct {
n int
Expand Down

0 comments on commit 39cd2ae

Please sign in to comment.