Skip to content

Commit

Permalink
Merge branch 'main' into graphblas_python
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Apr 12, 2022
2 parents 56784e9 + ee086c7 commit 5ba7876
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_and_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
if [[ ${{ matrix.cfg.sourcetype }} == "wheel" ]]; then
pip install suitesparse-graphblas
else
conda install -c conda-forge "graphblas>=6"
conda install -c conda-forge "graphblas>=7.0.2"
fi
if [[ ${{ matrix.cfg.sourcetype }} == "source" ]]; then
pip install --no-binary=all suitesparse-graphblas
Expand Down
154 changes: 135 additions & 19 deletions graphblas/_ss/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .. import ffi, lib, monoid
from ..base import call, record_raw
from ..dtypes import _INDEX, INT64, lookup_dtype
from ..exceptions import check_status, check_status_carg
from ..exceptions import _error_code_lookup, check_status, check_status_carg
from ..scalar import Scalar, _as_scalar
from ..utils import (
_CArray,
Expand All @@ -26,6 +26,7 @@
from .utils import get_order

ffi_new = ffi.new
NULL = ffi.NULL


@njit
Expand Down Expand Up @@ -354,7 +355,8 @@ def is_iso(self):
@property
def iso_value(self):
if self.is_iso:
return self._parent.reduce_scalar(monoid.any).new(name="")
# This may not be thread-safe if the parent is being modified in another thread
return Scalar.from_value(next(self.itervalues()), dtype=self._parent.dtype, name="")
raise ValueError("Matrix is not iso-valued")

@property
Expand Down Expand Up @@ -574,6 +576,120 @@ def build_scalar(self, rows, columns, value):
],
)

def _begin_iter(self, seek):
it_ptr = ffi.new("GxB_Iterator*")
info = lib.GxB_Iterator_new(it_ptr)
it = it_ptr[0]
success = lib.GrB_SUCCESS
info = lib.GxB_Matrix_Iterator_attach(it, self._parent._carg, NULL)
if info != success: # pragma: no cover
lib.GxB_Iterator_free(it_ptr)
raise _error_code_lookup[info]("Matrix iterator failed to attach")
if seek < 0:
p = lib.GxB_Matrix_Iterator_getpmax(it)
seek += p
if seek < 0:
seek = 0
info = lib.GxB_Matrix_Iterator_seek(it, seek)
if info != success:
lib.GxB_Iterator_free(it_ptr)
raise _error_code_lookup[info]("Matrix iterator failed to seek")
return it_ptr

def iterkeys(self, seek=0):
"""Iterate over all the row and column indices of a Matrix.
Parameters
----------
seek : int, default 0
Index of entry to seek to. May be negative to seek backwards from the end.
Matrix objects in bitmap format seek as if it's full format (i.e., it
ignores the bitmap mask).
The Matrix should not be modified during iteration; doing so will
result in undefined behavior.
"""
try:
it_ptr = self._begin_iter(seek)
except StopIteration:
return
it = it_ptr[0]
info = success = lib.GrB_SUCCESS
key_func = lib.GxB_Matrix_Iterator_getIndex
next_func = lib.GxB_Matrix_Iterator_next
row_ptr = ffi_new("GrB_Index*")
col_ptr = ffi_new("GrB_Index*")
while info == success:
key_func(it, row_ptr, col_ptr)
yield (row_ptr[0], col_ptr[0])
info = next_func(it)
lib.GxB_Iterator_free(it_ptr)
if info != lib.GxB_EXHAUSTED: # pragma: no cover
raise _error_code_lookup[info]("Matrix iterator failed")

def itervalues(self, seek=0):
"""Iterate over all the values of a Matrix.
Parameters
----------
seek : int, default 0
Index of entry to seek to. May be negative to seek backwards from the end.
Matrix objects in bitmap format seek as if it's full format (i.e., it
ignores the bitmap mask).
The Matrix should not be modified during iteration; doing so will
result in undefined behavior.
"""
try:
it_ptr = self._begin_iter(seek)
except StopIteration:
return
it = it_ptr[0]
info = success = lib.GrB_SUCCESS
val_func = getattr(lib, f"GxB_Iterator_get_{self._parent.dtype.name}")
next_func = lib.GxB_Matrix_Iterator_next
while info == success:
yield val_func(it)
info = next_func(it)
lib.GxB_Iterator_free(it_ptr)
if info != lib.GxB_EXHAUSTED: # pragma: no cover
raise _error_code_lookup[info]("Matrix iterator failed")

