diff --git a/mat64/shadow.go b/mat64/shadow.go index cd6c2f4..bd64fc6 100644 --- a/mat64/shadow.go +++ b/mat64/shadow.go @@ -4,7 +4,10 @@ package mat64 -import "github.com/gonum/blas/blas64" +import ( + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" +) const ( // regionOverlap is the panic string used for the general case @@ -66,10 +69,6 @@ func (m *Dense) checkOverlap(a blas64.General) bool { return false } -// BUG(kortschak): Overlap detection for symmetric and triangular matrices is not -// precise; currently overlap is detected if the bounding rectangles overlap rather -// than exact overlap between visible elements. - func (s *SymDense) checkOverlap(a blas64.Symmetric) bool { mat := s.RawSymmetric() if cap(mat.Data) == 0 || cap(a.Data) == 0 { @@ -100,12 +99,15 @@ func (s *SymDense) checkOverlap(a blas64.Symmetric) bool { panic(mismatchedStrides) } - // TODO(kortschak) Make this analysis more precise. - if off > 0 { + if off < 0 { off = -off mat.N, a.N = a.N, mat.N + // If we created the matrix it will always + // be in the upper triangle, but don't trust + // that this is the case. + mat.Uplo, a.Uplo = a.Uplo, mat.Uplo } - if rectanglesOverlap(off, mat.N, a.N, mat.Stride) { + if trianglesOverlap(off, mat.N, a.N, mat.Stride, mat.Uplo == blas.Upper, a.Uplo == blas.Upper) { panic(regionOverlap) } return false @@ -141,12 +143,12 @@ func (t *TriDense) checkOverlap(a blas64.Triangular) bool { panic(mismatchedStrides) } - // TODO(kortschak) Make this analysis more precise. - if off > 0 { + if off < 0 { off = -off mat.N, a.N = a.N, mat.N + mat.Uplo, a.Uplo = a.Uplo, mat.Uplo } - if rectanglesOverlap(off, mat.N, a.N, mat.Stride) { + if trianglesOverlap(off, mat.N, a.N, mat.Stride, mat.Uplo == blas.Upper, a.Uplo == blas.Upper) { panic(regionOverlap) } return false @@ -223,3 +225,55 @@ func rectanglesOverlap(off, aCols, bCols, stride int) bool { // b strictly wraps and so must overlap with a. return true } + +// trianglesOverlap returns whether the strided triangles a and b overlap +// when b is offset by off elements after a but has at least one element before +// the end of a. a and b have aSize and bSize respectively. +func trianglesOverlap(off, aSize, bSize, stride int, aUpper, bUpper bool) bool { + if !rectanglesOverlap(off, aSize, bSize, stride) { + // Fast return if bounding rectangles do not overlap. + return false + } + + // Find location of b relative to a. + rowOffset := off / stride + colOffset := off % stride + if (off+bSize)%stride < colOffset { + // We have wrapped, so readjust offsets. + rowOffset++ + colOffset -= stride + } + + if aUpper { + // Check whether the upper left of b + // is in the triangle of a + if rowOffset >= 0 && rowOffset <= colOffset { + return true + } + // Check whether the upper right of b + // is in the triangle of a. + return bUpper && rowOffset < colOffset+bSize + } + + // Check whether the upper left of b + // is in the triangle of a + if colOffset >= 0 && rowOffset >= colOffset { + return true + } + if !bUpper { + if colOffset < 0 { + // Check whether the lower left of a + // is in the triangle of b. This + // requires a swap of reference origin. + return -rowOffset+aSize > -colOffset + } + // Check whether the lower left of b + // is in the triangle of a. + return rowOffset+bSize > colOffset + } + + // Check whether the upper right corner of b + // is in a or the upper row of b spans a row + // of a. + return rowOffset > colOffset+bSize || colOffset < 0 +} diff --git a/mat64/shadow_test.go b/mat64/shadow_test.go index c38faf4..6f33fcc 100644 --- a/mat64/shadow_test.go +++ b/mat64/shadow_test.go @@ -7,6 +7,10 @@ package mat64 import ( "math/rand" "testing" + + "github.com/gonum/blas" + "github.com/gonum/blas/blas64" + "github.com/gonum/matrix" ) func TestDenseOverlaps(t *testing.T) { @@ -84,12 +88,151 @@ func TestDenseOverlaps(t *testing.T) { } } +func TestTriDenseOverlaps(t *testing.T) { + type view struct { + i, j, n int + *TriDense + } + + rnd := rand.New(rand.NewSource(1)) + + for _, parentKind := range []matrix.TriKind{matrix.Upper, matrix.Lower} { + for n := 1; n < 20; n++ { + data := make([]float64, n*n) + for i := range data { + data[i] = float64(i + 1) + } + m := NewDense(n, n, data) + mt := denseAsTriDense(m, parentKind) + panicked, message := panics(func() { mt.checkOverlap(mt.RawTriangular()) }) + if !panicked { + t.Error("expected matrix overlap with self") + } + if message != regionIdentity { + t.Errorf("unexpected panic message for self overlap: got: %q want: %q", message, regionIdentity) + } + + for i := 0; i < 1000; i++ { + var views [2]view + for k := range views { + if n > 1 { + views[k].i = rnd.Intn(n - 1) + views[k].j = rnd.Intn(n - 1) + views[k].n = rnd.Intn(n-max(views[k].i, views[k].j)-1) + 1 + } else { + views[k].n = 1 + } + viewKind := []matrix.TriKind{matrix.Upper, matrix.Lower}[rnd.Intn(2)] + views[k].TriDense = denseAsTriDense( + m.View(views[k].i, views[k].j, views[k].n, views[k].n).(*Dense), + viewKind) + + wantPanick := overlapsParentTriangle(views[k].i, views[k].j, views[k].n, parentKind, viewKind) + + panicked, _ = panics(func() { mt.checkOverlap(views[k].RawTriangular()) }) + if panicked != wantPanick { + t.Errorf("unexpected (%d×%d)%s overlap with view {rows=%d:%d, cols=%d:%d}%s got:%t want:%t\n% v\n\n% v\n", + n, n, kindString(parentKind), + views[k].i, views[k].i+views[k].n, views[k].j, views[k].j+views[k].n, kindString(viewKind), + panicked, wantPanick, + Formatted(mt), Formatted(views[k].TriDense)) + } + panicked, _ = panics(func() { views[k].checkOverlap(mt.RawTriangular()) }) + if panicked != wantPanick { + t.Errorf("unexpected {rows=%d:%d, cols=%d:%d}%s overlap with parent (%d×%d)%s got:%t want:%t\n% v\n\n% v\n", + views[k].i, views[k].i+views[k].n, views[k].j, views[k].j+views[k].n, kindString(viewKind), + n, n, kindString(parentKind), + panicked, wantPanick, + Formatted(views[k].TriDense), Formatted(mt)) + } + } + + want := overlapSiblingTriangles( + views[0].i, views[0].j, views[0].n, views[0].mat.Uplo == blas.Upper, + views[1].i, views[1].j, views[1].n, views[1].mat.Uplo == blas.Upper, + ) + + for k, v := range views { + w := views[1-k] + got, _ := panics(func() { v.checkOverlap(w.RawTriangular()) }) + if got != want { + t.Errorf("unexpected result for overlap test for {rows=%d:%d, cols=%d:%d}%s with {rows=%d:%d, cols=%d:%d}%s: got:%t want:%t\n% v\n\n% v\n", + v.i, v.i+v.n, v.j, v.j+v.n, kindString(v.mat.Uplo == blas.Upper), + w.i, w.i+w.n, w.j, w.j+w.n, kindString(w.mat.Uplo == blas.Upper), + got, want, + Formatted(v.TriDense), Formatted(w.TriDense)) + } + } + } + } + } +} + type interval struct{ from, to int } func intervalsOverlap(a, b interval) bool { return a.to > b.from && b.to > a.from } +func overlapsParentTriangle(i, j, n int, parent, view matrix.TriKind) bool { + switch parent { + case matrix.Upper: + if i <= j { + return true + } + if view == matrix.Upper { + return i < j+n + } + + case matrix.Lower: + if i >= j { + return true + } + if view == matrix.Lower { + return i+n > j + } + } + + return false +} + +func overlapSiblingTriangles(ai, aj, an int, aKind matrix.TriKind, bi, bj, bn int, bKind matrix.TriKind) bool { + for i := max(ai, bi); i < min(ai+an, bi+bn); i++ { + var a, b interval + + if aKind == matrix.Upper { + a = interval{from: aj - ai + i, to: aj + an} + } else { + a = interval{from: aj, to: aj - ai + i + 1} + } + + if bKind == matrix.Upper { + b = interval{from: bj - bi + i, to: bj + bn} + } else { + b = interval{from: bj, to: bj - bi + i + 1} + } + + if intervalsOverlap(a, b) { + return true + } + } + return false +} + +func abs(a int) int { + if a < 0 { + return -a + } + return a +} + +func kindString(k matrix.TriKind) string { + if k == matrix.Upper { + return "U" + } + return "L" +} + // See https://github.com/gonum/matrix/issues/359 for details. func TestIssue359(t *testing.T) { for xi := 0; xi < 2; xi++ { @@ -115,3 +258,27 @@ func TestIssue359(t *testing.T) { } } } + +// denseAsTriDense returns a triangular matrix derived from the +// square matrix m, with the orientation specified by kind. +func denseAsTriDense(m *Dense, kind matrix.TriKind) *TriDense { + r, c := m.Dims() + if r != c { + panic(matrix.ErrShape) + } + n := r + uplo := blas.Lower + if kind == matrix.Upper { + uplo = blas.Upper + } + return &TriDense{ + mat: blas64.Triangular{ + N: n, + Stride: m.mat.Stride, + Data: m.mat.Data, + Uplo: uplo, + Diag: blas.NonUnit, + }, + cap: n, + } +}