diff --git a/mat64/index_bound_checks.go b/mat64/index_bound_checks.go index ff06c71..077517c 100644 --- a/mat64/index_bound_checks.go +++ b/mat64/index_bound_checks.go @@ -63,3 +63,39 @@ func (m *Vector) set(r int, v float64) { } m.mat.Data[r*m.mat.Inc] = v } + +// At returns the element at row r and column c. +func (t *Symmetric) At(r, c int) float64 { + return t.at(r, c) +} + +func (t *Symmetric) at(r, c int) float64 { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + if r > c { + r, c = c, r + } + return t.mat.Data[r*t.mat.Stride+c] +} + +// SetSym sets the elements at (r,c) and (c,r) to the value v. +func (t *Symmetric) SetSym(r, c int, v float64) { + t.set(r, c, v) +} + +func (t *Symmetric) set(r, c int, v float64) { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + if r > c { + r, c = c, r + } + t.mat.Data[r*t.mat.Stride+c] = v +} diff --git a/mat64/index_no_bound_checks.go b/mat64/index_no_bound_checks.go index 5ca231d..e1bf12a 100644 --- a/mat64/index_no_bound_checks.go +++ b/mat64/index_no_bound_checks.go @@ -63,3 +63,39 @@ func (m *Vector) Set(r, c int, v float64) { func (m *Vector) set(r int, v float64) { m.mat.Data[r*m.mat.Inc] = v } + +// At returns the element at row r and column c. +func (t *Symmetric) At(r, c int) float64 { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + return t.at(r, c) +} + +func (t *Symmetric) at(r, c int) float64 { + if r > c { + r, c = c, r + } + return t.mat.Data[r*t.mat.Stride+c] +} + +// SetSym sets the elements at (r,c) and (c,r) to the value v. +func (t *Symmetric) SetSym(r, c int, v float64) { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + t.set(r, c, v) +} + +func (t *Symmetric) set(r, c int, v float64) { + if r > c { + r, c = c, r + } + t.mat.Data[r*t.mat.Stride+c] = v +} diff --git a/mat64/symmetric.go b/mat64/symmetric.go new file mode 100644 index 0000000..4ac2d22 --- /dev/null +++ b/mat64/symmetric.go @@ -0,0 +1,41 @@ +package mat64 + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +// Symmetric represents a symmetric matrix. +type Symmetric struct { + mat blas64.Symmetric +} + +// NewSymmetric constructs an n x n symmetric matrix. If len(mat) == n * n, +// mat will be used to hold the underlying data, or if mat == nil, new data will be allocated. +// The underlying data representation is the same as a Dense matrix, except +// the values of the entries in the lower triangular portion are completely ignored. +func NewSymmetric(n int, mat []float64) *Symmetric { + if n < 0 { + panic("mat64: negative dimension") + } + if mat != nil && n*n != len(mat) { + panic(ErrShape) + } + if mat == nil { + mat = make([]float64, n*n) + } + return &Symmetric{blas64.Symmetric{ + N: n, + Stride: n, + Data: mat, + Uplo: blas.Upper, + }} +} + +func (s *Symmetric) Dims() (r, c int) { + return s.mat.N, s.mat.N +} + +func (s *Symmetric) RawSymmetric() blas64.Symmetric { + return s.mat +} diff --git a/mat64/symmetric_test.go b/mat64/symmetric_test.go new file mode 100644 index 0000000..fec1a00 --- /dev/null +++ b/mat64/symmetric_test.go @@ -0,0 +1,75 @@ +package mat64 + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "gopkg.in/check.v1" +) + +func (s *S) TestNewSymmetric(c *check.C) { + for i, test := range []struct { + data []float64 + N int + mat *Symmetric + }{ + { + data: []float64{ + 1, 2, 3, + 4, 5, 6, + 7, 8, 9, + }, + N: 3, + mat: &Symmetric{blas64.Symmetric{ + N: 3, + Stride: 3, + Uplo: blas.Upper, + Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }}, + }, + } { + t := NewSymmetric(test.N, test.data) + rows, cols := t.Dims() + c.Check(rows, check.Equals, test.N, check.Commentf("Test %d", i)) + c.Check(cols, check.Equals, test.N, check.Commentf("Test %d", i)) + c.Check(t, check.DeepEquals, test.mat, check.Commentf("Test %d", i)) + + m := NewDense(test.N, test.N, test.data) + c.Check(t.mat.Data, check.DeepEquals, m.mat.Data, check.Commentf("Test %d", i)) + + c.Check(func() { NewSymmetric(3, []float64{1, 2}) }, check.PanicMatches, ErrShape.Error()) + } +} + +func (s *S) TestTriAtSet(c *check.C) { + t := &Symmetric{blas64.Symmetric{ + N: 3, + Stride: 3, + Uplo: blas.Upper, + Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }} + rows, cols := t.Dims() + // Check At out of bounds + c.Check(func() { t.At(rows, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.At(0, cols) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.At(rows+1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.At(0, cols+1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.At(-1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.At(0, -1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + + // Check Set out of bounds + c.Check(func() { t.SetSym(rows, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.SetSym(0, cols, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.SetSym(rows+1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.SetSym(0, cols+1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.SetSym(-1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.SetSym(0, -1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + + c.Check(t.At(2, 1), check.Equals, 6.0) + c.Check(t.At(1, 2), check.Equals, 6.0) + t.SetSym(1, 2, 15) + c.Check(t.At(2, 1), check.Equals, 15.0) + c.Check(t.At(1, 2), check.Equals, 15.0) + t.SetSym(2, 1, 12) + c.Check(t.At(2, 1), check.Equals, 12.0) + c.Check(t.At(1, 2), check.Equals, 12.0) +}