Skip to content

Commit

Permalink
Change x.ss.iso_value to return a Scalar; also, coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Apr 11, 2022
1 parent 203b17f commit cdb451f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion grblas/_ss/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def is_iso(self):
def iso_value(self):
if self.is_iso:
# This may not be thread-safe if the parent is being modified in another thread
return next(self.itervalues())
return Scalar.from_value(next(self.itervalues()), dtype=self._parent.dtype, name="")
raise ValueError("Matrix is not iso-valued")

@property
Expand Down
4 changes: 2 additions & 2 deletions grblas/_ss/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..base import call
from ..dtypes import _INDEX, INT64, UINT64, lookup_dtype
from ..exceptions import _error_code_lookup, check_status, check_status_carg
from ..scalar import _as_scalar
from ..scalar import Scalar, _as_scalar
from ..utils import _CArray, ints_to_numpy_buffer, libget, values_to_numpy_buffer, wrapdoc
from .matrix import MatrixArray, _concat_mn, normalize_chunks
from .prefix_scan import prefix_scan
Expand Down Expand Up @@ -100,7 +100,7 @@ def is_iso(self):
def iso_value(self):
if self.is_iso:
# This may not be thread-safe if the parent is being modified in another thread
return next(self.itervalues())
return Scalar.from_value(next(self.itervalues()), dtype=self._parent.dtype, name="")
raise ValueError("Vector is not iso-valued")

@property
Expand Down
1 change: 1 addition & 0 deletions grblas/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,7 @@ def test_iteration(v):
assert len(list(v.ss.iterkeys(4))) == 2
assert len(list(v.ss.itervalues(-3))) == 2
assert len(list(v.ss.iteritems(-v.size))) == N
assert len(list(v.ss.itervalues(-v.size - 1))) == N
assert len(list(v.ss.iterkeys(v.size + 1))) == 0

v = Vector.ss.import_sparse(**v.ss.export("sparse"))
Expand Down

0 comments on commit cdb451f

Please sign in to comment.