Skip to content

Commit

Permalink
More coverage. Anything else needed?
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Apr 9, 2022
1 parent 1ce8421 commit c9e8772
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 47 deletions.
53 changes: 13 additions & 40 deletions grblas/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _udt_mask(dtype):
rv = np.concatenate(masks)
else:
rv = np.ones(dtype.itemsize, dtype=bool)
# assert rv.size == dtype.itemsize, (rv.size, dtype.itemsize)
# assert rv.size == dtype.itemsize
_udt_mask_cache[dtype] = rv
return rv

Expand Down Expand Up @@ -215,16 +215,14 @@ def monoid(self):
@property
def commutes_to(self):
commutes_to = self.parent.commutes_to
if commutes_to is not None and self.type in commutes_to._typed_ops:
if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt):
return commutes_to[self.type]
# TODO: what about UDTs and `gb.binary.any`?

@property
def _semiring_commutes_to(self):
commutes_to = self.parent._semiring_commutes_to
if commutes_to is not None and self.type in commutes_to._typed_ops:
if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt):
return commutes_to[self.type]
# TODO: what about UDTs and `gb.binary.any`?

@property
def is_commutative(self):
Expand Down Expand Up @@ -475,7 +473,7 @@ def _call(self, *args, **kwargs):
self._monoid(*args, **kwargs)
except Exception:
binop._monoid = None
assert binop._monoid is not binop
# assert binop._monoid is not binop
if self.is_commutative:
binop._commutes_to = binop
# Don't bother yet with creating `binop.commutes_to` (but we could!)
Expand Down Expand Up @@ -806,11 +804,11 @@ def _deserialize(cls, name, *args):


def _identity(x):
return x
return x # pragma: no cover


def _one(x):
return 1
return 1 # pragma: no cover


class UnaryOp(OpBase):
Expand Down Expand Up @@ -940,31 +938,7 @@ def _getitem_udt(self, dtype, dtype2):
if ret_type is not dtype and ret_type.gb_obj is dtype.gb_obj:
ret_type = dtype

# Numba is unable to handle BOOL correctly right now, but we have a workaround
# See: https://github.com/numba/numba/issues/5395
# We're relying on coercion behaving correctly here
return_type = INT8 if ret_type == BOOL else ret_type

# Build wrapper because GraphBLAS wants pointers and void return
nt = numba.types
wrapper_sig = nt.void(
nt.CPointer(return_type.numba_type),
nt.CPointer(dtype.numba_type),
)
if ret_type == BOOL:

def unary_wrapper(z_ptr, x_ptr): # pragma: no cover
z = numba.carray(z_ptr, 1)
x = numba.carray(x_ptr, 1)
z[0] = bool(numba_func(x[0]))

else:

def unary_wrapper(z_ptr, x_ptr): # pragma: no cover
z = numba.carray(z_ptr, 1)
x = numba.carray(x_ptr, 1)
z[0] = numba_func(x[0])

unary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype)
unary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(unary_wrapper)
new_unary = ffi_new("GrB_UnaryOp*")
check_status_carg(
Expand Down Expand Up @@ -1117,7 +1091,7 @@ def _abssecond(x, y):


def _rpow(x, y):
return y**x
return y**x # pragma: no cover


def _isclose(rel_tol=1e-7, abs_tol=0.0):
Expand Down Expand Up @@ -1177,7 +1151,7 @@ def _get_udt_wrapper(numba_func, return_type, dtype, dtype2=None):
else:
xname = "x_ptr"
elif dtype == BOOL:
xname = "bool(x_ptr[0])" # is this still necessary?
xname = "bool(x_ptr[0])"
else:
xname = "x_ptr[0]"

Expand All @@ -1190,7 +1164,7 @@ def _get_udt_wrapper(numba_func, return_type, dtype, dtype2=None):
else:
yname = ", y_ptr"
elif dtype2 == BOOL:
yname = ", bool(y_ptr[0])" # is this still necessary?
yname = ", bool(y_ptr[0])"
else:
yname = ", y_ptr[0]"

Expand Down Expand Up @@ -1414,7 +1388,7 @@ def _getitem_udt(self, dtype, dtype2):

nt = numba.types
if self.name == "eq" and not self._anonymous:
assert dtype == dtype2 # XXX: must be same size?
# assert dtype.np_type == dtype2.np_type
itemsize = dtype.np_type.itemsize
mask = _udt_mask(dtype.np_type)
ret_type = BOOL
Expand Down Expand Up @@ -1451,7 +1425,7 @@ def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover
z_ptr[0] = (x[mask] == y[mask]).all()

elif self.name == "ne" and not self._anonymous:
assert dtype == dtype2 # XXX: must be same size?
# assert dtype.np_type == dtype2.np_type
itemsize = dtype.np_type.itemsize
mask = _udt_mask(dtype.np_type)
ret_type = BOOL
Expand Down Expand Up @@ -1497,7 +1471,6 @@ def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover
ret_type = dtype
elif ret_type.gb_obj is dtype2.gb_obj:
ret_type = dtype2
1 / 0
binary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype, dtype2)

