Skip to content

Commit

Permalink
Merge branch 'main' into graphblas_python2
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Apr 10, 2022
2 parents 4c31d5c + 8e9acb9 commit 56784e9
Show file tree
Hide file tree
Showing 26 changed files with 1,739 additions and 512 deletions.
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,3 @@ m = gb.io.to_scipy_sparse_matrix(m, format='csr')
A = gb.io.from_networkx(g)
g = gb.io.to_networkx(A)
```

## Attribution
This library borrows some great ideas from [pygraphblas](https://github.com/michelp/pygraphblas),
especially around parsing operator names from SuiteSparse and the concept of a Scalar which the backend
implementation doesn't need to know about.

8 changes: 0 additions & 8 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,6 @@ Import/Export connectors to the Python ecosystem
A = gb.io.from_networkx(g)
g = gb.io.to_networkx(A)
Attribution
-----------

This library borrows some great ideas from `pygraphblas <https://github.com/michelp/pygraphblas>`_,
especially around parsing operator names from SuiteSparse and the concept of a Scalar which the backend
implementation doesn't need to know about.


Indices and tables
==================

Expand Down
44 changes: 29 additions & 15 deletions graphblas/_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np

from . import agg, binary, monoid, semiring, unary
from .dtypes import lookup_dtype, unify
from .operator import _normalize_type
from .dtypes import INT64, lookup_dtype
from .operator import get_typed_op
from .scalar import Scalar


Expand All @@ -14,8 +14,8 @@ def _get_types(ops, initdtype):
if initdtype is None:
prev = dict(ops[0].types)
else:
initdtype = lookup_dtype(initdtype)
prev = {key: unify(lookup_dtype(val), initdtype).name for key, val in ops[0].types.items()}
op = ops[0]
prev = {key: get_typed_op(op, key, initdtype).return_type for key in op.types}
for op in ops[1:]:
cur = {}
types = op.types
Expand Down Expand Up @@ -43,11 +43,12 @@ def __init__(
composite=None,
custom=None,
types=None,
any_dtype=None,
):
self.name = name
self._initval_orig = initval
self._initval = False if initval is None else initval
self._initdtype = lookup_dtype(type(self._initval))
self._initdtype = lookup_dtype(type(self._initval), self._initval)
self._monoid = monoid
self._semiring = semiring
self._semiring2 = semiring2
Expand All @@ -68,6 +69,7 @@ def __init__(
self._types_orig = types
self._types = None
self._typed_ops = {}
self._any_dtype = any_dtype

@property
def types(self):
Expand All @@ -82,16 +84,16 @@ def types(self):
return self._types

def __getitem__(self, dtype):
dtype = _normalize_type(dtype)
if dtype not in self.types:
dtype = lookup_dtype(dtype)
if not self._any_dtype and dtype not in self.types:
raise KeyError(f"{self.name} does not work with {dtype}")
if dtype not in self._typed_ops:
self._typed_ops[dtype] = TypedAggregator(self, dtype)
return self._typed_ops[dtype]

def __contains__(self, dtype):
dtype = _normalize_type(dtype)
return dtype in self.types
dtype = lookup_dtype(dtype)
return self._any_dtype or dtype in self.types

def __repr__(self):
return f"agg.{self.name}"
Expand All @@ -107,7 +109,12 @@ def __init__(self, agg, dtype):
self.name = agg.name
self.parent = agg
self.type = dtype
self.return_type = agg.types[dtype]
if dtype in agg.types:
self.return_type = agg.types[dtype]
elif agg._any_dtype is True:
self.return_type = dtype
else:
self.return_type = agg._any_dtype

def __repr__(self):
return f"agg.{self.name}[{self.type}]"
Expand Down Expand Up @@ -160,8 +167,7 @@ def _new(self, updater, expr, *, in_composite=False):
if agg._custom is not None:
return agg._custom(self, updater, expr, in_composite=in_composite)

dtype = unify(lookup_dtype(self.type), lookup_dtype(agg._initdtype))
semiring = agg._semiring[dtype]
semiring = get_typed_op(agg._semiring, self.type, agg._initdtype)
if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator":
# Matrix -> Vector
A = expr.args[0]
Expand Down Expand Up @@ -242,13 +248,15 @@ def __reduce__(self):
agg.any = Aggregator("any", monoid=monoid.lor)
agg.min = Aggregator("min", monoid=monoid.min)
agg.max = Aggregator("max", monoid=monoid.max)
agg.any_value = Aggregator("any_value", monoid=monoid.any)
agg.any_value = Aggregator("any_value", monoid=monoid.any, any_dtype=True)
agg.bitwise_all = Aggregator("bitwise_all", monoid=monoid.band)
agg.bitwise_any = Aggregator("bitwise_any", monoid=monoid.bor)
# Other monoids: bxnor bxor eq lxnor lxor

# Semiring-only
agg.count = Aggregator("count", semiring=semiring.plus_pair, semiring2=semiring.plus_first)
agg.count = Aggregator(
"count", semiring=semiring.plus_pair, semiring2=semiring.plus_first, any_dtype=INT64
)
agg.count_nonzero = Aggregator(
"count_nonzero", semiring=semiring.plus_isne, semiring2=semiring.plus_first
)
Expand All @@ -264,7 +272,9 @@ def __reduce__(self):
semiring=semiring.plus_pow,
semiring2=semiring.plus_first,
)
agg.exists = Aggregator("exists", semiring=semiring.any_pair, semiring2=semiring.any_pair)
agg.exists = Aggregator(
"exists", semiring=semiring.any_pair, semiring2=semiring.any_pair, any_dtype=INT64
)

