Skip to content

Commit

Permalink
A few finishing touches
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Apr 9, 2022
1 parent c9e8772 commit 21f483b
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 56 deletions.
2 changes: 1 addition & 1 deletion grblas/binary/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"subtract": "minus",
"true_divide": "truediv",
}
_graphblas_to_numpy = {val: key for key, val in _numpy_to_graphblas.items()}
# _graphblas_to_numpy = {val: key for key, val in _numpy_to_graphblas.items()} # Soon...
# Not included: maximum, minimum, gcd, hypot, logaddexp, logaddexp2
# lcm, left_shift, nextafter, right_shift

Expand Down
1 change: 0 additions & 1 deletion grblas/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def register_anonymous(dtype, name=None):
status = lib.GrB_Type_new(gb_obj, dtype.itemsize)
check_status_carg(status, "Type", gb_obj[0])
# For now, let's use "opaque" unsigned bytes for the c type.
# grb_name probably isn't useful, right?
rv = DataType(name, gb_obj, None, f"uint8_t[{dtype.itemsize}]", numba_type, dtype)
if dtype not in _registry:
_registry[gb_obj] = rv
Expand Down
17 changes: 9 additions & 8 deletions grblas/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,17 @@ def build(self, rows, columns, values, *, dup_op=None, clear=False, nrows=None,
return

dup_op_given = dup_op is not None
if self.dtype._is_udt: # XXX
dup_op = None
else:
if not dup_op_given:
if not dup_op_given:
if not self.dtype._is_udt:
dup_op = binary.plus
dup_op = get_typed_op(dup_op, self.dtype, kind="binary")
if dup_op.opclass == "Monoid":
dup_op = dup_op.binaryop
else:
self._expect_op(dup_op, "BinaryOp", within="build", argname="dup_op")
dup_op = binary.any
# SS:SuiteSparse-specific: we could use NULL for dup_op
dup_op = get_typed_op(dup_op, self.dtype, kind="binary")
if dup_op.opclass == "Monoid":
dup_op = dup_op.binaryop
else:
self._expect_op(dup_op, "BinaryOp", within="build", argname="dup_op")

rows = _CArray(rows)
columns = _CArray(columns)
Expand Down
2 changes: 1 addition & 1 deletion grblas/monoid/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
"logical_xor": "lxor",
"multiply": "times",
}
_graphblas_to_numpy = {val: key for key, val in _numpy_to_graphblas.items()}
# _graphblas_to_numpy = {val: key for key, val in _numpy_to_graphblas.items()} # Soon...
# Not included: maximum, minimum, gcd, hypot, logaddexp, logaddexp2


Expand Down
90 changes: 56 additions & 34 deletions grblas/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,32 @@ def _udt_mask(dtype):


class TypedOpBase:
__slots__ = "parent", "name", "type", "return_type", "gb_obj", "gb_name", "__weakref__"
__slots__ = (
"parent",
"name",
"type",
"return_type",
"gb_obj",
"gb_name",
"_type2",
"__weakref__",
)

def __init__(self, parent, name, type_, return_type, gb_obj, gb_name):
def __init__(self, parent, name, type_, return_type, gb_obj, gb_name, dtype2=None):
self.parent = parent
self.name = name
self.type = type_
self.return_type = return_type
self.gb_obj = gb_obj
self.gb_name = gb_name
self._type2 = dtype2

def __repr__(self):
classname = self.opclass.lower()
if classname.endswith("op"):
classname = classname[:-2]
return f"{classname}.{self.name}[{self.type}]"
dtype2 = "" if self._type2 is None else f", {self._type2.name}"
return f"{classname}.{self.name}[{self.type.name}{dtype2}]"

@property
def _carg(self):
Expand All @@ -150,7 +161,10 @@ def is_positional(self):
return self.parent.is_positional

def __reduce__(self):
return (getitem, (self.parent, self.type))
if self._type2 is None or self.type == self._type2:
return (getitem, (self.parent, self.type))
else:
return (getitem, (self.parent, (self.type, self._type2)))


