Skip to content
This repository was archived by the owner on Dec 10, 2018. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions mat64/index_bound_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,39 @@ func (m *Vector) set(r int, v float64) {
}
m.mat.Data[r*m.mat.Inc] = v
}

// At returns the element at row r and column c.
func (t *Symmetric) At(r, c int) float64 {
return t.at(r, c)
}

func (t *Symmetric) at(r, c int) float64 {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if r > c {
r, c = c, r
}
return t.mat.Data[r*t.mat.Stride+c]
}

// SetSym sets the elements at (r,c) and (c,r) to the value v.
func (t *Symmetric) SetSym(r, c int, v float64) {
t.set(r, c, v)
}

func (t *Symmetric) set(r, c int, v float64) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new line between funcs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if r > c {
r, c = c, r
}
t.mat.Data[r*t.mat.Stride+c] = v
}
36 changes: 36 additions & 0 deletions mat64/index_no_bound_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,39 @@ func (m *Vector) Set(r, c int, v float64) {
func (m *Vector) set(r int, v float64) {
m.mat.Data[r*m.mat.Inc] = v
}

// At returns the element at row r and column c.
func (t *Symmetric) At(r, c int) float64 {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
return t.at(r, c)
}

func (t *Symmetric) at(r, c int) float64 {
if r > c {
r, c = c, r
}
return t.mat.Data[r*t.mat.Stride+c]
}

// SetSym sets the elements at (r,c) and (c,r) to the value v.
func (t *Symmetric) SetSym(r, c int, v float64) {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
t.set(r, c, v)
}

func (t *Symmetric) set(r, c int, v float64) {
if r > c {
r, c = c, r
}
t.mat.Data[r*t.mat.Stride+c] = v
}
41 changes: 41 additions & 0 deletions mat64/symmetric.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package mat64

import (
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
)

// Symmetric represents a symmetric matrix.
type Symmetric struct {
mat blas64.Symmetric
}

// NewSymmetric constructs an n x n symmetric matrix. If len(mat) == n * n,
// mat will be used to hold the underlying data, or if mat == nil, new data will be allocated.
// The underlying data representation is the same as a Dense matrix, except
// the values of the entries in the lower triangular portion are completely ignored.
func NewSymmetric(n int, mat []float64) *Symmetric {
if n < 0 {
panic("mat64: negative dimension")
}
if mat != nil && n*n != len(mat) {
panic(ErrShape)
}
if mat == nil {
mat = make([]float64, n*n)
}
return &Symmetric{blas64.Symmetric{
N: n,
Stride: n,
Data: mat,
Uplo: blas.Upper,
}}
}

func (s *Symmetric) Dims() (r, c int) {
return s.mat.N, s.mat.N
}

func (s *Symmetric) RawSymmetric() blas64.Symmetric {
return s.mat
}
75 changes: 75 additions & 0 deletions mat64/symmetric_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package mat64

import (
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
"gopkg.in/check.v1"
)

func (s *S) TestNewSymmetric(c *check.C) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test for a panic with ErrShape?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for i, test := range []struct {
data []float64
N int
mat *Symmetric
}{
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check for matrix equality with a Dense?

data: []float64{
1, 2, 3,
4, 5, 6,
7, 8, 9,
},
N: 3,
mat: &Symmetric{blas64.Symmetric{
N: 3,
Stride: 3,
Uplo: blas.Upper,
Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
}},
},
} {
t := NewSymmetric(test.N, test.data)
rows, cols := t.Dims()
c.Check(rows, check.Equals, test.N, check.Commentf("Test %d", i))
c.Check(cols, check.Equals, test.N, check.Commentf("Test %d", i))
c.Check(t, check.DeepEquals, test.mat, check.Commentf("Test %d", i))

m := NewDense(test.N, test.N, test.data)
c.Check(t.mat.Data, check.DeepEquals, m.mat.Data, check.Commentf("Test %d", i))

c.Check(func() { NewSymmetric(3, []float64{1, 2}) }, check.PanicMatches, ErrShape.Error())
}
}

func (s *S) TestTriAtSet(c *check.C) {
t := &Symmetric{blas64.Symmetric{
N: 3,
Stride: 3,
Uplo: blas.Upper,
Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
}}
rows, cols := t.Dims()
// Check At out of bounds
c.Check(func() { t.At(rows, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.At(0, cols) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))
c.Check(func() { t.At(rows+1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.At(0, cols+1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))
c.Check(func() { t.At(-1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.At(0, -1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))

// Check Set out of bounds
c.Check(func() { t.SetSym(rows, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.SetSym(0, cols, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))
c.Check(func() { t.SetSym(rows+1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.SetSym(0, cols+1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))
c.Check(func() { t.SetSym(-1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.SetSym(0, -1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))

c.Check(t.At(2, 1), check.Equals, 6.0)
c.Check(t.At(1, 2), check.Equals, 6.0)
t.SetSym(1, 2, 15)
c.Check(t.At(2, 1), check.Equals, 15.0)
c.Check(t.At(1, 2), check.Equals, 15.0)
t.SetSym(2, 1, 12)
c.Check(t.At(2, 1), check.Equals, 12.0)
c.Check(t.At(1, 2), check.Equals, 12.0)
}