Skip to content

Commit

Permalink
Add seek argument to SuiteSparse iterators, and add docstrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Apr 11, 2022
1 parent 62e8fbd commit 203b17f
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 24 deletions.
71 changes: 59 additions & 12 deletions grblas/_ss/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def build_scalar(self, rows, columns, value):
],
)

def _begin_iter(self):
def _begin_iter(self, seek):
it_ptr = ffi.new("GxB_Iterator*")
info = lib.GxB_Iterator_new(it_ptr)
it = it_ptr[0]
Expand All @@ -585,16 +585,35 @@ def _begin_iter(self):
if info != success: # pragma: no cover
lib.GxB_Iterator_free(it_ptr)
raise _error_code_lookup[info]("Matrix iterator failed to attach")
info = lib.GxB_Matrix_Iterator_seek(it, 0)
if info != success: # pragma: no cover
if seek < 0:
p = lib.GxB_Matrix_Iterator_getpmax(it)
seek += p
if seek < 0:
seek = 0
info = lib.GxB_Matrix_Iterator_seek(it, seek)
if info != success:
lib.GxB_Iterator_free(it_ptr)
raise _error_code_lookup[info]("Matrix iterator failed to seek")
return it_ptr

def iterkeys(self):
if self._parent._nvals == 0:
def iterkeys(self, seek=0):
"""Iterate over all the row and column indices of a Matrix.
Parameters
----------
seek : int, default 0
Index of entry to seek to. May be negative to seek backwards from the end.
Matrix objects in bitmap format seek as if it's full format (i.e., it
ignores the bitmap mask).
The Matrix should not be modified during iteration; doing so will
result in undefined behavior.
"""
try:
it_ptr = self._begin_iter(seek)
except StopIteration:
return
it_ptr = self._begin_iter()
it = it_ptr[0]
info = success = lib.GrB_SUCCESS
key_func = lib.GxB_Matrix_Iterator_getIndex
Expand All @@ -609,10 +628,24 @@ def iterkeys(self):
if info != lib.GxB_EXHAUSTED: # pragma: no cover
raise _error_code_lookup[info]("Matrix iterator failed")

def itervalues(self):
if self._parent._nvals == 0:
def itervalues(self, seek=0):
"""Iterate over all the values of a Matrix.
Parameters
----------
seek : int, default 0
Index of entry to seek to. May be negative to seek backwards from the end.
Matrix objects in bitmap format seek as if it's full format (i.e., it
ignores the bitmap mask).
The Matrix should not be modified during iteration; doing so will
result in undefined behavior.
"""
try:
it_ptr = self._begin_iter(seek)
except StopIteration:
return
it_ptr = self._begin_iter()
it = it_ptr[0]
info = success = lib.GrB_SUCCESS
val_func = getattr(lib, f"GxB_Iterator_get_{self._parent.dtype.name}")
Expand All @@ -624,10 +657,24 @@ def itervalues(self):
if info != lib.GxB_EXHAUSTED: # pragma: no cover
raise _error_code_lookup[info]("Matrix iterator failed")

def iteritems(self):
if self._parent._nvals == 0:
def iteritems(self, seek=0):
"""Iterate over all the row, column, and value triples of a Matrix.
Parameters
----------
seek : int, default 0
Index of entry to seek to. May be negative to seek backwards from the end.
Matrix objects in bitmap format seek as if it's full format (i.e., it
ignores the bitmap mask).
The Matrix should not be modified during iteration; doing so will
result in undefined behavior.
"""
try:
it_ptr = self._begin_iter(seek)
except StopIteration:
return
it_ptr = self._begin_iter()
it = it_ptr[0]
info = success = lib.GrB_SUCCESS
key_func = lib.GxB_Matrix_Iterator_getIndex
Expand Down
71 changes: 59 additions & 12 deletions grblas/_ss/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def build_scalar(self, indices, value):
],
)

def _begin_iter(self):
def _begin_iter(self, seek):
it_ptr = ffi.new("GxB_Iterator*")
info = lib.GxB_Iterator_new(it_ptr)
it = it_ptr[0]
Expand All @@ -286,16 +286,35 @@ def _begin_iter(self):
if info != success: # pragma: no cover
lib.GxB_Iterator_free(it_ptr)
raise _error_code_lookup[info]("Vector iterator failed to attach")
info = lib.GxB_Vector_Iterator_seek(it, 0)
if info != success: # pragma: no cover
if seek < 0:
p = lib.GxB_Vector_Iterator_getpmax(it)
seek += p
if seek < 0:
seek = 0
info = lib.GxB_Vector_Iterator_seek(it, seek)
if info != success:
lib.GxB_Iterator_free(it_ptr)
raise _error_code_lookup[info]("Vector iterator failed to seek")
return it_ptr

