Skip to content

Commit

Permalink
[sparse] change call signature of coo primitive wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 7, 2022
1 parent 03a50c0 commit 424536d
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 81 deletions.
195 changes: 132 additions & 63 deletions jax/experimental/sparse/coo.py
Expand Up @@ -15,7 +15,7 @@
"""COO (coordinate format) matrix object and associated primitives."""

import operator
from typing import Tuple
from typing import Any, NamedTuple, Tuple
import warnings

import numpy as np
Expand All @@ -39,6 +39,14 @@
except ImportError:
hipsparse = None


Dtype = Any
Shape = Tuple[int, ...]

class COOInfo(NamedTuple):
shape: Shape


@tree_util.register_pytree_node_class
class COO(JAXSparse):
"""Experimental COO matrix implemented in JAX; API subject to change."""
Expand All @@ -48,16 +56,16 @@ class COO(JAXSparse):
shape: Tuple[int, int]
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)
_info = property(lambda self: COOInfo(self.shape))
_bufs = property(lambda self: (self.data, self.row, self.col))

def __init__(self, args, *, shape):
self.data, self.row, self.col = _safe_asarray(args)
super().__init__(args, shape=shape)

@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
if nse is None:
nse = (mat != 0).sum()
return cls(coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape)
return coo_fromdense(mat, nse=nse, index_dtype=index_dtype)

@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
Expand All @@ -70,24 +78,26 @@ def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
return cls((data, row, col), shape=shape)

def todense(self):
return coo_todense(self.data, self.row, self.col, shape=self.shape)
return coo_todense(self)

def transpose(self, axes=None):
assert axes is None
if axes is not None:
raise NotImplementedError("axes argument to transpose()")
return COO((self.data, self.col, self.row), shape=self.shape[::-1])

def tree_flatten(self):
return (self.data, self.row, self.col), {"shape": self.shape}
return (self.data, self.row, self.col), self._info._asdict()

def __matmul__(self, other):
if isinstance(other, JAXSparse):
raise NotImplementedError("matmul between two sparse objects.")
other = jnp.asarray(other)
data, other = _promote_dtypes(self.data, other)
self_promoted = COO((data, self.row, self.col), **self._info._asdict())
if other.ndim == 1:
return coo_matvec(data, self.row, self.col, other, shape=self.shape)
return coo_matvec(self_promoted, other)
elif other.ndim == 2:
return coo_matmat(data, self.row, self.col, other, shape=self.shape)
return coo_matmat(self_promoted, other)
else:
raise NotImplementedError(f"matmul with object of shape {other.shape}")

Expand All @@ -96,54 +106,64 @@ def __matmul__(self, other):

coo_todense_p = core.Primitive('coo_todense')

def coo_todense(data, row, col, *, shape):
def coo_todense(mat):
"""Convert a COO-format sparse matrix to a dense matrix.
Args:
mat : COO matrix
Returns:
mat_dense: dense version of ``mat``
"""
return _coo_todense(mat.data, mat.row, mat.col, spinfo=mat._info)

def _coo_todense(data, row, col, *, spinfo):
"""Convert CSR-format sparse matrix to a dense matrix.
Args:
data : array of shape ``(nse,)``.
row : array of shape ``(nse,)``
col : array of shape ``(nse,)`` and dtype ``row.dtype``
shape : length-2 tuple representing the matrix shape
spinfo : COOInfo object containing matrix metadata
Returns:
mat : array with specified shape and dtype matching ``data``
"""
return coo_todense_p.bind(data, row, col, shape=shape)
return coo_todense_p.bind(data, row, col, spinfo=spinfo)

@coo_todense_p.def_impl
def _coo_todense_impl(data, row, col, *, shape):
return jnp.zeros(shape, data.dtype).at[row, col].add(data)
def _coo_todense_impl(data, row, col, *, spinfo):
return jnp.zeros(spinfo.shape, data.dtype).at[row, col].add(data)

@coo_todense_p.def_abstract_eval
def _coo_todense_abstract_eval(data, row, col, *, shape):
return core.ShapedArray(shape, data.dtype)
def _coo_todense_abstract_eval(data, row, col, *, spinfo):
return core.ShapedArray(spinfo.shape, data.dtype)

_coo_todense_translation_rule = xla.lower_fun(
_coo_todense_impl, multiple_results=False, new_style=True)

def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
*, shape):
*, spinfo):
dtype = avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col,
shape=shape)
spinfo=spinfo)
if cusparse is not None:
return [cusparse.coo_todense(ctx.builder, data, row, col, shape=shape)]
return [cusparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape)]
else:
return [hipsparse.coo_todense(ctx.builder, data, row, col, shape=shape)]
return [hipsparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape)]