class TypedBuiltinUnaryOp(TypedOpBase):
Expand Down Expand Up @@ -228,6 +242,10 @@ def _semiring_commutes_to(self):
def is_commutative(self):
return self.commutes_to is self

@property
def type2(self):
return self.type if self._type2 is None else self._type2


class TypedBuiltinMonoid(TypedOpBase):
__slots__ = "_identity"
Expand Down Expand Up @@ -280,6 +298,10 @@ def binaryop(self):
def commutes_to(self):
return self

@property
def type2(self):
return self.type


class TypedBuiltinSemiring(TypedOpBase):
__slots__ = ()
Expand Down Expand Up @@ -318,6 +340,8 @@ def commutes_to(self):
def is_commutative(self):
return self.binaryop.is_commutative

type2 = TypedBuiltinBinaryOp.type2


class TypedUserUnaryOp(TypedOpBase):
__slots__ = ()
Expand All @@ -341,8 +365,8 @@ class TypedUserBinaryOp(TypedOpBase):
__slots__ = "_monoid"
opclass = "BinaryOp"

def __init__(self, parent, name, type_, return_type, gb_obj):
super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}")
def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None):
super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2)
self._monoid = None

@property
Expand All @@ -356,9 +380,10 @@ def monoid(self):
commutes_to = TypedBuiltinBinaryOp.commutes_to
_semiring_commutes_to = TypedBuiltinBinaryOp._semiring_commutes_to
is_commutative = TypedBuiltinBinaryOp.is_commutative
__call__ = TypedBuiltinBinaryOp.__call__
orig_func = TypedUserUnaryOp.orig_func
_numba_func = TypedUserUnaryOp._numba_func
type2 = TypedBuiltinBinaryOp.type2
__call__ = TypedBuiltinBinaryOp.__call__


class TypedUserMonoid(TypedOpBase):
Expand All @@ -373,20 +398,22 @@ def __init__(self, parent, name, type_, return_type, gb_obj, binaryop, identity)
binaryop._monoid = self

commutes_to = TypedBuiltinMonoid.commutes_to
type2 = TypedBuiltinMonoid.type2
__call__ = TypedBuiltinMonoid.__call__


class TypedUserSemiring(TypedOpBase):
__slots__ = "monoid", "binaryop"
opclass = "Semiring"

def __init__(self, parent, name, type_, return_type, gb_obj, monoid, binaryop):
super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}")
def __init__(self, parent, name, type_, return_type, gb_obj, monoid, binaryop, dtype2=None):
super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2)
self.monoid = monoid
self.binaryop = binaryop

commutes_to = TypedBuiltinSemiring.commutes_to
is_commutative = TypedBuiltinSemiring.is_commutative
type2 = TypedBuiltinBinaryOp.type2
__call__ = TypedBuiltinSemiring.__call__


Expand Down Expand Up @@ -649,7 +676,12 @@ def __repr__(self):
return f"{self._modname}.{self.name}"

def __getitem__(self, type_):
if not self._is_udt:
if type(type_) is tuple:
dtype1, dtype2 = type_
dtype1 = lookup_dtype(dtype1)
dtype2 = lookup_dtype(dtype2)
return get_typed_op(self, dtype1, dtype2)
elif not self._is_udt:
type_ = lookup_dtype(type_)
if type_ not in self._typed_ops:
if self._udt_types is None:
Expand All @@ -659,13 +691,8 @@ def __getitem__(self, type_):
else:
return self._typed_ops[type_]
# This is a UDT or is able to operate on UDTs such as `first` any `any`
if type(type_) is tuple:
dtype1, dtype2 = type_
dtype1 = lookup_dtype(dtype1)
dtype2 = lookup_dtype(dtype2)
else:
dtype1 = dtype2 = lookup_dtype(type_)
return self._getitem_udt(dtype1, dtype2)
dtype = lookup_dtype(type_)
return self._compile_udt(dtype, dtype)