# Semiring and finalize
agg.hypot = Aggregator(
Expand Down Expand Up @@ -564,11 +574,13 @@ def _first_last(agg, updater, expr, *, in_composite, semiring_):
"first",
custom=partial(_first_last, semiring_=semiring.min_secondi),
types=[binary.first],
any_dtype=True,
)
agg.last = Aggregator(
"last",
custom=partial(_first_last, semiring_=semiring.max_secondi),
types=[binary.second],
any_dtype=True,
)


Expand Down Expand Up @@ -601,9 +613,11 @@ def _first_last_index(agg, updater, expr, *, in_composite, semiring):
"first_index",
custom=partial(_first_last_index, semiring=semiring.min_secondi),
types=[semiring.min_secondi],
any_dtype=INT64,
)
agg.last_index = Aggregator(
"last_index",
custom=partial(_first_last_index, semiring=semiring.max_secondi),
types=[semiring.min_secondi],
any_dtype=INT64,
)
32 changes: 18 additions & 14 deletions graphblas/_ss/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
parent = self._parent
else:
parent = self._parent.dup(name=f"M_{method}")
dtype = np.dtype(parent.dtype.np_type)
dtype = parent.dtype.np_type
index_dtype = np.dtype(np.uint64)

nrows = parent._nrows
Expand Down Expand Up @@ -847,25 +847,27 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
parent.gb_obj = ffi.NULL
else:
parent.clear()
return rv
elif format == "coor":
info = self._export(
rv = self._export(
"csr", sort=sort, give_ownership=give_ownership, raw=False, method=method
)
info["rows"] = indptr_to_indices(info.pop("indptr"))
info["cols"] = info.pop("col_indices")
info["sorted_rows"] = True
info["format"] = "coor"
return info
rv["rows"] = indptr_to_indices(rv.pop("indptr"))
rv["cols"] = rv.pop("col_indices")
rv["sorted_rows"] = True
rv["format"] = "coor"
elif format == "cooc":
info = self._export(
rv = self._export(
"csc", sort=sort, give_ownership=give_ownership, raw=False, method=method
)
info["cols"] = indptr_to_indices(info.pop("indptr"))
info["rows"] = info.pop("row_indices")
info["sorted_cols"] = True
info["format"] = "cooc"
return info
rv["cols"] = indptr_to_indices(rv.pop("indptr"))
rv["rows"] = rv.pop("row_indices")
rv["sorted_cols"] = True
rv["format"] = "cooc"
else:
raise ValueError(f"Invalid format: {format}")
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv

if method == "export":
mhandle = ffi_new("GrB_Matrix*", parent._carg)
Expand Down Expand Up @@ -1175,6 +1177,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
rv["values"] = values
if method == "export":
parent.gb_obj = ffi.NULL
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion graphblas/_ss/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
parent = self._parent
else:
parent = self._parent.dup(name=f"v_{method}")
dtype = np.dtype(parent.dtype.np_type)
dtype = parent.dtype.np_type
index_dtype = np.dtype(np.uint64)

if format is None:
Expand Down Expand Up @@ -481,6 +481,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m
)
if method == "export":
parent.gb_obj = ffi.NULL
if parent.dtype._is_udt:
rv["dtype"] = parent.dtype
return rv

@classmethod
Expand Down
3 changes: 1 addition & 2 deletions graphblas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ def update(self, expr):
return self._update(expr)

def _update(self, expr, mask=None, accum=None, replace=False, input_mask=None):
# TODO: check expected output type (now included in Expression object)
if not isinstance(expr, BaseExpression):
if isinstance(expr, AmbiguousAssignOrExtract):
if expr._is_scalar and self._is_scalar:
Expand Down Expand Up @@ -557,7 +556,7 @@ def __init__(
raise ValueError(f"No default expr_repr for len(args) == {len(args)}")
self.expr_repr = expr_repr
if dtype is None:
self.dtype = lookup_dtype(op.return_type)
self.dtype = op.return_type
else:
self.dtype = dtype
self._value = None
Expand Down
1 change: 1 addition & 0 deletions graphblas/binary/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"subtract": "minus",
"true_divide": "truediv",
}
# _graphblas_to_numpy = {val: key for key, val in _numpy_to_graphblas.items()} # Soon...
# Not included: maximum, minimum, gcd, hypot, logaddexp, logaddexp2
# lcm, left_shift, nextafter, right_shift

Expand Down

0 comments on commit 56784e9

Please sign in to comment.