Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
303 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
// Copyright ©2018 The Gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package mat | ||
|
||
import ( | ||
"gonum.org/v1/gonum/blas" | ||
"gonum.org/v1/gonum/blas/blas64" | ||
) | ||
|
||
var ( | ||
diagDense *DiagDense | ||
_ Matrix = diagDense | ||
_ Diagonal = diagDense | ||
_ MutableDiagonal = diagDense | ||
_ Triangular = diagDense | ||
_ Symmetric = diagDense | ||
_ Banded = diagDense | ||
_ RawBander = diagDense | ||
_ RawSymBander = diagDense | ||
) | ||
|
||
// Diagonal represents a diagonal matrix, that is a square matrix that only | ||
// has non-zero terms on the diagonal. | ||
type Diagonal interface { | ||
Matrix | ||
// Diag returns the number of rows/columns in the matrix | ||
Diag() int | ||
} | ||
|
||
// MutableDiagonal is a Diagonal matrix whose elements can be set. | ||
type MutableDiagonal interface { | ||
Diagonal | ||
SetDiag(i int, v float64) | ||
} | ||
|
||
// DiagDense represents a diagonal matrix in dense storage format. | ||
type DiagDense struct { | ||
data []float64 | ||
} | ||
|
||
// NewDiagonal creates a new Diagonal matrix with n rows and n columns. | ||
// The length of data must be n or data must be nil, otherwise NewDiagonal | ||
// will panic. | ||
func NewDiagonal(n int, data []float64) *DiagDense { | ||
if n < 0 { | ||
panic("mat: negative dimension") | ||
} | ||
if data == nil { | ||
data = make([]float64, n) | ||
} | ||
if len(data) != n { | ||
panic(ErrShape) | ||
} | ||
return &DiagDense{ | ||
data: data, | ||
} | ||
} | ||
|
||
// Diag returns the dimension of the receiver. | ||
func (d *DiagDense) Diag() int { | ||
return len(d.data) | ||
} | ||
|
||
// Dims returns the dimensions of the matrix. | ||
func (d *DiagDense) Dims() (r, c int) { | ||
return len(d.data), len(d.data) | ||
} | ||
|
||
// T returns the transpose of the matrix. | ||
func (d *DiagDense) T() Matrix { | ||
return d | ||
} | ||
|
||
// TTri returns the transpose of the matrix. Note that Diagonal matrices are | ||
// Upper by default | ||
func (d *DiagDense) TTri() Triangular { | ||
return TransposeTri{d} | ||
} | ||
|
||
func (d *DiagDense) TBand() Banded { | ||
return TransposeBand{d} | ||
} | ||
|
||
func (d *DiagDense) Bandwidth() (kl, ku int) { | ||
return 0, 0 | ||
} | ||
|
||
// Symmetric implements the Symmetric interface. | ||
func (d *DiagDense) Symmetric() int { | ||
return len(d.data) | ||
} | ||
|
||
// Triangle implements the Triangular interface. | ||
func (d *DiagDense) Triangle() (int, TriKind) { | ||
return len(d.data), Upper | ||
} | ||
|
||
func (d *DiagDense) RawBand() blas64.Band { | ||
return blas64.Band{ | ||
Rows: len(d.data), | ||
Cols: len(d.data), | ||
KL: 0, | ||
KU: 0, | ||
Stride: 1, | ||
Data: d.data, | ||
} | ||
} | ||
|
||
func (d *DiagDense) RawSymBand() blas64.SymmetricBand { | ||
return blas64.SymmetricBand{ | ||
N: len(d.data), | ||
K: 0, | ||
Stride: 1, | ||
Uplo: blas.Upper, | ||
Data: d.data, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// Copyright ©2018 The Gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package mat | ||
|
||
import ( | ||
"reflect" | ||
"testing" | ||
) | ||
|
||
func TestNewDiagonal(t *testing.T) { | ||
for i, test := range []struct { | ||
data []float64 | ||
n int | ||
mat *DiagDense | ||
dense *Dense | ||
}{ | ||
{ | ||
data: []float64{1, 2, 3, 4, 5, 6}, | ||
n: 6, | ||
mat: &DiagDense{ | ||
data: []float64{1, 2, 3, 4, 5, 6}, | ||
}, | ||
dense: NewDense(6, 6, []float64{ | ||
1, 0, 0, 0, 0, 0, | ||
0, 2, 0, 0, 0, 0, | ||
0, 0, 3, 0, 0, 0, | ||
0, 0, 0, 4, 0, 0, | ||
0, 0, 0, 0, 5, 0, | ||
0, 0, 0, 0, 0, 6, | ||
}), | ||
}, | ||
} { | ||
band := NewDiagonal(test.n, test.data) | ||
rows, cols := band.Dims() | ||
|
||
if rows != test.n { | ||
t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.n) | ||
} | ||
if cols != test.n { | ||
t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.n) | ||
} | ||
if !reflect.DeepEqual(band, test.mat) { | ||
t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", i, band, test.mat) | ||
} | ||
if !Equal(band, test.mat) { | ||
t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", i, band, test.mat) | ||
} | ||
if !Equal(band, test.dense) { | ||
t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", i, Formatted(band), Formatted(test.dense)) | ||
} | ||
} | ||
} | ||
|
||
func TestDiagonalAtSet(t *testing.T) { | ||
for _, n := range []int{1, 3, 8} { | ||
for _, nilstart := range []bool{true, false} { | ||
var diag *DiagDense | ||
if nilstart { | ||
diag = NewDiagonal(n, nil) | ||
} else { | ||
data := make([]float64, n) | ||
diag = NewDiagonal(n, data) | ||
// Test the data is used. | ||
for i := range data { | ||
data[i] = -float64(i) - 1 | ||
v := diag.At(i, i) | ||
if v != data[i] { | ||
t.Errorf("Diag shadow mismatch. Got %v, want %v", v, data[i]) | ||
} | ||
} | ||
} | ||
for i := 0; i < n; i++ { | ||
for j := 0; j < n; j++ { | ||
if i != j { | ||
if diag.At(i, j) != 0 { | ||
t.Errorf("Diag returned non-zero off diagonal element at %d, %d", i, j) | ||
} | ||
} | ||
v := float64(i) + 1 | ||
diag.SetDiag(i, v) | ||
v2 := diag.At(i, i) | ||
if v2 != v { | ||
t.Errorf("Diag at/set mismatch. Got %v, want %v", v, v2) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.