def _add(self, op):
self._typed_ops[op.type] = op
Expand All @@ -677,12 +704,6 @@ def __delitem__(self, type_):
del self.types[type_]

def __contains__(self, type_):
if not self._is_udt:
type_ = lookup_dtype(type_)
if type_ in self._typed_ops or self.is_positional:
return True
elif self._udt_types is None:
return False
try:
self[type_]
except (TypeError, KeyError, numba.NumbaError):
Expand Down Expand Up @@ -927,7 +948,7 @@ def unary_wrapper(z, x):
else:
raise UdfParseError("Unable to parse function using Numba")

def _getitem_udt(self, dtype, dtype2):
def _compile_udt(self, dtype, dtype2):
if dtype in self._udt_types:
return self._udt_ops[dtype]

Expand Down Expand Up @@ -1379,7 +1400,7 @@ def binary_wrapper(z, x, y):
else:
raise UdfParseError("Unable to parse function using Numba")

def _getitem_udt(self, dtype, dtype2):
def _compile_udt(self, dtype, dtype2):
if dtype2 is None:
dtype2 = dtype
dtypes = (dtype, dtype2)
Expand Down Expand Up @@ -1434,7 +1455,6 @@ def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover
nt.CPointer(UINT8.numba_type),
nt.CPointer(UINT8.numba_type),
)

if mask.all():

def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover
Expand Down Expand Up @@ -1485,9 +1505,10 @@ def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover
op = TypedUserBinaryOp(
self,
self.name,
dtype, # Should we do a tuple for both inputs?
dtype,
ret_type,
new_binary[0],
dtype2=dtype2,
)
self._udt_types[dtypes] = ret_type
self._udt_ops[dtypes] = op
Expand Down Expand Up @@ -1775,7 +1796,7 @@ def _build(cls, name, binaryop, identity, *, anonymous=False):
new_type_obj._add(op)
return new_type_obj

def _getitem_udt(self, dtype, dtype2):
def _compile_udt(self, dtype, dtype2):
if dtype2 is None:
dtype2 = dtype
elif dtype != dtype2:
Expand All @@ -1785,7 +1806,7 @@ def _getitem_udt(self, dtype, dtype2):
)
if dtype in self._udt_types:
return self._udt_ops[dtype]
binaryop = self.binaryop._getitem_udt(dtype, dtype2)
binaryop = self.binaryop._compile_udt(dtype, dtype2)
from .scalar import Scalar

ret_type = binaryop.return_type
Expand Down Expand Up @@ -1998,13 +2019,13 @@ def _build(cls, name, monoid, binaryop, *, anonymous=False):
new_type_obj._add(op)
return new_type_obj