def _coo_todense_jvp(data_dot, data, row, col, *, shape):
return coo_todense(data_dot, row, col, shape=shape)
def _coo_todense_jvp(data_dot, data, row, col, *, spinfo):
return _coo_todense(data_dot, row, col, spinfo=spinfo)

def _coo_todense_transpose(ct, data, row, col, *, shape):
def _coo_todense_transpose(ct, data, row, col, *, spinfo):
# Note: we assume that transpose has the same sparsity pattern.
# Can we check this?
assert ad.is_undefined_primal(data)
if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.shape == shape
assert ct.shape == spinfo.shape
assert row.aval.dtype == col.aval.dtype
assert ct.dtype == data.aval.dtype
return _coo_extract(row, col, ct), row, col
Expand All @@ -161,7 +181,24 @@ def _coo_todense_transpose(ct, data, row, col, *, shape):
coo_fromdense_p = core.Primitive('coo_fromdense')
coo_fromdense_p.multiple_results = True

def coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
def coo_fromdense(mat, *, nse=None, index_dtype=jnp.int32):
"""Create a COO-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to COO.
nse : number of specified entries in ``mat``. If not specified,
it will be computed from the input matrix.
index_dtype : dtype of sparse indices
Returns:
mat_coo : COO representation of the matrix.
"""
if nse is None:
nse = (mat != 0).sum()
nse = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
return COO(_coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape)

def _coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
"""Create COO-format sparse matrix from a dense matrix.
Args:
Expand Down Expand Up @@ -220,7 +257,7 @@ def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
M, = primals
Mdot, = tangents

primals_out = coo_fromdense(M, nse=nse, index_dtype=index_dtype)
primals_out = _coo_fromdense(M, nse=nse, index_dtype=index_dtype)
data, row, col = primals_out

if type(Mdot) is ad.Zero:
Expand All @@ -239,7 +276,7 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):
if isinstance(row, ad.Zero) or isinstance(col, ad.Zero):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ad.is_undefined_primal(M)
return coo_todense(data, row, col, shape=M.aval.shape)
return _coo_todense(data, row, col, spinfo=COOInfo(shape=M.aval.shape))

ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
Expand All @@ -255,7 +292,23 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):

coo_matvec_p = core.Primitive('coo_matvec')

def coo_matvec(data, row, col, v, *, shape, transpose=False):
def coo_matvec(mat, v, transpose=False):
"""Product of COO sparse matrix and a dense vector.
Args:
mat : COO matrix
v : one-dimensional array of size ``(shape[0] if transpose else shape[1],)`` and
dtype ``mat.dtype``
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
y : array of shape ``(mat.shape[1] if transpose else mat.shape[0],)`` representing
the matrix vector product.
"""
return _coo_matvec(*mat._bufs, v, spinfo=mat._info, transpose=transpose)

def _coo_matvec(data, row, col, v, *, spinfo, transpose=False):
"""Product of COO sparse matrix and a dense vector.
Args:
Expand All @@ -272,58 +325,58 @@ def coo_matvec(data, row, col, v, *, shape, transpose=False):
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
return coo_matvec_p.bind(data, row, col, v, shape=shape, transpose=transpose)
return coo_matvec_p.bind(data, row, col, v, spinfo=spinfo, transpose=transpose)

@coo_matvec_p.def_impl
def _coo_matvec_impl(data, row, col, v, *, shape, transpose):
def _coo_matvec_impl(data, row, col, v, *, spinfo, transpose):
v = jnp.asarray(v)
if transpose:
row, col = col, row
out_shape = shape[1] if transpose else shape[0]
out_shape = spinfo.shape[1] if transpose else spinfo.shape[0]
dv = data * v[col]
return jnp.zeros(out_shape, dv.dtype).at[row].add(dv)

@coo_matvec_p.def_abstract_eval
def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose):
assert data.shape == row.shape == col.shape
assert data.dtype == v.dtype
assert row.dtype == col.dtype
assert len(shape) == 2
assert len(spinfo.shape) == 2
assert v.ndim == 1
assert v.shape[0] == (shape[0] if transpose else shape[1])
out_shape = shape[1] if transpose else shape[0]
assert v.shape[0] == (spinfo.shape[0] if transpose else spinfo.shape[1])
out_shape = spinfo.shape[1] if transpose else spinfo.shape[0]
return core.ShapedArray((out_shape,), data.dtype)

_coo_matvec_translation_rule = xla.lower_fun(
_coo_matvec_impl, multiple_results=False, new_style=True)

def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
v, *, shape, transpose):
v, *, spinfo, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v,
shape=shape, transpose=transpose)
spinfo=spinfo, transpose=transpose)
if cusparse is not None:
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=spinfo.shape,
transpose=transpose)]
else:
return [hipsparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
return [hipsparse.coo_matvec(ctx.builder, data, row, col, v, shape=spinfo.shape,
transpose=transpose)]

def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, shape, transpose):
return coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose)
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose):
return _coo_matvec(data_dot, row, col, v, spinfo=spinfo, transpose=transpose)