binary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(binary_wrapper)
Expand Down Expand Up @@ -2298,7 +2271,7 @@ def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scala
dtype = unify(
dtype, dtype2, is_left_scalar=is_left_scalar, is_right_scalar=is_right_scalar
)
except TypeError:
except (TypeError, AttributeError):
if op.is_positional:
return op[UINT64]
if op._udt_types is None:
Expand Down
25 changes: 25 additions & 0 deletions grblas/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3185,7 +3185,32 @@ def test_udt():
assert A.reduce_rowwise(monoid.any).new().isequal(expected)
rows, cols, values = A.to_values()
assert A.isequal(Matrix.from_values(rows, cols, values))
assert A.isequal(Matrix.from_values(rows, cols, values, dtype=A.dtype))
info = A.ss.export()
result = A.ss.import_any(**info)
assert result.isequal(A)
info = A.ss.export("cooc")
result = A.ss.import_any(**info)
assert result.isequal(A)
result = unary.positioni(A).new()
expected = Matrix.from_values([0, 0, 1, 1], [0, 1, 0, 1], [0, 0, 1, 1])
assert result.isequal(expected)
# agg: any_value, first, last, first_index, last_index, count

np_dtype = np.dtype("(3,)uint16")
udt = dtypes.register_anonymous(np_dtype, "has_subdtype")
A = Matrix.new(udt, nrows=2, ncols=2)
A[:, :] = (1, 2, 3)
rows, cols, values = A.to_values()
assert_array_equal(values, np.array([[1, 2, 3]] * 4))
result = Matrix.from_values(rows, cols, values)
assert A.isequal(result)
assert result.isequal(A)
result = Matrix.from_values(rows, cols, values, dtype=udt)
assert result.isequal(A)
info = A.ss.export()
result = A.ss.import_any(**info)
assert result.isequal(A)
info = A.ss.export("coor")
result = A.ss.import_any(**info)
assert result.isequal(A)
22 changes: 19 additions & 3 deletions grblas/tests/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def test_get_typed_op():
)
with pytest.raises(ValueError, match="Unknown binary or aggregator"):
operator.get_typed_op("bad_op_name", dtypes.INT64, kind="binary|aggregator")
with pytest.raises(Exception):
# get_typed_op expects dtypes to already be dtypes
operator.get_typed_op(binary.plus, dtypes.INT64, "bad dtype")


def test_unaryop_udf():
Expand Down Expand Up @@ -726,7 +729,10 @@ def test_monoid_attributes():
assert monoid.numpy.add.binaryop is binary.numpy.add
assert monoid.numpy.add.identities == {typ: 0 for typ in monoid.numpy.add.types}

binop = BinaryOp.register_anonymous(lambda x, y: x + y, name="plus")
def plus(x, y): # pragma: no cover
return x + y

binop = BinaryOp.register_anonymous(plus, name="plus")
op = Monoid.register_anonymous(binop, 0, name="plus")
assert op.binaryop is binop
assert op[int].binaryop is binop[int]
Expand Down Expand Up @@ -1076,6 +1082,11 @@ def _udt_first(x, y):
assert udt_first(v, 1).new().isequal(v)
assert udt_first[udt, dtypes.INT64].return_type == udt
assert udt_first[dtypes.INT64, udt].return_type == dtypes.INT64
assert udt_first[udt, dtypes.BOOL].return_type == udt
assert udt_first[dtypes.BOOL, udt].return_type == dtypes.BOOL
udt_dup = dtypes.register_anonymous(record_dtype)
assert udt_first[udt, udt_dup].return_type == udt
# assert udt_first[udt_dup, udt].return_type == udt ?

udt_any = Monoid.register_new("udt_any", udt_first, (0, 0))
assert udt in udt_any
Expand Down Expand Up @@ -1103,20 +1114,25 @@ def _udt_first(x, y):
class BreakCompile:
pass

def badfunc(x):
def badfunc(x): # pragma: no cover
return BreakCompile(x)

badunary = UnaryOp.register_anonymous(badfunc, is_udt=True)
assert udt not in badunary
assert int not in badunary

def badfunc(x, y):
def badfunc(x, y): # pragma: no cover
return BreakCompile(x)

badbinary = BinaryOp.register_anonymous(badfunc, is_udt=True)
assert udt not in badbinary
assert int not in badbinary

assert binary.first[udt].return_type is udt
assert binary.first[udt].commutes_to is binary.second[udt]
assert semiring.any_firsti[int].commutes_to is semiring.any_secondj[int]
assert semiring.any_firsti[udt].commutes_to is semiring.any_secondj[udt]


def test_dir():
for mod in [unary, binary, monoid, semiring, op]:
Expand Down
14 changes: 10 additions & 4 deletions grblas/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,16 +1751,19 @@ def test_udt():
assert v.reduce(agg.first_index).new() == 0
assert v.reduce(agg.last_index).new() == 2
assert v.reduce(agg.count).new() == 3
info = v.ss.export()
result = Vector.ss.import_any(**info)
assert result.isequal(v)

# arrays as dtypes!
np_dtype = np.dtype("(3,)uint16")
udt = dtypes.register_anonymous(np_dtype, "has_subdtype")
s = Scalar.new(udt)
udt2 = dtypes.register_anonymous(np_dtype, "has_subdtype")
s = Scalar.new(udt2)
s.value = [0, 0, 0]

v = Vector.new(udt, size=2)
v = Vector.new(udt2, size=2)
v[0] = 0
w = Vector.new(udt, size=2)
w = Vector.new(udt2, size=2)
w[0] = (0, 0, 0)
assert v.isequal(w)
assert v.ewise_mult(w, binary.eq).new().reduce(monoid.land).new()
Expand All @@ -1778,6 +1781,9 @@ def test_udt():
assert_array_equal(values, np.ones((2, 3), dtype=np.uint16))
assert v.isequal(Vector.from_values(indices, values, dtype=v.dtype))
assert v.isequal(Vector.from_values(indices, values))
info = v.ss.export()
result = Vector.ss.import_any(**info)
assert result.isequal(v)

s = v.reduce(monoid.any).new()
assert (s.value == [1, 1, 1]).all()
Expand Down

0 comments on commit c9e8772

Please sign in to comment.