def _getitem_udt(self, dtype, dtype2):
def _compile_udt(self, dtype, dtype2):
if dtype2 is None:
dtype2 = dtype
dtypes = (dtype, dtype2)
if dtypes in self._udt_types:
return self._udt_ops[dtypes]
binaryop = self.binaryop._getitem_udt(dtype, dtype2)
binaryop = self.binaryop._compile_udt(dtype, dtype2)
monoid = self.monoid[binaryop.return_type]
ret_type = monoid.return_type
new_semiring = ffi_new("GrB_Semiring*")
Expand All @@ -2013,11 +2034,12 @@ def _getitem_udt(self, dtype, dtype2):
op = TypedUserSemiring(
new_semiring,
self.name,
dtype, # Should we do a tuple for both inputs?
dtype,
ret_type,
new_semiring[0],
monoid,
binaryop,
dtype2=dtype2,
)
self._udt_types[dtypes] = dtype
self._udt_ops[dtypes] = op
Expand Down Expand Up @@ -2265,7 +2287,7 @@ def _is_udt(self):
def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scalar=False, kind=None):
if isinstance(op, OpBase):
if op._is_udt:
return op._getitem_udt(dtype, dtype2)
return op._compile_udt(dtype, dtype2)
if dtype2 is not None:
try:
dtype = unify(
Expand All @@ -2276,7 +2298,7 @@ def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scala
return op[UINT64]
if op._udt_types is None:
raise
return op._getitem_udt(dtype, dtype2)
return op._compile_udt(dtype, dtype2)
return op[dtype]
elif isinstance(op, ParameterizedUdf):
op = op() # Use default parameters of parameterized UDFs
Expand Down
Binary file modified grblas/tests/pickle3.pkl
Binary file not shown.
10 changes: 9 additions & 1 deletion grblas/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3195,7 +3195,15 @@ def test_udt():
result = unary.positioni(A).new()
expected = Matrix.from_values([0, 0, 1, 1], [0, 1, 0, 1], [0, 0, 1, 1])
assert result.isequal(expected)
# agg: any_value, first, last, first_index, last_index, count

# Just make sure these work
for aggop in [agg.any_value, agg.first, agg.last, agg.count]:
A.reduce_rowwise(aggop).new()
A.reduce_columnwise(aggop).new()
A.reduce_scalar(aggop).new()
for aggop in [agg.first_index, agg.last_index]:
A.reduce_rowwise(aggop).new()
A.reduce_columnwise(aggop).new()

np_dtype = np.dtype("(3,)uint16")
udt = dtypes.register_anonymous(np_dtype, "has_subdtype")
Expand Down
6 changes: 6 additions & 0 deletions grblas/tests/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,12 @@ def badfunc(x, y): # pragma: no cover
assert semiring.any_firsti[int].commutes_to is semiring.any_secondj[int]
assert semiring.any_firsti[udt].commutes_to is semiring.any_secondj[udt]

assert binary.second[udt].type is udt
assert binary.second[udt].type2 is udt
assert binary.second[udt, dtypes.INT8].type is udt
assert binary.second[udt, dtypes.INT8].type2 is dtypes.INT8
assert monoid.any[udt].type2 is udt


def test_dir():
for mod in [unary, binary, monoid, semiring, op]:
Expand Down
4 changes: 4 additions & 0 deletions grblas/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,7 @@ def test_udt():
expected[0, 0] = (1, 2)
expected[0, 1] = (3, 4)
assert expected.isequal(A)

any_udt = d["any[udt]"]
assert any_udt is gb.binary.any[udt3]
assert pickle.loads(pickle.dumps(gb.binary.first[udt, int])) is gb.binary.first[udt, int]
5 changes: 5 additions & 0 deletions grblas/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,8 @@ def test_udt():
info = v.ss.export()
result = Vector.ss.import_any(**info)
assert result.isequal(v)
for aggop in [agg.any_value, agg.first, agg.last, agg.count, agg.first_index, agg.last_index]:
v.reduce(aggop).new()

# arrays as dtypes!
np_dtype = np.dtype("(3,)uint16")
Expand Down Expand Up @@ -1796,6 +1798,9 @@ def test_udt():
t = Scalar.from_value(1)
assert s == t
assert t == s
# Just make sure these work
for aggop in [agg.any_value, agg.first, agg.last, agg.count, agg.first_index, agg.last_index]:
v.reduce(aggop).new()


def test_infix_outer():
Expand Down
2 changes: 1 addition & 1 deletion grblas/unary/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
_unary_names.update({"conj", "conjugate"})
_numpy_to_graphblas["conj"] = "conj"
_numpy_to_graphblas["conjugate"] = "conj"
_graphblas_to_numpy = {val: key for key, val in _numpy_to_graphblas.items()}
# _graphblas_to_numpy = {val: key for key, val in _numpy_to_graphblas.items()} # Soon...

_operator._STANDARD_OPERATOR_NAMES.update(f"unary.numpy.{name}" for name in _unary_names)
__all__ = list(_unary_names)
Expand Down

0 comments on commit 21f483b

Please sign in to comment.