From 39cd2aeb18bdd0a4001a2fc0cb40ff331811f995 Mon Sep 17 00:00:00 2001 From: btracey Date: Mon, 2 Feb 2015 19:43:47 -0800 Subject: [PATCH] Add symmetric Add function Remove unnecessary fmt statement --- mat64/symmetric.go | 38 ++++++++++++++++++++++++++++++++++ mat64/symmetric_test.go | 45 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/mat64/symmetric.go b/mat64/symmetric.go index 9e73f7d..677eaaa 100644 --- a/mat64/symmetric.go +++ b/mat64/symmetric.go @@ -75,6 +75,44 @@ func (s *SymDense) isZero() bool { return s.mat.N == 0 } +func (s *SymDense) AddSym(a, b Symmetric) { + n := a.Symmetric() + if n != b.Symmetric() { + panic(ErrShape) + } + if s.isZero() { + s.mat = blas64.Symmetric{ + N: n, + Stride: n, + Data: use(s.mat.Data, n*n), + Uplo: blas.Upper, + } + } else if s.mat.N != n { + panic(ErrShape) + } + + if a, ok := a.(RawSymmmetricer); ok { + if b, ok := b.(RawSymmmetricer); ok { + amat, bmat := a.RawSymmetric(), b.RawSymmetric() + for i := 0; i < n; i++ { + btmp := bmat.Data[i*bmat.Stride+i : i*bmat.Stride+n] + stmp := s.mat.Data[i*s.mat.Stride+i : i*s.mat.Stride+n] + for j, v := range amat.Data[i*amat.Stride+i : i*amat.Stride+n] { + stmp[j] = v + btmp[j] + } + } + return + } + } + + for i := 0; i < n; i++ { + stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n] + for j := i; j < n; j++ { + stmp[j] = a.At(i, j) + b.At(i, j) + } + } +} + func (s *SymDense) CopySym(a Symmetric) int { n := a.Symmetric() n = min(n, s.mat.N) diff --git a/mat64/symmetric_test.go b/mat64/symmetric_test.go index a50ed8e..9ca5541 100644 --- a/mat64/symmetric_test.go +++ b/mat64/symmetric_test.go @@ -76,6 +76,51 @@ func (s *S) TestSymAtSet(c *check.C) { c.Check(t.At(1, 2), check.Equals, 12.0) } +func (s *S) TestSymAdd(c *check.C) { + for _, test := range []struct { + n int + }{ + {n: 1}, + {n: 2}, + {n: 3}, + {n: 4}, + {n: 5}, + {n: 10}, + } { + n := test.n + a := NewSymDense(n, nil) + for i := range a.mat.Data { + a.mat.Data[i] = rand.Float64() + } + b := NewSymDense(n, nil) + for i := range a.mat.Data { + b.mat.Data[i] = rand.Float64() + } + var m Dense + m.Add(a, b) + + // Check with new receiver + var s SymDense + s.AddSym(a, b) + for i := 0; i < n; i++ { + for j := i; j < n; j++ { + v := m.At(i, j) + c.Check(s.At(i, j), check.Equals, v) + } + } + + // Check with equal receiver + s.CopySym(a) + s.AddSym(&s, b) + for i := 0; i < n; i++ { + for j := i; j < n; j++ { + v := m.At(i, j) + c.Check(s.At(i, j), check.Equals, v) + } + } + } +} + func (s *S) TestCopy(c *check.C) { for _, test := range []struct { n int