Skip to content

Commit

Permalink
Merge c014515 into 11453e6
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Jun 21, 2017
2 parents 11453e6 + c014515 commit cd708e3
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 0 deletions.
147 changes: 147 additions & 0 deletions mat/band.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright ©2017 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package mat

import (
"gonum.org/v1/gonum/blas/blas64"
)

var (
bandDense *BandDense
_ Matrix = bandDense
_ Band = bandDense
_ RawBand = bandDense
)

// BandDense represents a band matrix in dense storage format.
type BandDense struct {
mat blas64.Band
}

type Band interface {
Matrix
// Triangular returns the number of rows/columns in the matrix and its
// orientation.
Bandwidth() (kl, ku int)

// TBand is the equivalent of the T() method in the Matrix interface but
// guarantees the transpose is of band type.
TBand() Band
}

type RawBand interface {
RawBand() blas64.Band
}

var (
_ Matrix = TransposeBand{}
_ Band = TransposeBand{}
_ UntransposeBander = TransposeBand{}
)

// TransposeBand is a type for performing an implicit transpose of a Band
// matrix. It implements the Band interface, returning values from the
// transpose of the matrix within.
type TransposeBand struct {
Band Band
}

// At returns the value of the element at row i and column j of the transposed
// matrix, that is, row j and column i of the Band field.
func (t TransposeBand) At(i, j int) float64 {
return t.Band.At(j, i)
}

// Dims returns the dimensions of the transposed matrix. Band matrices are
// square and thus this is the same size as the original Triangular.
func (t TransposeBand) Dims() (r, c int) {
c, r = t.Band.Dims()
return r, c
}

// T performs an implicit transpose by returning the Band field.
func (t TransposeBand) T() Matrix {
return t.Band
}

// Bandwidth returns the number of rows/columns in the matrix and its orientation.
func (t TransposeBand) Bandwidth() (kl, ku int) {
kl, ku = t.Band.Bandwidth()
return ku, kl
}

// TBand performs an implicit transpose by returning the Band field.
func (t TransposeBand) TBand() Band {
return t.Band
}

// Untranspose returns the Band field.
func (t TransposeBand) Untranspose() Matrix {
return t.Band
}

func (t TransposeBand) UntransposeBand() Band {
return t.Band
}

// NewBandDense creates a new Band matrix with r rows and c columns. If data == nil,
// a new slice is allocated for the backing slice. If len(data) == r*(kl+ku+1),
// data is used as the backing slice, and changes to the elements of the returned BandDense
// will be reflected in data. If neither of these is true, NewBandDense will panic.
//
// The data must be arranged in row-major order. Only the values in the band portion
// of the matrix are used.
func NewBandDense(r, c, kl, ku int, data []float64) *BandDense {
if r < 0 || c < 0 || kl < 0 || ku < 0 {
panic("mat: negative dimension")
}
if kl+1 > r || ku+1 > c {
panic("mat: band out of range")
}
bc := kl + ku + 1
if data != nil && len(data) != r*bc {
panic(ErrShape)
}
if data == nil {
data = make([]float64, r*bc)
}
return &BandDense{
mat: blas64.Band{
Rows: r,
Cols: c,
KL: kl,
KU: ku,
Stride: bc,
Data: data,
},
}
}

// Dims returns the number of rows and columns in the matrix.
func (b *BandDense) Dims() (r, c int) {
return b.mat.Rows, b.mat.Cols
}

// Dims returns the upper and lower bandwidths of the matrix.
func (b *BandDense) Bandwidth() (kl, ku int) {
return b.mat.KL, b.mat.KU
}

// T performs an implicit transpose by returning the receiver inside a Transpose.
func (b *BandDense) T() Matrix {
return Transpose{b}
}

// TBand performs an implicit transpose by returning the receiver inside a TransposeBand.
func (b *BandDense) TBand() Band {
return TransposeBand{b}
}

// RawMatrix returns the underlying blas64.Band used by the receiver.
// Changes to elements in the receiver following the call will be reflected
// in returned blas64.Band.
func (b *BandDense) RawBand() blas64.Band {
return b.mat
}
144 changes: 144 additions & 0 deletions mat/band_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright ©2017 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package mat

import (
"reflect"
"testing"

"gonum.org/v1/gonum/blas/blas64"
)

func TestNewBand(t *testing.T) {
for i, test := range []struct {
data []float64
r, c int
kl, ku int
mat *BandDense
}{
{
data: []float64{
-1, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, -1,
19, 20, -1, -1,
},
r: 6, c: 6,
kl: 1, ku: 2,
mat: &BandDense{
mat: blas64.Band{
Rows: 6,
Cols: 6,
KL: 1,
KU: 2,
Stride: 4,
Data: []float64{
-1, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, -1,
19, 20, -1, -1,
},
},
},
},
} {
band := NewBandDense(test.r, test.c, test.kl, test.ku, test.data)
rows, cols := band.Dims()

if rows != test.r {
t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.r)
}
if cols != test.c {
t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.c)
}
if !reflect.DeepEqual(band, test.mat) {
t.Errorf("unexpected data slice for test %d: got: %v want: %v", i, band, test.mat)
}
}
}

