Skip to content

Commit

Permalink
Respond to PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
btracey committed Nov 25, 2018
1 parent 26da546 commit 67021b8
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 10 deletions.
11 changes: 10 additions & 1 deletion mat/diagonal.go
Expand Up @@ -209,7 +209,16 @@ func (d *DiagDense) DiagFrom(m Matrix) {
Inc: mat.Stride + 1,
Data: mat.Data[:(n-1)*mat.Stride+n],
}
// TODO(kortschak): Add banded triangular handling when the type exists.
case RawTriBander:
mat := r.RawTriBand()
data := mat.Data
if mat.Uplo == blas.Lower {
data = data[mat.K:]
}
vec = blas64.Vector{
Inc: mat.Stride,
Data: data[:(n-1)*mat.Stride+1],
}
case RawTriangular:
mat := r.RawTriangular()
if mat.Diag == blas.Unit {
Expand Down
36 changes: 36 additions & 0 deletions mat/diagonal_test.go
Expand Up @@ -244,6 +244,42 @@ func TestDiagFrom(t *testing.T) {
0, 0, 0, 0, 0, 6,
}),
},
{
mat: NewTriBandDense(6, 2, Upper, []float64{
1, math.NaN(), math.NaN(),
2, math.NaN(), math.NaN(),
3, math.NaN(), math.NaN(),
4, math.NaN(), math.NaN(),
5, math.NaN(), math.NaN(),
6, math.NaN(), math.NaN(),
}),
want: NewDense(6, 6, []float64{
1, 0, 0, 0, 0, 0,
0, 2, 0, 0, 0, 0,
0, 0, 3, 0, 0, 0,
0, 0, 0, 4, 0, 0,
0, 0, 0, 0, 5, 0,
0, 0, 0, 0, 0, 6,
}),
},
{
mat: NewTriBandDense(6, 2, Lower, []float64{
math.NaN(), math.NaN(), 1,
math.NaN(), math.NaN(), 2,
math.NaN(), math.NaN(), 3,
math.NaN(), math.NaN(), 4,
math.NaN(), math.NaN(), 5,
math.NaN(), math.NaN(), 6,
}),
want: NewDense(6, 6, []float64{
1, 0, 0, 0, 0, 0,
0, 2, 0, 0, 0, 0,
0, 0, 3, 0, 0, 0,
0, 0, 0, 4, 0, 0,
0, 0, 0, 0, 5, 0,
0, 0, 0, 0, 0, 6,
}),
},
{
mat: NewTriDense(6, Upper, []float64{
1, math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(),
Expand Down
3 changes: 3 additions & 0 deletions mat/triangular.go
Expand Up @@ -298,6 +298,9 @@ func (t *TriDense) isolatedWorkspace(a Triangular) (w *TriDense, restore func())

// DiagView returns the diagonal as a matrix backed by the original data.
func (t *TriDense) DiagView() Diagonal {
if t.mat.Diag == blas.Unit {
panic("mat: cannot take view of Unit diagonal")
}
n := t.mat.N
return &DiagDense{
mat: blas64.Vector{
Expand Down
16 changes: 7 additions & 9 deletions mat/triband.go
Expand Up @@ -308,20 +308,18 @@ func (t *TriBandDense) RawTriBand() blas64.TriangularBand {

// DiagView returns the diagonal as a matrix backed by the original data.
func (t *TriBandDense) DiagView() Diagonal {
if t.mat.Diag == blas.Unit {
panic("mat: cannot take view of Unit diagonal")
}
n := t.mat.N
if t.isUpper() {
return &DiagDense{
mat: blas64.Vector{
Inc: t.mat.Stride,
Data: t.mat.Data[:(n-1)*t.mat.Stride+1],
},
n: n,
}
data := t.mat.Data
if !t.isUpper() {
data = data[t.mat.K:]
}
return &DiagDense{
mat: blas64.Vector{
Inc: t.mat.Stride,
Data: t.mat.Data[t.mat.K : t.mat.K+(n-1)*t.mat.Stride+1],
Data: data[:(n-1)*t.mat.Stride+1],
},
n: n,
}
Expand Down

0 comments on commit 67021b8

Please sign in to comment.