Skip to content
This repository was archived by the owner on Dec 10, 2018. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions mat64/index_bound_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package mat64

import "github.com/gonum/blas"

func (m *Dense) At(r, c int) float64 {
return m.at(r, c)
}
Expand Down Expand Up @@ -63,3 +65,49 @@ func (m *Vector) set(r int, v float64) {
}
m.mat.Data[r*m.mat.Inc] = v
}

// At returns the element at row r and column c.
func (t *Triangular) At(r, c int) float64 {
return t.at(r, c)
}

func (t *Triangular) at(r, c int) float64 {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if t.mat.Uplo == blas.Upper {
if r > c {
return 0
}
return t.mat.Data[r*t.mat.Stride+c]
}
if r < c {
return 0
}
return t.mat.Data[r*t.mat.Stride+c]
}

// Set sets the element at row r and column c. Set panics if the location is outside
// the appropriate half of the matrix.
func (t *Triangular) Set(r, c int, v float64) {
t.set(r, c, v)
}

func (t *Triangular) set(r, c int, v float64) {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if t.mat.Uplo == blas.Upper && r > c {
panic("mat64: triangular set out of bounds")
}
if t.mat.Uplo == blas.Lower && r < c {
panic("mat64: triangular set out of bounds")
}
t.mat.Data[r*t.mat.Stride+c] = v
}
48 changes: 48 additions & 0 deletions mat64/index_no_bound_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package mat64

import "github.com/gonum/blas"

func (m *Dense) At(r, c int) float64 {
if r >= m.mat.Rows || r < 0 {
panic(ErrRowAccess)
Expand Down Expand Up @@ -63,3 +65,49 @@ func (m *Vector) Set(r, c int, v float64) {
func (m *Vector) set(r int, v float64) {
m.mat.Data[r*m.mat.Inc] = v
}

// At returns the element at row r and column c.
func (t *Triangular) At(r, c int) float64 {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
return t.at(r, c)
}

func (t *Triangular) at(r, c int) float64 {
if t.mat.Uplo == blas.Upper {
if r > c {
return 0
}
return t.mat.Data[r*t.mat.Stride+c]
}
if r < c {
return 0
}
return t.mat.Data[r*t.mat.Stride+c]
}

// Set sets the element at row r and column c. Set panics if the location is outside
// the appropriate half of the matrix.
func (t *Triangular) Set(r, c int, v float64) {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if t.mat.Uplo == blas.Upper && r > c {
panic("mat64: triangular set out of bounds")
}
if t.mat.Uplo == blas.Lower && r < c {
panic("mat64: triangular set out of bounds")
}
t.set(r, c, v)
}

func (t *Triangular) set(r, c int, v float64) {
t.mat.Data[r*t.mat.Stride+c] = v
}
55 changes: 55 additions & 0 deletions mat64/triangular.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package mat64

import (
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
)

var (
triangular *Triangular

_ Matrix = triangular
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naive comment: isn't "cleaner" to just spell it out
_ Matrix = (*Triangular)(nil)

admittedly more verbose, but no "pollution" of the global scope.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following the pattern that exists there now. We can change it if there's a better way, but consistency is better.

_ Mutable = triangular
)

// Triangular represents an upper or lower triangular matrix.
type Triangular struct {
mat blas64.Triangular
}

// NewTriangular constructs an n x n triangular matrix. The constructed matrix
// is upper triangular if upper == true, and lower triangular otherwise.
// If len(mat) == n * n, mat will be used to hold the underlying data, or if
// mat == nil, new data will be allocated.
// The underlying data representation is the same as that of a Dense matrix,
// except the values of the entries in the opposite half are completely ignored.
func NewTriangular(n int, upper bool, mat []float64) *Triangular {
if n < 0 {
panic("mat64: negative dimension")
}
if mat != nil && n*n != len(mat) {
panic(ErrShape)
}
if mat == nil {
mat = make([]float64, n*n)
}
uplo := blas.Lower
if upper {
uplo = blas.Upper
}
return &Triangular{blas64.Triangular{
N: n,
Stride: n,
Data: mat,
Uplo: uplo,
Diag: blas.NonUnit,
}}
}

func (t *Triangular) Dims() (r, c int) {
return t.mat.N, t.mat.N
}

func (t *Triangular) RawTriangular() blas64.Triangular {
return t.mat
}
76 changes: 76 additions & 0 deletions mat64/triangular_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package mat64

import (
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
"gopkg.in/check.v1"
)

func (s *S) TestNewTriangular(c *check.C) {
for i, test := range []struct {
data []float64
N int
upper bool
mat *Triangular
}{
{
data: []float64{
1, 2, 3,
4, 5, 6,
7, 8, 9,
},
N: 3,
upper: true,
mat: &Triangular{blas64.Triangular{
N: 3,
Stride: 3,
Uplo: blas.Upper,
Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
Diag: blas.NonUnit,
}},
},
} {
t := NewTriangular(test.N, test.upper, test.data)
rows, cols := t.Dims()
c.Check(rows, check.Equals, test.N, check.Commentf("Test %d", i))
c.Check(cols, check.Equals, test.N, check.Commentf("Test %d", i))
c.Check(t, check.DeepEquals, test.mat, check.Commentf("Test %d", i))
}
}

func (s *S) TestTriAtSet(c *check.C) {
t := &Triangular{blas64.Triangular{
N: 3,
Stride: 3,
Uplo: blas.Upper,
Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
Diag: blas.NonUnit,
}}
rows, cols := t.Dims()
// Check At out of bounds
c.Check(func() { t.At(rows, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.At(0, cols) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))
c.Check(func() { t.At(rows+1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.At(0, cols+1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))
c.Check(func() { t.At(-1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.At(0, -1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))

// Check Set out of bounds
c.Check(func() { t.Set(rows, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.Set(0, cols, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))
c.Check(func() { t.Set(rows+1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.Set(0, cols+1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))
c.Check(func() { t.Set(-1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds"))
c.Check(func() { t.Set(0, -1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds"))
c.Check(func() { t.Set(2, 1, 1.2) }, check.PanicMatches, "mat64: triangular set out of bounds", check.Commentf("Test lower access"))
t.mat.Uplo = blas.Lower
c.Check(func() { t.Set(1, 2, 1.2) }, check.PanicMatches, "mat64: triangular set out of bounds", check.Commentf("Test upper access"))

c.Check(t.At(2, 1), check.Equals, 8.0)
t.Set(2, 1, 15)
c.Check(t.At(2, 1), check.Equals, 15.0)
t.mat.Uplo = blas.Upper
c.Check(t.At(1, 2), check.Equals, 6.0)
t.Set(1, 2, 15)
c.Check(t.At(1, 2), check.Equals, 15.0)
}