def _coo_matvec_jvp_vec(v_dot, data, row, col, v, *, shape, transpose):
return coo_matvec(data, row, col, v_dot, shape=shape, transpose=transpose)
def _coo_matvec_jvp_vec(v_dot, data, row, col, v, *, spinfo, transpose):
return _coo_matvec(data, row, col, v_dot, spinfo=spinfo, transpose=transpose)

def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose):
assert not ad.is_undefined_primal(row)
assert not ad.is_undefined_primal(col)

if ad.is_undefined_primal(v):
return data, row, col, coo_matvec(data, row, col, ct, shape=shape, transpose=not transpose)
return data, row, col, _coo_matvec(data, row, col, ct, spinfo=spinfo, transpose=not transpose)
else:
v = jnp.asarray(v)
# The following line does this, but more efficiently:
Expand All @@ -342,7 +395,23 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):

coo_matmat_p = core.Primitive('coo_matmat')

def coo_matmat(data, row, col, B, *, shape, transpose=False):
def coo_matmat(mat, B, *, transpose=False):
"""Product of COO sparse matrix and a dense matrix.
Args:
mat : COO matrix
B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and
dtype ``mat.dtype``
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)``
representing the matrix vector product.
"""
return _coo_matmat(*mat._bufs, B, spinfo=mat._info, transpose=transpose)

def _coo_matmat(data, row, col, B, *, spinfo, transpose=False):
"""Product of COO sparse matrix and a dense matrix.
Args:
Expand All @@ -359,56 +428,56 @@ def coo_matmat(data, row, col, B, *, shape, transpose=False):
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix vector product.
"""
return coo_matmat_p.bind(data, row, col, B, shape=shape, transpose=transpose)
return coo_matmat_p.bind(data, row, col, B, spinfo=spinfo, transpose=transpose)

@coo_matmat_p.def_impl
def _coo_matmat_impl(data, row, col, B, *, shape, transpose):
def _coo_matmat_impl(data, row, col, B, *, spinfo, transpose):
B = jnp.asarray(B)
if transpose:
row, col = col, row
out_shape = shape[1] if transpose else shape[0]
out_shape = spinfo.shape[1] if transpose else spinfo.shape[0]
dB = data[:, None] * B[col]
return jnp.zeros((out_shape, B.shape[1]), dB.dtype).at[row].add(dB)

@coo_matmat_p.def_abstract_eval
def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose):
assert data.shape == row.shape == col.shape
assert data.dtype == B.dtype
assert B.ndim == 2
assert len(shape) == 2
assert B.shape[0] == (shape[0] if transpose else shape[1])
out_shape = shape[1] if transpose else shape[0]
assert len(spinfo.shape) == 2
assert B.shape[0] == (spinfo.shape[0] if transpose else spinfo.shape[1])
out_shape = spinfo.shape[1] if transpose else spinfo.shape[0]
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)

_coo_matmat_translation_rule = xla.lower_fun(
_coo_matmat_impl, multiple_results=False, new_style=True)

def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
B, *, shape, transpose):
B, *, spinfo, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
shape=shape, transpose=transpose)
spinfo=spinfo, transpose=transpose)
if cusparse is not None:
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=spinfo.shape,
transpose=transpose)]
else:
return [hipsparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
return [hipsparse.coo_matmat(ctx.builder, data, row, col, B, shape=spinfo.shape,
transpose=transpose)]

def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, shape, transpose):
return coo_matmat(data_dot, row, col, B, shape=shape, transpose=transpose)
def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose):
return _coo_matmat(data_dot, row, col, B, spinfo=spinfo, transpose=transpose)

def _coo_matmat_jvp_right(B_dot, data, row, col, B, *, shape, transpose):
return coo_matmat(data, row, col, B_dot, shape=shape, transpose=transpose)
def _coo_matmat_jvp_right(B_dot, data, row, col, B, *, spinfo, transpose):
return _coo_matmat(data, row, col, B_dot, spinfo=spinfo, transpose=transpose)

def _coo_matmat_transpose(ct, data, row, col, B, *, shape, transpose):
def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose):
assert not ad.is_undefined_primal(row)
assert not ad.is_undefined_primal(col)
if ad.is_undefined_primal(B):
return data, row, col, coo_matmat(data, row, col, ct, shape=shape, transpose=not transpose)
return data, row, col, _coo_matmat(data, row, col, ct, spinfo=spinfo, transpose=not transpose)
else:
B = jnp.asarray(B)
return (ct[row] * B[col]).sum(1), row, col, B
Expand Down

0 comments on commit 424536d

Please sign in to comment.