Skip to content

Commit

Permalink
More coverage; better handling of array subdtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Apr 9, 2022
1 parent c105b92 commit 1ce8421
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 37 deletions.
32 changes: 18 additions & 14 deletions grblas/_ss/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
parent = self._parent
else:
parent = self._parent.dup(name=f"M_{method}")
dtype = np.dtype(parent.dtype.np_type)
dtype = parent.dtype.np_type
index_dtype = np.dtype(np.uint64)

nrows = parent._nrows
Expand Down Expand Up @@ -847,25 +847,27 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
parent.gb_obj = ffi.NULL
else:
parent.clear()
return rv
elif format == "coor":
info = self._export(
rv = self._export(
"csr", sort=sort, give_ownership=give_ownership, raw=False, method=method
)
info["rows"] = indptr_to_indices(info.pop("indptr"))
info["cols"] = info.pop("col_indices")
info["sorted_rows"] = True
info["format"] = "coor"
return info
rv["rows"] = indptr_to_indices(rv.pop("indptr"))
rv["cols"] = rv.pop("col_indices")
rv["sorted_rows"] = True
rv["format"] = "coor"
elif format == "cooc":
info = self._export(
rv = self._export(
"csc", sort=sort, give_ownership=give_ownership, raw=False, method=method
)
info["cols"] = indptr_to_indices(info.pop("indptr"))
info["rows"] = info.pop("row_indices")
info["sorted_cols"] = True
info["format"] = "cooc"
return info
rv["cols"] = indptr_to_indices(rv.pop("indptr"))
rv["rows"] = rv.pop("row_indices")
rv["sorted_cols"] = True
rv["format"] = "cooc"
else:
raise ValueError(f"Invalid format: {format}")
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv

if method == "export":
mhandle = ffi_new("GrB_Matrix*", parent._carg)
Expand Down Expand Up @@ -1175,6 +1177,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
rv["values"] = values
if method == "export":
parent.gb_obj = ffi.NULL
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion grblas/_ss/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
parent = self._parent
else:
parent = self._parent.dup(name=f"v_{method}")
dtype = np.dtype(parent.dtype.np_type)
dtype = parent.dtype.np_type
index_dtype = np.dtype(np.uint64)

if format is None:
Expand Down Expand Up @@ -481,6 +481,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
)
if method == "export":
parent.gb_obj = ffi.NULL
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv

@classmethod
Expand Down
5 changes: 1 addition & 4 deletions grblas/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def __eq__(self, other):
return self is other
# Attempt to use `other` as a lookup key
try:
other = lookup_dtype(other)
return self is other
return self is lookup_dtype(other)
except ValueError:
raise TypeError(f"Invalid or unknown datatype: {other}") from None

Expand All @@ -49,7 +48,6 @@ def __lt__(self, other):

def __reduce__(self):
if self._is_udt:
1 / 0
return (self._deserialize, (self.name, self.np_type, self._is_anonymous))
if self.gb_name == "GrB_Index":
return "_INDEX"
Expand All @@ -69,7 +67,6 @@ def _is_udt(self):

@staticmethod
def _deserialize(name, dtype, is_anonymous):
1 / 0
if is_anonymous:
return register_anonymous(dtype, name)
if name in _registry:
Expand Down
9 changes: 6 additions & 3 deletions grblas/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def build(self, rows, columns, values, *, dup_op=None, clear=False, nrows=None,
rows = ints_to_numpy_buffer(rows, np.uint64, name="row indices")
columns = ints_to_numpy_buffer(columns, np.uint64, name="column indices")
values, dtype = values_to_numpy_buffer(values, self.dtype)
n = values.size
n = values.shape[0]
if rows.size != n or columns.size != n:
raise ValueError(
f"`rows` and `columns` and `values` lengths must match: "
Expand Down Expand Up @@ -386,7 +386,7 @@ def from_values(
"""
rows = ints_to_numpy_buffer(rows, np.uint64, name="row indices")
columns = ints_to_numpy_buffer(columns, np.uint64, name="column indices")
values, dtype = values_to_numpy_buffer(values, dtype)
values, new_dtype = values_to_numpy_buffer(values, dtype)
# Compute nrows and ncols if not provided
if nrows is None:
if rows.size == 0:
Expand All @@ -396,8 +396,11 @@ def from_values(
if columns.size == 0:
raise ValueError("No column indices provided. Unable to infer ncols.")
ncols = int(columns.max()) + 1
if dtype is None and values.ndim > 1:
# Look for array-subtdype
new_dtype = lookup_dtype(np.dtype((new_dtype.np_type, values.shape[1:])))
# Create the new matrix
C = cls.new(dtype, nrows, ncols, name=name)
C = cls.new(new_dtype, nrows, ncols, name=name)
if values.ndim == 0:
if dup_op is not None:
raise ValueError(
Expand Down
17 changes: 7 additions & 10 deletions grblas/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def __contains__(self, type_):
return False
try:
self[type_]
except (TypeError, KeyError):
except (TypeError, KeyError, numba.NumbaError):
return False
else:
return True
Expand Down Expand Up @@ -935,10 +935,7 @@ def _getitem_udt(self, dtype, dtype2):

numba_func = self._numba_func
sig = (dtype.numba_type,)
try:
numba_func.compile(sig)
except numba.TypingError:
raise TypeError("TODO")
numba_func.compile(sig) # Should we catch and give additional error message?
ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type)
if ret_type is not dtype and ret_type.gb_obj is dtype.gb_obj:
ret_type = dtype
Expand Down Expand Up @@ -1493,10 +1490,7 @@ def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover
else:
numba_func = self._numba_func
sig = (dtype.numba_type, dtype2.numba_type)
try:
numba_func.compile(sig)
except numba.TypingError:
raise TypeError("TODO")
numba_func.compile(sig) # Should we catch and give additional error message?
ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type)
if ret_type is not dtype and ret_type is not dtype2:
if ret_type.gb_obj is dtype.gb_obj:
Expand Down Expand Up @@ -1812,7 +1806,10 @@ def _getitem_udt(self, dtype, dtype2):
if dtype2 is None:
dtype2 = dtype
elif dtype != dtype2:
raise TypeError("TODO")
raise TypeError(
"Monoid inputs must be the same dtype (got {dtype} and {dtype2}); "
"unable to coerce when using UDTs."
)
if dtype in self._udt_types:
return self._udt_ops[dtype]
binaryop = self.binaryop._getitem_udt(dtype, dtype2)
Expand Down
Binary file added grblas/tests/pickle3.pkl
Binary file not shown.
36 changes: 36 additions & 0 deletions grblas/tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pickle

import numpy as np
import pytest

import grblas as gb
Expand Down Expand Up @@ -273,3 +274,38 @@ def test_deserialize_parameterized():
# Again!
with open(os.path.join(thisdir, "pickle2.pkl"), "rb") as f:
pickle.load(f) # TODO: check results


def test_udt():
record_dtype = np.dtype([("x", np.bool_), ("y", np.int64)], align=True)
udt = gb.dtypes.register_new("PickleUDT", record_dtype)
assert not udt._is_anonymous
assert pickle.loads(pickle.dumps(udt)) is udt

np_dtype = np.dtype("(3,)uint16")
udt2 = gb.dtypes.register_anonymous(np_dtype, "pickling")
assert udt2._is_anonymous
assert pickle.loads(pickle.dumps(udt2)).np_type == udt2.np_type

thisdir = os.path.dirname(__file__)
with open(os.path.join(thisdir, "pickle3.pkl"), "rb") as f:
d = pickle.load(f)
udt3 = d["PickledUDT"]
v = d["v"]
assert udt3.name == "PickledUDT"
assert udt3 is gb.dtypes.PickledUDT
assert v.dtype == udt3
expected = gb.Vector.new(udt3, size=2)
expected[0] = (False, 1)
expected[1] = (True, 3)
assert expected.isequal(v)

udt4 = d["pickled_subdtype"]
A = d["A"]
assert udt4.name == "pickled_subdtype"
assert not hasattr(gb.dtypes, udt4.name)
assert A.dtype == udt4
expected = gb.Matrix.new(udt4, nrows=1, ncols=2)
expected[0, 0] = (1, 2)
expected[0, 1] = (3, 4)
assert expected.isequal(A)
2 changes: 2 additions & 0 deletions grblas/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,8 @@ def test_udt():

indices, values = v.to_values()
assert_array_equal(values, np.ones((2, 3), dtype=np.uint16))
assert v.isequal(Vector.from_values(indices, values, dtype=v.dtype))
assert v.isequal(Vector.from_values(indices, values))

s = v.reduce(monoid.any).new()
assert (s.value == [1, 1, 1]).all()
Expand Down
10 changes: 8 additions & 2 deletions grblas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ def ints_to_numpy_buffer(array, dtype, *, name="array", copy=False, ownable=Fals
return array


def _get_subdtype(dtype):
while dtype.subdtype is not None:
dtype = dtype.subdtype[0]
return dtype


def values_to_numpy_buffer(array, dtype=None, *, copy=False, ownable=False, order="C"):
if dtype is not None:
dtype = lookup_dtype(dtype)
array = np.array(array, dtype.np_type, copy=copy, order=order)
array = np.array(array, _get_subdtype(dtype.np_type), copy=copy, order=order)
else:
is_input_np = isinstance(array, np.ndarray)
array = np.array(array, copy=copy, order=order)
Expand Down Expand Up @@ -127,7 +133,7 @@ def __init__(self, array=None, dtype=_INDEX, *, size=None, name=None):
if size is not None:
self.array = np.empty(size, dtype=dtype.np_type)
else:
self.array = np.array(array, dtype=dtype.np_type, copy=False, order="C")
self.array = np.array(array, dtype=_get_subdtype(dtype.np_type), copy=False, order="C")
c_type = dtype.c_type if dtype._is_udt else f"{dtype.c_type}*"
self._carg = ffi.cast(c_type, ffi.from_buffer(self.array))
self.dtype = dtype
Expand Down
9 changes: 6 additions & 3 deletions grblas/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def build(self, indices, values, *, dup_op=None, clear=False, size=None):
# TODO: accept `dtype` keyword to match the dtype of `values`?
indices = ints_to_numpy_buffer(indices, np.uint64, name="indices")
values, dtype = values_to_numpy_buffer(values, self.dtype)
n = values.size
n = values.shape[0]
if indices.size != n:
raise ValueError(
f"`indices` and `values` lengths must match: {indices.size} != {values.size}"
Expand Down Expand Up @@ -355,14 +355,17 @@ def from_values(cls, indices, values, dtype=None, *, size=None, dup_op=None, nam
values may be a scalar, in which case duplicate indices are ignored.
"""
indices = ints_to_numpy_buffer(indices, np.uint64, name="indices")
values, dtype = values_to_numpy_buffer(values, dtype)
values, new_dtype = values_to_numpy_buffer(values, dtype)
# Compute size if not provided
if size is None:
if indices.size == 0:
raise ValueError("No indices provided. Unable to infer size.")
size = int(indices.max()) + 1
if dtype is None and values.ndim > 1:
# Array-subtdype
new_dtype = lookup_dtype(np.dtype((new_dtype.np_type, values.shape[1:])))
# Create the new vector
w = cls.new(dtype, size, name=name)
w = cls.new(new_dtype, size, name=name)
if values.ndim == 0:
if dup_op is not None:
raise ValueError(
Expand Down

0 comments on commit 1ce8421

Please sign in to comment.