Skip to content

Commit

Permalink
mat: provide internal slice methods that returns the concrete types
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Dec 27, 2019
1 parent ad4f952 commit 5127c36
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion mat/cholesky_test.go
Expand Up @@ -505,7 +505,7 @@ func TestCholeskyExtendVecSym(t *testing.T) {
},
} {
n := test.a.Symmetric()
as := test.a.SliceSym(0, n-1).(*SymDense)
as := test.a.sliceSym(0, n-1)

// Compute the full factorization to use later (do the full factorization
// first to ensure the matrix is positive definite).
Expand Down
8 changes: 6 additions & 2 deletions mat/dense.go
Expand Up @@ -318,6 +318,10 @@ func (m *Dense) DiagView() Diagonal {
// Slice panics with ErrIndexOutOfRange if the slice is outside the capacity
// of the receiver.
func (m *Dense) Slice(i, k, j, l int) Matrix {
return m.slice(i, k, j, l)
}

func (m *Dense) slice(i, k, j, l int) *Dense {
mr, mc := m.Caps()
if i < 0 || mr <= i || j < 0 || mc <= j || k < i || mr < k || l < j || mc < l {
if i == k || j == l {
Expand Down Expand Up @@ -545,7 +549,7 @@ func (m *Dense) Stack(a, b Matrix) {
m.reuseAsNonZeroed(ar+br, ac)

m.Copy(a)
w := m.Slice(ar, ar+br, 0, bc).(*Dense)
w := m.slice(ar, ar+br, 0, bc)
w.Copy(b)
}

Expand All @@ -563,7 +567,7 @@ func (m *Dense) Augment(a, b Matrix) {
m.reuseAsNonZeroed(ar, ac+bc)

m.Copy(a)
w := m.Slice(0, br, ac, ac+bc).(*Dense)
w := m.slice(0, br, ac, ac+bc)
w.Copy(b)
}

Expand Down
2 changes: 1 addition & 1 deletion mat/dense_arithmetic.go
Expand Up @@ -682,7 +682,7 @@ func (m *Dense) Kronecker(a, b Matrix) {
m.reuseAsNonZeroed(ra*rb, ca*cb)
for i := 0; i < ra; i++ {
for j := 0; j < ca; j++ {
m.Slice(i*rb, (i+1)*rb, j*cb, (j+1)*cb).(*Dense).Scale(a.At(i, j), b)
m.slice(i*rb, (i+1)*rb, j*cb, (j+1)*cb).Scale(a.At(i, j), b)
}
}
}
Expand Down
6 changes: 2 additions & 4 deletions mat/gsvd.go
Expand Up @@ -265,16 +265,14 @@ func (gsvd *GSVD) ZeroRTo(dst *Dense) {
capRows: r,
capCols: c,
}
dst.Slice(0, h, c-k-l, c).(*Dense).
Copy(a.Slice(0, h, c-k-l, c))
dst.slice(0, h, c-k-l, c).Copy(a.Slice(0, h, c-k-l, c))
if r < k+l {
b := Dense{
mat: gsvd.b,
capRows: gsvd.p,
capCols: c,
}
dst.Slice(r, k+l, c+r-k-l, c).(*Dense).
Copy(b.Slice(r-k, l, c+r-k-l, c))
dst.slice(r, k+l, c+r-k-l, c).Copy(b.Slice(r-k, l, c+r-k-l, c))
}
}

Expand Down
4 changes: 4 additions & 0 deletions mat/symmetric.go
Expand Up @@ -563,6 +563,10 @@ func (s *SymDense) SubsetSym(a Symmetric, set []int) {
// SliceSym panics with ErrIndexOutOfRange if the slice is outside the
// capacity of the receiver.
func (s *SymDense) SliceSym(i, k int) Symmetric {
return s.sliceSym(i, k)
}

func (s *SymDense) sliceSym(i, k int) *SymDense {
sz := s.cap
if i < 0 || sz < i || k < i || sz < k {
panic(ErrIndexOutOfRange)
Expand Down
2 changes: 1 addition & 1 deletion mat/symmetric_test.go
Expand Up @@ -608,7 +608,7 @@ func TestViewGrowSquare(t *testing.T) {
// Take a subset and check the view matches.
start1 := test.start1
span1 := test.span1
v := s.SliceSym(start1, start1+span1).(*SymDense)
v := s.sliceSym(start1, start1+span1)
for i := 0; i < span1; i++ {
for j := i; j < span1; j++ {
if v.At(i, j) != s.At(start1+i, start1+j) {
Expand Down

0 comments on commit 5127c36

Please sign in to comment.