Skip to content

Commit

Permalink
Broadcasting! plus(A & v) -> any_plus(A @ v.diag()) (#203)
Browse files Browse the repository at this point in the history
* Broadcasting! `plus(A & v)` -> `any_plus(A @ v.diag())`

Fixes #197, and implements nearly everything in the alt proposal.

* Support ewise_union infix: `plus(x | y, left_default=1, right_default=2)`

* Almost forgot this! (require_monoid in vector.ewise_add broadcast recipe)

* Let's name the temp vectors too
  • Loading branch information
eriknw committed Apr 7, 2022
1 parent 2d9a296 commit d9e6921
Show file tree
Hide file tree
Showing 9 changed files with 304 additions and 75 deletions.
79 changes: 47 additions & 32 deletions grblas/_infixmethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,14 @@
def call_op(self, other, method, op, *, outer=False, union=False):
type1 = output_type(self)
type2 = output_type(other)
if (
type1 is type2
or type1 is Matrix
and type2 is TransposedMatrix
or type1 is TransposedMatrix
and type2 is Matrix
):
types = {Matrix, TransposedMatrix, Vector}
if type1 in types and type2 in types:
if outer:
return op(self | other, require_monoid=False)
return self.ewise_add(other, op, require_monoid=False)
elif union:
return self.ewise_union(other, op, False, False)
else:
return op(self & other)
return self.ewise_mult(other, op)
return op(self, other)


Expand Down Expand Up @@ -71,34 +66,44 @@ def __rxor__(self, other):


def __ixor__(self, other):
if output_type(other) in {Vector, Matrix, TransposedMatrix}:
if other.dtype != BOOL:
raise TypeError(
f"The __ixor__ infix operator, `x ^= y`, is not supported for {other.dtype.name} "
"dtype. It is only supported for BOOL dtype (and it uses ewise_add--the union)."
)
self(binary.lxor) << other
else:
other_type = output_type(other)
if (
other_type is Vector
and self.ndim == 2
or other_type not in {Vector, Matrix, TransposedMatrix}
):
self << __xor__(self, other)
elif other.dtype != BOOL:
raise TypeError(
f"The __ixor__ infix operator, `x ^= y`, is not supported for {other.dtype.name} "
"dtype. It is only supported for BOOL dtype (and it uses ewise_add--the union)."
)
else:
self(binary.lxor) << other
return self


def __ior__(self, other):
if output_type(other) in {Vector, Matrix, TransposedMatrix}:
if other.dtype != BOOL:
raise TypeError(
f"The __ior__ infix operator, `x |= y`, is not supported for {other.dtype.name} "
"dtype. It is only supported for BOOL dtype (and it uses ewise_add--the union)."
)
self(binary.lor) << other
else:
other_type = output_type(other)
if (
other_type is Vector
and self.ndim == 2
or other_type not in {Vector, Matrix, TransposedMatrix}
):
expr = call_op(self, other, "__ior__", binary.lor, outer=True)
if expr.dtype != BOOL:
raise TypeError(
f"The __ior__ infix operator, `x |= y`, is not supported for {expr.dtype.name} "
"dtype. It is only supported for BOOL dtype (and it uses ewise_add--the union)."
)
self << expr
elif other.dtype != BOOL:
raise TypeError(
f"The __ior__ infix operator, `x |= y`, is not supported for {other.dtype.name} "
"dtype. It is only supported for BOOL dtype (and it uses ewise_add--the union)."
)
else:
self(binary.lor) << other
return self


Expand Down Expand Up @@ -147,10 +152,15 @@ def __radd__(self, other):


def __iadd__(self, other):
if output_type(other) in {Vector, Matrix, TransposedMatrix}:
self(binary.plus) << other
else:
other_type = output_type(other)
if (
other_type is Vector
and self.ndim == 2
or other_type not in {Vector, Matrix, TransposedMatrix}
):
self << __add__(self, other)
else:
self(binary.plus) << other
return self


Expand Down Expand Up @@ -350,10 +360,15 @@ def __itruediv__(self, other):
lines.append(f"def __i{method}__(self, other):")
if method in outer:
lines.append(
" if output_type(other) in {Vector, Matrix, TransposedMatrix}:\n"
f" self(binary.{op}) << other\n"
f" else:\n"
f" self << __{method}__(self, other)"
" other_type = output_type(other)\n"
" if (\n"
" other_type is Vector\n"
" and self.ndim == 2\n"
" or other_type not in {Vector, Matrix, TransposedMatrix}\n"
" ):\n"
f" self << __{method}__(self, other)\n"
" else:\n"
f" self(binary.{op}) << other"
)
else:
lines.append(f" self << __{method}__(self, other)")
Expand Down
49 changes: 21 additions & 28 deletions grblas/infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,12 @@ class MatrixInfixExpr(InfixExprBase):

def __init__(self, left, right):
super().__init__(left, right)
self._nrows = left._nrows
self._ncols = left._ncols
if left.ndim == 1:
self._nrows = right._nrows
self._ncols = right._ncols
else:
self._nrows = left._nrows
self._ncols = left._ncols

@property
def nrows(self):
Expand Down Expand Up @@ -330,35 +334,24 @@ def _ewise_infix_expr(left, right, *, method, within):
left_type = output_type(left)
right_type = output_type(right)

if left_type in {Vector, Matrix, TransposedMatrix}:
if not (
left_type is right_type
or (left_type is Matrix and right_type is TransposedMatrix)
or (left_type is TransposedMatrix and right_type is Matrix)
):
if left_type is Vector:
right = left._expect_type(right, Vector, within=within, argname="right")
else:
right = left._expect_type(
right, (Matrix, TransposedMatrix), within=within, argname="right"
)
elif right_type is Vector:
left = right._expect_type(left, Vector, within=within, argname="left")
elif right_type is Matrix or right_type is TransposedMatrix:
left = right._expect_type(left, (Matrix, TransposedMatrix), within=within, argname="left")
types = {Vector, Matrix, TransposedMatrix}
if left_type in types and right_type in types:
# Create dummy expression to check compatibility of dimensions, etc.
expr = getattr(left, method)(right, binary.any)
if expr.output_type is Vector:
if method == "ewise_mult":
return VectorEwiseMultExpr(left, right)
return VectorEwiseAddExpr(left, right)
elif method == "ewise_mult":
return MatrixEwiseMultExpr(left, right)
return MatrixEwiseAddExpr(left, right)
if left_type in types:
left._expect_type(right, tuple(types), within=within, argname="right")
elif right_type in types:
right._expect_type(left, tuple(types), within=within, argname="left")
else: # pragma: no cover
raise TypeError(f"Bad types for ewise infix: {type(left).__name__}, {type(right).__name__}")

# Create dummy expression to check compatibility of dimensions, etc.
expr = getattr(left, method)(right, binary.any)
if expr.output_type is Vector:
if method == "ewise_mult":
return VectorEwiseMultExpr(left, right)
return VectorEwiseAddExpr(left, right)
elif method == "ewise_mult":
return MatrixEwiseMultExpr(left, right)
return MatrixEwiseAddExpr(left, right)


def _matmul_infix_expr(left, right, *, within):
left_type = output_type(left)
Expand Down
42 changes: 37 additions & 5 deletions grblas/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from ._ss.matrix import ss
from .base import BaseExpression, BaseType, call
from .dtypes import _INDEX, lookup_dtype, unify
from .exceptions import NoValue, check_status
from .exceptions import DimensionMismatch, NoValue, check_status
from .expr import AmbiguousAssignOrExtract, IndexerResolver, Updater
from .mask import StructuralMask, ValueMask
from .operator import get_typed_op
from .operator import get_semiring, get_typed_op
from .scalar import _MATERIALIZE, Scalar, ScalarExpression, _as_scalar
from .utils import (
_CArray,
Expand Down Expand Up @@ -442,7 +442,7 @@ def ewise_add(self, other, op=monoid.plus, *, require_monoid=True):
method_name = "ewise_add"
other = self._expect_type(
other,
(Matrix, TransposedMatrix),
(Matrix, TransposedMatrix, Vector),
within=method_name,
argname="other",
op=op,
Expand All @@ -460,6 +460,19 @@ def ewise_add(self, other, op=monoid.plus, *, require_monoid=True):
)
else:
self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op")
if other.ndim == 1:
# Broadcast rowwise from the right
# Can we do `C(M.S) << plus(A | v)` -> `C(M.S) << plus(any_second(M @ v.diag()) | A)`?
if self._ncols != other._size:
# Check this before we compute a possibly large matrix below
raise DimensionMismatch(
"Dimensions not compatible for broadcasting Vector from the right "
f"to rows of Matrix in {method_name}. Matrix.ncols (={self._ncols}) "
f"must equal Vector.size (={other._size})."
)
full = Vector.new(other.dtype, self._nrows, name="v_full")
full[:] = 0
other = full.outer(other, binary.second).new(name="M_temp")
expr = MatrixExpression(
method_name,
f"GrB_Matrix_eWiseAdd_{op.opclass}",
Expand All @@ -481,11 +494,13 @@ def ewise_mult(self, other, op=binary.times):
"""
method_name = "ewise_mult"
other = self._expect_type(
other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op
other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op
)
op = get_typed_op(op, self.dtype, other.dtype, kind="binary")
# Per the spec, op may be a semiring, but this is weird, so don't.
self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op")
if other.ndim == 1:
return self.mxm(other.diag(name="M_temp"), get_semiring(monoid.any, op))
expr = MatrixExpression(
method_name,
f"GrB_Matrix_eWiseMult_{op.opclass}",
Expand All @@ -512,7 +527,9 @@ def ewise_union(self, other, op, left_default, right_default):
"""
# SS, SuiteSparse-specific: eWiseUnion
method_name = "ewise_union"
other = self._expect_type(other, Matrix, within=method_name, argname="other", op=op)
other = self._expect_type(
other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op
)
if type(left_default) is not Scalar:
try:
left = Scalar.from_value(
Expand Down Expand Up @@ -551,11 +568,26 @@ def ewise_union(self, other, op, left_default, right_default):
self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op")
if op.opclass == "Monoid":
op = op.binaryop
if other.ndim == 1:
# Broadcast rowwise from the right
# Can we do `C(M.S) << plus(A | v)` -> `C(M.S) << plus(any_second(M @ v.diag()) | A)`?
if self._ncols != other._size:
# Check this before we compute a possibly large matrix below
raise DimensionMismatch(
"Dimensions not compatible for broadcasting Vector from the right "
f"to rows of Matrix in {method_name}. Matrix.ncols (={self._ncols}) "
f"must equal Vector.size (={other._size})."
)
full = Vector.new(other.dtype, self._nrows, name="v_full")
full[:] = 0
other = full.outer(other, binary.second).new(name="M_temp")
expr = MatrixExpression(
method_name,
"GxB_Matrix_eWiseUnion",
[self, left, other, right],
op=op,
at=self._is_transposed,
bt=other._is_transposed,
expr_repr="{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})",
)
if self.shape != other.shape:
Expand Down
39 changes: 37 additions & 2 deletions grblas/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,26 @@ class TypedBuiltinBinaryOp(TypedOpBase):
__slots__ = ()
opclass = "BinaryOp"

def __call__(self, left, right=None, *, require_monoid=None):
def __call__(
self, left, right=None, *, require_monoid=None, left_default=None, right_default=None
):
if left_default is not None or right_default is not None:
if (
left_default is None
or right_default is None
or require_monoid is not None
or right is not None
or not isinstance(left, InfixExprBase)
or left.method_name != "ewise_add"
):
raise TypeError(
"Specifying `left_default` or `right_default` keyword arguments implies "
"performing `ewise_union` operation with infix notation.\n"
"There is only one valid way to do this:\n\n"
f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y "
"are Vectors or Matrices, and left_default and right_default are scalars."
)
return left.left.ewise_union(left.right, self, left_default, right_default)
if require_monoid is not None:
if right is not None:
raise TypeError(
Expand Down Expand Up @@ -175,7 +194,23 @@ def __init__(self, parent, name, type_, return_type, gb_obj, gb_name):
super().__init__(parent, name, type_, return_type, gb_obj, gb_name)
self._identity = None

def __call__(self, left, right=None):
def __call__(self, left, right=None, *, left_default=None, right_default=None):
if left_default is not None or right_default is not None:
if (
left_default is None
or right_default is None
or right is not None
or not isinstance(left, InfixExprBase)
or left.method_name != "ewise_add"
):
raise TypeError(
"Specifying `left_default` or `right_default` keyword arguments implies "
"performing `ewise_union` operation with infix notation.\n"
"There is only one valid way to do this:\n\n"
f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y "
"are Vectors or Matrices, and left_default and right_default are scalars."
)
return left.left.ewise_union(left.right, self, left_default, right_default)
return _call_op(self, left, right)

@property
Expand Down
29 changes: 25 additions & 4 deletions grblas/tests/test_infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,6 @@ def test_bad_ewise(s1, v1, A1, A2):
(s1, v1),
(v1, 1),
(1, v1),
(v1, A1),
(A1, v1),
(v1, A1.T),
(A1.T, v1),
(A1, s1),
(s1, A1),
(A1.T, s1),
Expand All @@ -128,6 +124,31 @@ def test_bad_ewise(s1, v1, A1, A2):
left | right
with raises(TypeError, match="Bad type for argument"):
left & right
# These are okay now
for left, right in [
(A1, v1),
(v1, A1.T),
]:
left.ewise_add(right)
left | right
left.ewise_mult(right)
left & right
left.ewise_union(right, op.plus, 0, 0)
# Wrong dimension; can't broadcast
for left, right in [
(v1, A1),
(A1.T, v1),
]:
with raises(DimensionMismatch):
left | right
with raises(DimensionMismatch):
left & right
with raises(DimensionMismatch):
left.ewise_add(right)
with raises(DimensionMismatch):
left.ewise_mult(right)
with raises(DimensionMismatch):
left.ewise_union(right, op.plus, 0, 0)

w = v1[: v1.size - 1].new()
with raises(DimensionMismatch):
Expand Down
8 changes: 8 additions & 0 deletions grblas/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3088,6 +3088,14 @@ def test_ewise_union():
result = A1.ewise_union(A2, binary.plus, 10, 20).new()
expected = Matrix.from_values([0, 0], [0, 1], [21, 12], nrows=1, ncols=3)
assert result.isequal(expected)

# Test transposed
A2transposed = A2.T.new()
result = A1.ewise_union(A2transposed.T, binary.plus, 10, 20).new()
assert result.isequal(expected)
result = A1.T.ewise_union(A2transposed, binary.plus, 10, 20).new()
assert result.isequal(expected.T.new())

# Handle Scalars
result = A1.ewise_union(A2, binary.plus, Scalar.from_value(10), Scalar.from_value(20)).new()
assert result.isequal(expected)
Expand Down

0 comments on commit d9e6921

Please sign in to comment.