From c97817ffdf9b47a078a67cdf14e030d93dbc9394 Mon Sep 17 00:00:00 2001 From: btracey Date: Mon, 12 Jan 2015 19:46:13 -0800 Subject: [PATCH 1/4] Added Triangular type, At and Set methods, and tests --- mat64/index_bound_checks.go | 54 ++++++++++++++++++++++-- mat64/index_no_bound_checks.go | 54 ++++++++++++++++++++++-- mat64/matrix.go | 2 + mat64/triangular.go | 55 ++++++++++++++++++++++++ mat64/triangular_test.go | 76 ++++++++++++++++++++++++++++++++++ 5 files changed, 233 insertions(+), 8 deletions(-) create mode 100644 mat64/triangular.go create mode 100644 mat64/triangular_test.go diff --git a/mat64/index_bound_checks.go b/mat64/index_bound_checks.go index f518c74..2a7c299 100644 --- a/mat64/index_bound_checks.go +++ b/mat64/index_bound_checks.go @@ -8,16 +8,18 @@ package mat64 +import "github.com/gonum/blas" + func (m *Dense) At(r, c int) float64 { return m.at(r, c) } func (m *Dense) at(r, c int) float64 { if r >= m.mat.Rows || r < 0 { - panic("index error: row access out of bounds") + panic(ErrRowAccess) } if c >= m.mat.Cols || c < 0 { - panic("index error: column access out of bounds") + panic(ErrColAccess) } return m.mat.Data[r*m.mat.Stride+c] } @@ -28,10 +30,54 @@ func (m *Dense) Set(r, c int, v float64) { func (m *Dense) set(r, c int, v float64) { if r >= m.mat.Rows || r < 0 { - panic("index error: row access out of bounds") + panic(ErrRowAccess) } if c >= m.mat.Cols || c < 0 { - panic("index error: column access out of bounds") + panic(ErrColAccess) } m.mat.Data[r*m.mat.Stride+c] = v } + +// At returns the element at row r and column c. +func (t *Triangular) At(r, c int) float64 { + return t.at(r, c) +} + +func (t *Triangular) at(r, c int) float64 { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + if t.mat.Uplo == blas.Upper { + if r > c { + return 0 + } + return t.mat.Data[r*t.mat.Stride+c] + } + if r < c { + return 0 + } + return t.mat.Data[r*t.mat.Stride+c] +} + +// Set sets the element at row r and column c. Set panics if the +func (t *Triangular) Set(r, c int, v float64) { + t.set(r, c, v) +} + +func (t *Triangular) set(r, c int, v float64) { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + if t.mat.Uplo == blas.Upper && r > c { + panic("mat64: triangular set out of bounds") + } else if t.mat.Uplo == blas.Lower && r < c { + panic("mat64: triangular set out of bounds") + } + t.mat.Data[r*t.mat.Stride+c] = v +} diff --git a/mat64/index_no_bound_checks.go b/mat64/index_no_bound_checks.go index 1a6e76e..4c1f8ce 100644 --- a/mat64/index_no_bound_checks.go +++ b/mat64/index_no_bound_checks.go @@ -8,12 +8,14 @@ package mat64 +import "github.com/gonum/blas" + func (m *Dense) At(r, c int) float64 { if r >= m.mat.Rows || r < 0 { - panic("index error: row access out of bounds") + panic(ErrRowAccess) } if c >= m.mat.Cols || c < 0 { - panic("index error: column access out of bounds") + panic(ErrColAccess) } return m.at(r, c) } @@ -24,10 +26,10 @@ func (m *Dense) at(r, c int) float64 { func (m *Dense) Set(r, c int, v float64) { if r >= m.mat.Rows || r < 0 { - panic("index error: row access out of bounds") + panic(ErrRowAccess) } if c >= m.mat.Cols || c < 0 { - panic("index error: column access out of bounds") + panic(ErrColAccess) } m.set(r, c, v) } @@ -35,3 +37,47 @@ func (m *Dense) Set(r, c int, v float64) { func (m *Dense) set(r, c int, v float64) { m.mat.Data[r*m.mat.Stride+c] = v } + +// At returns the element at row r and column c. +func (t *Triangular) At(r, c int) float64 { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + return t.at(r, c) +} + +func (t *Triangular) at(r, c int) float64 { + if t.mat.Uplo == blas.Upper { + if r > c { + return 0 + } + return t.mat.Data[r*t.mat.Stride+c] + } + if r < c { + return 0 + } + return t.mat.Data[r*t.mat.Stride+c] +} + +// Set sets the element at row r and column c. Set panics if the +func (t *Triangular) Set(r, c int, v float64) { + if r >= t.mat.N || r < 0 { + panic(ErrRowAccess) + } + if c >= t.mat.N || c < 0 { + panic(ErrColAccess) + } + if t.mat.Uplo == blas.Upper && r > c { + panic("mat64: triangular set out of bounds") + } else if t.mat.Uplo == blas.Lower && r < c { + panic("mat64: triangular set out of bounds") + } + t.set(r, c, v) +} + +func (t *Triangular) set(r, c int, v float64) { + t.mat.Data[r*t.mat.Stride+c] = v +} diff --git a/mat64/matrix.go b/mat64/matrix.go index 178bf19..93db020 100644 --- a/mat64/matrix.go +++ b/mat64/matrix.go @@ -394,6 +394,8 @@ const ( ErrShape = Error("mat64: dimension mismatch") ErrIllegalStride = Error("mat64: illegal stride") ErrPivot = Error("mat64: malformed pivot list") + ErrRowAccess = Error("index error: row access out of bounds") + ErrColAccess = Error("index error: column access out of bounds") ) func min(a, b int) int { diff --git a/mat64/triangular.go b/mat64/triangular.go new file mode 100644 index 0000000..12426b6 --- /dev/null +++ b/mat64/triangular.go @@ -0,0 +1,55 @@ +package mat64 + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) + +var ( + triangular *Triangular + + _ Matrix = triangular + _ Mutable = triangular +) + +type TriType int + +const ( + Upper TriType = TriType(blas.Upper) + Lower TriType = TriType(blas.Lower) +) + +// A triangular matrix has the same underlying data representation as a Dense matrix +// but the entries that aren't in the populated half are completely ignored. + +// Triangular represents an upper or lower triangular matrix. +type Triangular struct { + mat blas64.Triangular +} + +func NewTriangular(n int, t TriType, mat []float64) *Triangular { + if mat != nil && n*n != len(mat) { + panic(ErrShape) + } + if mat == nil { + mat = make([]float64, n*n) + } + if t != Upper && t != Lower { + panic("mat64: bad TriSide") + } + return &Triangular{blas64.Triangular{ + N: n, + Stride: n, + Data: mat, + Uplo: blas.Uplo(t), + Diag: blas.NonUnit, + }} +} + +func (t *Triangular) Dims() (r, c int) { + return t.mat.N, t.mat.N +} + +func (t *Triangular) RawTriangular() blas64.Triangular { + return t.mat +} diff --git a/mat64/triangular_test.go b/mat64/triangular_test.go new file mode 100644 index 0000000..3ac2382 --- /dev/null +++ b/mat64/triangular_test.go @@ -0,0 +1,76 @@ +package mat64 + +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "gopkg.in/check.v1" +) + +func (s *S) TestNewTriangular(c *check.C) { + for i, test := range []struct { + data []float64 + N int + t TriType + mat *Triangular + }{ + { + data: []float64{ + 1, 2, 3, + 4, 5, 6, + 7, 8, 9, + }, + N: 3, + t: Upper, + mat: &Triangular{blas64.Triangular{ + N: 3, + Stride: 3, + Uplo: blas.Uplo(Upper), + Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + Diag: blas.NonUnit, + }}, + }, + } { + t := NewTriangular(test.N, test.t, test.data) + rows, cols := t.Dims() + c.Check(rows, check.Equals, test.N, check.Commentf("Test %d", i)) + c.Check(cols, check.Equals, test.N, check.Commentf("Test %d", i)) + c.Check(t, check.DeepEquals, test.mat, check.Commentf("Test %d", i)) + } +} + +func (s *S) TestTriAtSet(c *check.C) { + t := &Triangular{blas64.Triangular{ + N: 3, + Stride: 3, + Uplo: blas.Upper, + Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + Diag: blas.NonUnit, + }} + rows, cols := t.Dims() + // Check At out of bounds + c.Check(func() { t.At(rows, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.At(0, cols) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.At(rows+1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.At(0, cols+1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.At(-1, 0) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.At(0, -1) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + + // Check Set out of bounds + c.Check(func() { t.Set(rows, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.Set(0, cols, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.Set(rows+1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.Set(0, cols+1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.Set(-1, 0, 1.2) }, check.PanicMatches, ErrRowAccess.Error(), check.Commentf("Test row out of bounds")) + c.Check(func() { t.Set(0, -1, 1.2) }, check.PanicMatches, ErrColAccess.Error(), check.Commentf("Test col out of bounds")) + c.Check(func() { t.Set(2, 1, 1.2) }, check.PanicMatches, "mat64: triangular set out of bounds", check.Commentf("Test lower access")) + t.mat.Uplo = blas.Lower + c.Check(func() { t.Set(1, 2, 1.2) }, check.PanicMatches, "mat64: triangular set out of bounds", check.Commentf("Test upper access")) + + c.Check(t.At(2, 1), check.Equals, 8.0) + t.Set(2, 1, 15) + c.Check(t.At(2, 1), check.Equals, 15.0) + t.mat.Uplo = blas.Upper + c.Check(t.At(1, 2), check.Equals, 6.0) + t.Set(1, 2, 15) + c.Check(t.At(1, 2), check.Equals, 15.0) +} From bcf721a8b57a1eae43260e2a9efba6c69feb5d12 Mon Sep 17 00:00:00 2001 From: btracey Date: Mon, 12 Jan 2015 23:52:35 -0800 Subject: [PATCH 2/4] Minor fixes from comments --- mat64/index_bound_checks.go | 6 ++++-- mat64/index_no_bound_checks.go | 6 ++++-- mat64/triangular.go | 3 +++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/mat64/index_bound_checks.go b/mat64/index_bound_checks.go index 2a7c299..a379ce3 100644 --- a/mat64/index_bound_checks.go +++ b/mat64/index_bound_checks.go @@ -62,7 +62,8 @@ func (t *Triangular) at(r, c int) float64 { return t.mat.Data[r*t.mat.Stride+c] } -// Set sets the element at row r and column c. Set panics if the +// Set sets the element at row r and column c. Set panics if the location is outside +// the appropriate half of the matrix. func (t *Triangular) Set(r, c int, v float64) { t.set(r, c, v) } @@ -76,7 +77,8 @@ func (t *Triangular) set(r, c int, v float64) { } if t.mat.Uplo == blas.Upper && r > c { panic("mat64: triangular set out of bounds") - } else if t.mat.Uplo == blas.Lower && r < c { + } + if t.mat.Uplo == blas.Lower && r < c { panic("mat64: triangular set out of bounds") } t.mat.Data[r*t.mat.Stride+c] = v diff --git a/mat64/index_no_bound_checks.go b/mat64/index_no_bound_checks.go index 4c1f8ce..307470e 100644 --- a/mat64/index_no_bound_checks.go +++ b/mat64/index_no_bound_checks.go @@ -62,7 +62,8 @@ func (t *Triangular) at(r, c int) float64 { return t.mat.Data[r*t.mat.Stride+c] } -// Set sets the element at row r and column c. Set panics if the +// Set sets the element at row r and column c. Set panics if the location is outside +// the appropriate half of the matrix. func (t *Triangular) Set(r, c int, v float64) { if r >= t.mat.N || r < 0 { panic(ErrRowAccess) @@ -72,7 +73,8 @@ func (t *Triangular) Set(r, c int, v float64) { } if t.mat.Uplo == blas.Upper && r > c { panic("mat64: triangular set out of bounds") - } else if t.mat.Uplo == blas.Lower && r < c { + } + if t.mat.Uplo == blas.Lower && r < c { panic("mat64: triangular set out of bounds") } t.set(r, c, v) diff --git a/mat64/triangular.go b/mat64/triangular.go index 12426b6..8d50b8f 100644 --- a/mat64/triangular.go +++ b/mat64/triangular.go @@ -28,6 +28,9 @@ type Triangular struct { } func NewTriangular(n int, t TriType, mat []float64) *Triangular { + if n < 0 { + panic("mat64: negative dimension") + } if mat != nil && n*n != len(mat) { panic(ErrShape) } From 92c22c9bd57a46923f5919a054bcafa13052699e Mon Sep 17 00:00:00 2001 From: btracey Date: Tue, 13 Jan 2015 11:07:44 -0800 Subject: [PATCH 3/4] Added documentation to NewTriangular --- mat64/triangular.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mat64/triangular.go b/mat64/triangular.go index 8d50b8f..4a9dc46 100644 --- a/mat64/triangular.go +++ b/mat64/triangular.go @@ -19,14 +19,16 @@ const ( Lower TriType = TriType(blas.Lower) ) -// A triangular matrix has the same underlying data representation as a Dense matrix -// but the entries that aren't in the populated half are completely ignored. - // Triangular represents an upper or lower triangular matrix. type Triangular struct { mat blas64.Triangular } +// NewTriangular constructs an n x n triangular matrix with the given orientation. +// if len(mat) == n * n, mat will be used to hold the underlying data, or if +// mat == nil, new data will be allocated. +// The underlying data representation is the same as a Dense matrix, except +// the values of the entries in the opposite half are completely ignored. func NewTriangular(n int, t TriType, mat []float64) *Triangular { if n < 0 { panic("mat64: negative dimension") From f3c4b085493f7bff7d9b1604731bbeaacd033fc5 Mon Sep 17 00:00:00 2001 From: btracey Date: Thu, 15 Jan 2015 11:45:18 -0800 Subject: [PATCH 4/4] Removed TriType and made it a boolean --- mat64/triangular.go | 25 ++++++++++--------------- mat64/triangular_test.go | 16 ++++++++-------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/mat64/triangular.go b/mat64/triangular.go index 4a9dc46..e8d6888 100644 --- a/mat64/triangular.go +++ b/mat64/triangular.go @@ -12,24 +12,18 @@ var ( _ Mutable = triangular ) -type TriType int - -const ( - Upper TriType = TriType(blas.Upper) - Lower TriType = TriType(blas.Lower) -) - // Triangular represents an upper or lower triangular matrix. type Triangular struct { mat blas64.Triangular } -// NewTriangular constructs an n x n triangular matrix with the given orientation. -// if len(mat) == n * n, mat will be used to hold the underlying data, or if +// NewTriangular constructs an n x n triangular matrix. The constructed matrix +// is upper triangular if upper == true, and lower triangular otherwise. +// If len(mat) == n * n, mat will be used to hold the underlying data, or if // mat == nil, new data will be allocated. -// The underlying data representation is the same as a Dense matrix, except -// the values of the entries in the opposite half are completely ignored. -func NewTriangular(n int, t TriType, mat []float64) *Triangular { +// The underlying data representation is the same as that of a Dense matrix, +// except the values of the entries in the opposite half are completely ignored. +func NewTriangular(n int, upper bool, mat []float64) *Triangular { if n < 0 { panic("mat64: negative dimension") } @@ -39,14 +33,15 @@ func NewTriangular(n int, t TriType, mat []float64) *Triangular { if mat == nil { mat = make([]float64, n*n) } - if t != Upper && t != Lower { - panic("mat64: bad TriSide") + uplo := blas.Lower + if upper { + uplo = blas.Upper } return &Triangular{blas64.Triangular{ N: n, Stride: n, Data: mat, - Uplo: blas.Uplo(t), + Uplo: uplo, Diag: blas.NonUnit, }} } diff --git a/mat64/triangular_test.go b/mat64/triangular_test.go index 3ac2382..69ae862 100644 --- a/mat64/triangular_test.go +++ b/mat64/triangular_test.go @@ -8,10 +8,10 @@ import ( func (s *S) TestNewTriangular(c *check.C) { for i, test := range []struct { - data []float64 - N int - t TriType - mat *Triangular + data []float64 + N int + upper bool + mat *Triangular }{ { data: []float64{ @@ -19,18 +19,18 @@ func (s *S) TestNewTriangular(c *check.C) { 4, 5, 6, 7, 8, 9, }, - N: 3, - t: Upper, + N: 3, + upper: true, mat: &Triangular{blas64.Triangular{ N: 3, Stride: 3, - Uplo: blas.Uplo(Upper), + Uplo: blas.Upper, Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, Diag: blas.NonUnit, }}, }, } { - t := NewTriangular(test.N, test.t, test.data) + t := NewTriangular(test.N, test.upper, test.data) rows, cols := t.Dims() c.Check(rows, check.Equals, test.N, check.Commentf("Test %d", i)) c.Check(cols, check.Equals, test.N, check.Commentf("Test %d", i))