func TestBandAtSet(t *testing.T) {
band := NewBandDense(6, 6, 1, 2, []float64{
-1, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, -1,
19, 20, -1, -1,
})
/*
rows, cols := band.Dims()
// Check At out of bounds
for _, row := range []int{-1, rows, rows + 1} {
panicked, message := panics(func() { band.At(row, 0) })
if !panicked || message != ErrRowAccess.Error() {
t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
}
}
for _, col := range []int{-1, cols, cols + 1} {
panicked, message := panics(func() { band.At(0, col) })
if !panicked || message != ErrColAccess.Error() {
t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
}
}
// Check Set out of bounds
for _, row := range []int{-1, rows, rows + 1} {
panicked, message := panics(func() { band.SetBand(row, 0, 1.2) })
if !panicked || message != ErrRowAccess.Error() {
t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
}
}
for _, col := range []int{-1, cols, cols + 1} {
panicked, message := panics(func() { band.SetBand(0, col, 1.2) })
if !panicked || message != ErrColAccess.Error() {
t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
}
}
*/
for _, st := range []struct {
row, col int
}{
{row: 0, col: 3},
{row: 0, col: 4},
{row: 0, col: 5},
{row: 1, col: 4},
{row: 1, col: 5},
{row: 2, col: 5},
{row: 2, col: 0},
{row: 3, col: 1},
{row: 4, col: 2},
{row: 5, col: 3},
} {
panicked, message := panics(func() { band.SetBand(st.row, st.col, 1.2) })
if !panicked || message != ErrBandSet.Error() {
t.Errorf("expected panic for %+v %s", st, message)
}
}
/*
for _, st := range []struct {
row, col int
uplo blas.Uplo
orig, new float64
}{
{row: 2, col: 1, uplo: blas.Lower, orig: 8, new: 15},
{row: 1, col: 2, uplo: blas.Upper, orig: 6, new: 15},
} {
tri.mat.Uplo = st.uplo
if e := tri.At(st.row, st.col); e != st.orig {
t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.row, st.col, e, st.orig)
}
tri.SetTri(st.row, st.col, st.new)
if e := tri.At(st.row, st.col); e != st.new {
t.Errorf("unexpected value for At(%d, %d) after SetTri(%[1]d, %d, %v): got: %v want: %[3]v", st.row, st.col, st.new, e)
}
}
*/
}
1 change: 1 addition & 0 deletions mat/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ var (
ErrPivot = Error{"matrix: malformed pivot list"}
ErrTriangle = Error{"matrix: triangular storage mismatch"}
ErrTriangleSet = Error{"matrix: triangular set out of bounds"}
ErrBandSet = Error{"matrix: band set out of bounds"}
ErrSliceLengthMismatch = Error{"matrix: input slice length mismatch"}
ErrNotPSD = Error{"matrix: input not positive symmetric definite"}
ErrFailedEigen = Error{"matrix: eigendecomposition not successful"}
Expand Down
39 changes: 39 additions & 0 deletions mat/index_bound_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,42 @@ func (t *TriDense) set(i, j int, v float64) {
}
t.mat.Data[i*t.mat.Stride+j] = v
}

// At returns the element at row i, column j.
func (b *BandDense) At(i, j int) float64 {
return b.at(i, j)
}

func (b *BandDense) at(i, j int) float64 {
if uint(i) >= uint(b.mat.Rows) {
panic(ErrRowAccess)
}
if uint(j) >= uint(b.mat.Cols) {
panic(ErrColAccess)
}
pj := j + b.mat.KL - i
if pj < 0 || b.mat.KL+b.mat.KU+1 <= pj {
return 0
}
return b.mat.Data[i*b.mat.Stride+pj]
}

// SetBand sets the element at row i, column j to the value v.
// It panics if the location is outside the appropriate region of the matrix.
func (b *BandDense) SetBand(i, j int, v float64) {
b.set(i, j, v)
}

func (b *BandDense) set(i, j int, v float64) {
if uint(i) >= uint(b.mat.Rows) {
panic(ErrRowAccess)
}
if uint(j) >= uint(b.mat.Cols) {
panic(ErrColAccess)
}
pj := j + b.mat.KL - i
if pj < 0 || b.mat.KL+b.mat.KU+1 <= pj {
panic(ErrBandSet)
}
b.mat.Data[i*b.mat.Stride+pj] = v
}
40 changes: 40 additions & 0 deletions mat/index_no_bound_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,43 @@ func (t *TriDense) SetTri(i, j int, v float64) {
func (t *TriDense) set(i, j int, v float64) {
t.mat.Data[i*t.mat.Stride+j] = v
}

// At returns the element at row i, column j.
func (b *BandDense) At(i, j int) float64 {
if uint(i) >= uint(b.mat.Rows) {
panic(ErrRowAccess)
}
if uint(j) >= uint(b.mat.Cols) {
panic(ErrColAccess)
}
return b.at(i, j)
}

func (b *BandDense) at(i, j int) float64 {
pj := j + b.mat.KL - i
if pj < 0 || b.mat.KL+b.mat.KU+1 <= pj {
return 0
}
return b.mat.Data[i*b.mat.Stride+pj]
}

// SetBand sets the element at row i, column j to the value v.
// It panics if the location is outside the appropriate region of the matrix.
func (b *BandDense) SetBand(i, j int, v float64) {
if uint(i) >= uint(b.mat.Rows) {
panic(ErrRowAccess)
}
if uint(j) >= uint(b.mat.Cols) {
panic(ErrColAccess)
}
pj := j + b.mat.KL - i
if pj < 0 || b.mat.KL+b.mat.KU+1 <= pj {
panic(ErrBandSet)
}
b.set(i, j, v)
}

func (b *BandDense) set(i, j int, v float64) {
pj := j + b.mat.KL - i
b.mat.Data[i*b.mat.Stride+pj] = v
}

0 comments on commit cd708e3

Please sign in to comment.