def iterkeys(self):
if self._parent._nvals == 0:
def iterkeys(self, seek=0):
"""Iterate over all the indices of a Vector.
Parameters
----------
seek : int, default 0
Index of entry to seek to. May be negative to seek backwards from the end.
Vector objects in bitmap format seek as if it's full format (i.e., it
ignores the bitmap mask).
The Vector should not be modified during iteration; doing so will
result in undefined behavior.
"""
try:
it_ptr = self._begin_iter(seek)
except StopIteration:
return
it_ptr = self._begin_iter()
it = it_ptr[0]
info = success = lib.GrB_SUCCESS
key_func = lib.GxB_Vector_Iterator_getIndex
Expand All @@ -307,10 +326,24 @@ def iterkeys(self):
if info != lib.GxB_EXHAUSTED: # pragma: no cover
raise _error_code_lookup[info]("Vector iterator failed")

def itervalues(self):
if self._parent._nvals == 0:
def itervalues(self, seek=0):
"""Iterate over all the values of a Vector.
Parameters
----------
seek : int, default 0
Index of entry to seek to. May be negative to seek backwards from the end.
Vector objects in bitmap format seek as if it's full format (i.e., it
ignores the bitmap mask).
The Vector should not be modified during iteration; doing so will
result in undefined behavior.
"""
try:
it_ptr = self._begin_iter(seek)
except StopIteration:
return
it_ptr = self._begin_iter()
it = it_ptr[0]
info = success = lib.GrB_SUCCESS
val_func = getattr(lib, f"GxB_Iterator_get_{self._parent.dtype.name}")
Expand All @@ -322,10 +355,24 @@ def itervalues(self):
if info != lib.GxB_EXHAUSTED: # pragma: no cover
raise _error_code_lookup[info]("Vector iterator failed")

def iteritems(self):
if self._parent._nvals == 0:
def iteritems(self, seek=0):
"""Iterate over all the indices and values of a Vector.
Parameters
----------
seek : int, default 0
Index of entry to seek to. May be negative to seek backwards from the end.
Vector objects in bitmap format seek as if it's full format (i.e., it
ignores the bitmap mask).
The Vector should not be modified during iteration; doing so will
result in undefined behavior.
"""
try:
it_ptr = self._begin_iter(seek)
except StopIteration:
return
it_ptr = self._begin_iter()
it = it_ptr[0]
info = success = lib.GrB_SUCCESS
key_func = lib.GxB_Vector_Iterator_getIndex
Expand Down
16 changes: 16 additions & 0 deletions grblas/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3140,6 +3140,22 @@ def test_iteration(A):
assert sorted(zip(rows, columns)) == sorted(A.ss.iterkeys())
assert sorted(values) == sorted(A.ss.itervalues())
assert sorted(zip(rows, columns, values)) == sorted(A.ss.iteritems())
N = rows.size

A = Matrix.ss.import_bitmapr(**A.ss.export("bitmapr"))
assert A.ss.format == "bitmapr"
assert len(list(A.ss.iterkeys(3))) == N - A[0, :3].new().nvals
assert len(list(A.ss.iterkeys(-3))) == A[-1, -3:].new().nvals

A = Matrix.ss.import_csr(**A.ss.export("csr"))
assert A.ss.format == "csr"
assert len(list(A.ss.iterkeys(3))) == N - 3
assert len(list(A.ss.iterkeys(-3))) == 3
assert len(list(A.ss.itervalues(N))) == 0
assert len(list(A.ss.iteritems(N + 1))) == 0
assert len(list(A.ss.iterkeys(N + 2))) == 0
assert len(list(A.ss.iterkeys(-N))) == N
assert len(list(A.ss.itervalues(-N - 1))) == N


def test_udt():
Expand Down
14 changes: 14 additions & 0 deletions grblas/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,6 +1840,20 @@ def test_iteration(v):
assert sorted(values) == sorted(v.ss.itervalues())
assert sorted(zip(indices, values)) == sorted(v.ss.iteritems())

N = indices.size
v = Vector.ss.import_bitmap(**v.ss.export("bitmap"))
assert v.ss.format == "bitmap"
assert len(list(v.ss.iterkeys(4))) == 2
assert len(list(v.ss.itervalues(-3))) == 2
assert len(list(v.ss.iteritems(-v.size))) == N
assert len(list(v.ss.iterkeys(v.size + 1))) == 0

v = Vector.ss.import_sparse(**v.ss.export("sparse"))
assert v.ss.format == "sparse"
assert len(list(v.ss.iterkeys(2))) == 2
assert len(list(v.ss.itervalues(N))) == 0
assert len(list(v.ss.iteritems(N + 1))) == 0


def test_broadcasting(A, v):
# Vector on left
Expand Down

0 comments on commit 203b17f

Please sign in to comment.