def iteritems(self, seek=0):
"""Iterate over all the row, column, and value triples of a Matrix.
Parameters
----------
seek : int, default 0
Index of entry to seek to. May be negative to seek backwards from the end.
Matrix objects in bitmap format seek as if it's full format (i.e., it
ignores the bitmap mask).
The Matrix should not be modified during iteration; doing so will
result in undefined behavior.
"""
try:
it_ptr = self._begin_iter(seek)
except StopIteration:
return
it = it_ptr[0]
info = success = lib.GrB_SUCCESS
key_func = lib.GxB_Matrix_Iterator_getIndex
val_func = getattr(lib, f"GxB_Iterator_get_{self._parent.dtype.name}")
next_func = lib.GxB_Matrix_Iterator_next
row_ptr = ffi_new("GrB_Index*")
col_ptr = ffi_new("GrB_Index*")
while info == success:
key_func(it, row_ptr, col_ptr)
yield (row_ptr[0], col_ptr[0], val_func(it))
info = next_func(it)
lib.GxB_Iterator_free(it_ptr)
if info != lib.GxB_EXHAUSTED: # pragma: no cover
raise _error_code_lookup[info]("Matrix iterator failed")

def export(self, format=None, *, sort=False, give_ownership=False, raw=False):
"""
GxB_Matrix_export_xxx
Expand Down Expand Up @@ -844,7 +960,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
if give_ownership:
if method == "export":
parent.__del__()
parent.gb_obj = ffi.NULL
parent.gb_obj = NULL
else:
parent.clear()
elif format == "coor":
Expand Down Expand Up @@ -883,7 +999,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
Ap_size = ffi_new("GrB_Index*")
Ax_size = ffi_new("GrB_Index*")
if sort:
jumbled = ffi.NULL
jumbled = NULL
else:
jumbled = ffi_new("bool*")
is_iso = ffi_new("bool*")
Expand All @@ -903,7 +1019,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
Ax_size,
is_iso,
jumbled,
ffi.NULL,
NULL,
),
parent,
)
Expand Down Expand Up @@ -945,7 +1061,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
Ax_size,
is_iso,
jumbled,
ffi.NULL,
NULL,
),
parent,
)
Expand Down Expand Up @@ -993,7 +1109,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
is_iso,
nvec,
jumbled,
ffi.NULL,
NULL,
),
parent,
)
Expand Down Expand Up @@ -1048,7 +1164,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
is_iso,
nvec,
jumbled,
ffi.NULL,
NULL,
),
parent,
)
Expand Down Expand Up @@ -1100,7 +1216,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
Ax_size,
is_iso,
nvals_,
ffi.NULL,
NULL,
),
parent,
)
Expand Down Expand Up @@ -1150,7 +1266,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
Ax,
Ax_size,
is_iso,
ffi.NULL,
NULL,
),
parent,
)
Expand All @@ -1176,7 +1292,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
rv["format"] = format
rv["values"] = values
if method == "export":
parent.gb_obj = ffi.NULL
parent.gb_obj = NULL
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv
Expand Down Expand Up @@ -1338,7 +1454,7 @@ def _import_csr(
values.nbytes,
is_iso,
not sorted_cols,
ffi.NULL,
NULL,
)
if method == "import":
check_status_carg(
Expand Down Expand Up @@ -1513,7 +1629,7 @@ def _import_csc(
values.nbytes,
is_iso,
not sorted_rows,
ffi.NULL,
NULL,
)
if method == "import":
check_status_carg(
Expand Down Expand Up @@ -1709,7 +1825,7 @@ def _import_hypercsr(
is_iso,
nvec,
not sorted_cols,
ffi.NULL,
NULL,
)
if method == "import":
check_status_carg(
Expand Down Expand Up @@ -1905,7 +2021,7 @@ def _import_hypercsc(
is_iso,
nvec,
not sorted_rows,
ffi.NULL,
NULL,
)
if method == "import":
check_status_carg(
Expand Down Expand Up @@ -2085,7 +2201,7 @@ def _import_bitmapr(
values.nbytes,
is_iso,
nvals,
ffi.NULL,
NULL,
)
if method == "import":
check_status_carg(
Expand Down Expand Up @@ -2263,7 +2379,7 @@ def _import_bitmapc(
values.nbytes,
is_iso,
nvals,
ffi.NULL,
NULL,
)
if method == "import":
check_status_carg(
Expand Down Expand Up @@ -2413,7 +2529,7 @@ def _import_fullr(
Ax,
values.nbytes,
is_iso,
ffi.NULL,
NULL,
)
if method == "import":
check_status_carg(
Expand Down Expand Up @@ -2561,7 +2677,7 @@ def _import_fullc(
Ax,
values.nbytes,
is_iso,
ffi.NULL,
NULL,
)
if method == "import":
check_status_carg(
Expand Down

0 comments on commit 5ba7876

Please sign in to comment.