diff --git a/mat64/index_bound_checks.go b/mat64/index_bound_checks.go index ff06c71..2e03752 100644 --- a/mat64/index_bound_checks.go +++ b/mat64/index_bound_checks.go @@ -8,6 +8,8 @@ package mat64 +import "github.com/gonum/blas" + func (m *Dense) At(r, c int) float64 { return m.at(r, c) } @@ -63,3 +65,49 @@ func (m *Vector) set(r int, v float64) { } m.mat.Data[r*m.mat.Inc] = v } + +// At returns the element at row r and column c. +func (t *Triangular) At(r, c int) float64 { + return t.at(r, c) +} + +func (t *Triangular) at(r, c int) float64 { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + if t.mat.Uplo == blas.Upper { + if r > c { + return 0 + } + return t.mat.Data[r*t.mat.Stride+c] + } + if r < c { + return 0 + } + return t.mat.Data[r*t.mat.Stride+c] +} + +// Set sets the element at row r and column c. Set panics if the location is outside +// the appropriate half of the matrix. +func (t *Triangular) Set(r, c int, v float64) { + t.set(r, c, v) +} + +func (t *Triangular) set(r, c int, v float64) { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + if t.mat.Uplo == blas.Upper && r > c { + panic("mat64: triangular set out of bounds") + } + if t.mat.Uplo == blas.Lower && r < c { + panic("mat64: triangular set out of bounds") + } + t.mat.Data[r*t.mat.Stride+c] = v +} diff --git a/mat64/index_no_bound_checks.go b/mat64/index_no_bound_checks.go index 5ca231d..cbefbc6 100644 --- a/mat64/index_no_bound_checks.go +++ b/mat64/index_no_bound_checks.go @@ -8,6 +8,8 @@ package mat64 +import "github.com/gonum/blas" + func (m *Dense) At(r, c int) float64 { if r >= m.mat.Rows || r < 0 { panic(ErrRowAccess) @@ -63,3 +65,49 @@ func (m *Vector) Set(r, c int, v float64) { func (m *Vector) set(r int, v float64) { m.mat.Data[r*m.mat.Inc] = v } + +// At returns the element at row r and column c. +func (t *Triangular) At(r, c int) float64 { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + return t.at(r, c) +} + +func (t *Triangular) at(r, c int) float64 { + if t.mat.Uplo == blas.Upper { + if r > c { + return 0 + } + return t.mat.Data[r*t.mat.Stride+c] + } + if r < c { + return 0 + } + return t.mat.Data[r*t.mat.Stride+c] +} + +// Set sets the element at row r and column c. Set panics if the location is outside +// the appropriate half of the matrix. +func (t *Triangular) Set(r, c int, v float64) { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + if t.mat.Uplo == blas.Upper && r > c { + panic("mat64: triangular set out of bounds") + } + if t.mat.Uplo == blas.Lower && r < c { + panic("mat64: triangular set out of bounds") + } + t.set(r, c, v) +} + +func (t *Triangular) set(r, c int, v float64) { + t.mat.Data[r*t.mat.Stride+c] = v +} diff --git a/mat64/triangular.go b/mat64/triangular.go new file mode 100644 index 0000000..e8d6888 --- /dev/null +++ b/mat64/triangular.go @@ -0,0 +1,55 @@ +package mat64 + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +var ( + triangular *Triangular + + _ Matrix = triangular + _ Mutable = triangular +) + +// Triangular represents an upper or lower triangular matrix. +type Triangular struct { + mat blas64.Triangular +} + +// NewTriangular constructs an n x n triangular matrix. The constructed matrix +// is upper triangular if upper == true, and lower triangular otherwise. +// If len(mat) == n * n, mat will be used to hold the underlying data, or if +// mat == nil, new data will be allocated. +// The underlying data representation is the same as that of a Dense matrix, +// except the values of the entries in the opposite half are completely ignored. +func NewTriangular(n int, upper bool, mat []float64) *Triangular { + if n < 0 { + panic("mat64: negative dimension") + } + if mat != nil && n*n != len(mat) { + panic(ErrShape) + } + if mat == nil { + mat = make([]float64, n*n) + } + uplo := blas.Lower + if upper { + uplo = blas.Upper + } + return &Triangular{blas64.Triangular{ + N: n, + Stride: n, + Data: mat, + Uplo: uplo, + Diag: blas.NonUnit, + }} +} + +func (t *Triangular) Dims() (r, c int) { + return t.mat.N, t.mat.N +} + +func (t *Triangular) RawTriangular() blas64.Triangular { + return t.mat +} diff --git a/mat64/triangular_test.go b/mat64/triangular_test.go new file mode 100644 index 0000000..69ae862 --- /dev/null +++ b/mat64/triangular_test.go @@ -0,0 +1,76 @@ +package mat64 + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "gopkg.in/check.v1" +) + +func (s *S) TestNewTriangular(c *check.C) { + for i, test := range []struct { + data []float64 + N int + upper bool + mat *Triangular + }{ + { + data: []float64{ + 1, 2, 3, + 4, 5, 6, + 7, 8, 9, + }, + N: 3, + upper: true, + mat: &Triangular{blas64.Triangular{ + N: 3, + Stride: 3, + Uplo: blas.Upper, + Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + Diag: blas.NonUnit, + }}, + }, + } { + t := NewTriangular(test.N, test.upper, test.data) + rows, cols := t.Dims() + c.Check(rows, check.Equals, test.N, check.Commentf("Test %d", i)) + c.Check(cols, check.Equals, test.N, check.Commentf("Test %d", i)) + c.Check(t, check.DeepEquals, test.mat, check.Commentf("Test %d", i)) + } +} + +func (s *S) TestTriAtSet(c *check.C) { + t := &Triangular{blas64.Triangular{ + N: 3, + Stride: 3, + Uplo: blas.Upper, + Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + Diag: blas.NonUnit, + }} + rows, cols := t.Dims() + // Check At out of bounds + c.Check(func() { t.At(rows, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.At(0, cols) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.At(rows+1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.At(0, cols+1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.At(-1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.At(0, -1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + + // Check Set out of bounds + c.Check(func() { t.Set(rows, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.Set(0, cols, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.Set(rows+1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.Set(0, cols+1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.Set(-1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.Set(0, -1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.Set(2, 1, 1.2) }, check.PanicMatches, "mat64: triangular set out of bounds", check.Commentf("Test lower access")) + t.mat.Uplo = blas.Lower + c.Check(func() { t.Set(1, 2, 1.2) }, check.PanicMatches, "mat64: triangular set out of bounds", check.Commentf("Test upper access")) + + c.Check(t.At(2, 1), check.Equals, 8.0) + t.Set(2, 1, 15) + c.Check(t.At(2, 1), check.Equals, 15.0) + t.mat.Uplo = blas.Upper + c.Check(t.At(1, 2), check.Equals, 6.0) + t.Set(1, 2, 15) + c.Check(t.At(1, 2), check.Equals, 15.0) +}