/
bcsr.py
584 lines (467 loc) · 22.5 KB
/
bcsr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BCSR (Bached compressed row) matrix object and associated primitives."""
from __future__ import annotations
import operator
from typing import NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
from jax import core
from jax import lax
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse import bcoo
from jax.experimental.sparse.util import (
_broadcasting_vmap, _count_stored_elements,
_csr_to_coo, _dot_general_validated_shape,
SparseInfo, Shape)
import jax.numpy as jnp
from jax._src import api_util
from jax._src.lax.lax import DotDimensionNumbers
from jax.util import split_list, safe_zip
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.typing import Array, ArrayLike, DTypeLike
class BCSRProperties(NamedTuple):
n_batch: int
n_dense: int
nse: int
def _compatible(shape1: Sequence[int], shape2: Sequence[int]) -> bool:
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))
def _validate_bcsr_indices(indices: jnp.ndarray, indptr: jnp.ndarray,
shape: Sequence[int]) -> BCSRProperties:
assert jnp.issubdtype(indices.dtype, jnp.integer)
assert jnp.issubdtype(indptr.dtype, jnp.integer)
shape = tuple(shape)
nse = indices.shape[-1]
n_batch = indices.ndim - 1
n_dense = len(shape) - n_batch - 2
assert n_dense >= 0
if not _compatible(indices.shape[:n_batch], shape[:n_batch]):
raise ValueError(f"indices batch dimensions not compatible for {indices.shape=}, {shape=}")
if not _compatible(indptr.shape[:n_batch], shape[:n_batch]):
raise ValueError(f"indptr batch dimensions not compatible for {indptr.shape=}, {shape=}")
if indptr.shape[n_batch:] != (shape[n_batch] + 1,):
raise ValueError("indptr shape must match the matrix shape plus 1.")
return BCSRProperties(n_batch=n_batch, n_dense=n_dense, nse=nse)
def _validate_bcsr(data: jnp.ndarray, indices: jnp.ndarray,
indptr: jnp.ndarray, shape: Sequence[int]) -> BCSRProperties:
props = _validate_bcsr_indices(indices, indptr, shape)
shape = tuple(shape)
n_batch, n_dense, nse = props.n_batch, props.n_dense, props.nse
n_sparse = len(shape) - n_batch - n_dense
if n_sparse != 2:
raise ValueError("BCSR array must have 2 sparse dimensions; "
f"{n_sparse} is given.")
if not _compatible(data.shape[:n_batch], shape[:n_batch]):
raise ValueError(f"data batch dimensions not compatible for {data.shape=}, {shape=}")
if data.shape[-(n_dense + 1):] != (nse,) + shape[n_batch + 2:]:
raise ValueError(f"Invalid {data.shape=} for {nse=}, {n_batch=}, {n_dense=}")
return props
def _bcsr_to_bcoo(indices: jnp.ndarray, indptr: jnp.ndarray, *,
shape: Sequence[int]) -> jnp.ndarray:
"""Given BCSR (indices, indptr), return BCOO (indices)."""
n_batch, _, _ = _validate_bcsr_indices(indices, indptr, shape)
csr_to_coo = _csr_to_coo
for _ in range(n_batch):
csr_to_coo = _broadcasting_vmap(csr_to_coo)
return jnp.stack(csr_to_coo(indices, indptr), axis=indices.ndim)
#--------------------------------------------------------------------
# bcsr_fromdense
bcsr_fromdense_p = core.Primitive('bcsr_fromdense')
bcsr_fromdense_p.multiple_results = True
_TRACED_NSE_ERROR = """
The error arose for the nse argument of bcsr_fromdense. In order for
BCSR.fromdense() to be used in traced/compiled code, you must pass a concrete
value to the nse (number of stored elements) argument.
"""
def bcsr_fromdense(mat: ArrayLike, *, nse: Optional[int] = None, n_batch: int = 0,
n_dense:int = 0, index_dtype: DTypeLike = jnp.int32) -> BCSR:
"""Create BCSR-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to BCOO.
nse : number of stored elements in each batch
n_batch : number of batch dimensions (default: 0)
n_dense : number of dense dimensions (default: 0)
index_dtype : dtype of sparse indices (default: int32)
Returns:
mat_bcsr: BCSR representation of the matrix.
"""
mat = jnp.asarray(mat)
if nse is None:
nse = _count_stored_elements(mat, n_batch, n_dense)
nse_int: int = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
return BCSR(_bcsr_fromdense(mat, nse=nse_int, n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype),
shape=mat.shape)
def _bcsr_fromdense(mat: ArrayLike, *, nse: int, n_batch: int = 0, n_dense: int = 0,
index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array, Array]:
"""Create BCSR-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to BCSR, with
``ndim = n_batch + n_sparse + n_dense``.
nse : number of stored elements in each batch.
n_batch : number of batch dimensions (default: 0)
n_dense : number of dense dimensions (default: 0)
index_dtype : dtype of sparse indices (default: int32)
Returns:
data : array of shape
``mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]``
and dtype ``mat.dtype``
indices : array of shape ``mat.shape[:n_batch] + (nse,)`` and dtype of
``index_type``.
indptr: array of shape ``mat.shape[:n_batch] + (mat.shape[n_batch] + 1,)``
and dtype of ``index_type``.
"""
mat = jnp.asarray(mat)
nse = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
return bcsr_fromdense_p.bind(mat, nse=nse, n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype)
@bcsr_fromdense_p.def_impl
def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
mat = jnp.asarray(mat)
n_sparse = mat.ndim - n_dense - n_batch
if n_sparse != 2:
raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.")
bcoo_mat = bcoo.bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype,
n_dense=n_dense, n_batch=n_batch)
indices, indptr = bcoo._bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape)
return bcoo_mat.data, indices, indptr
@bcsr_fromdense_p.def_abstract_eval
def _bcoo_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
n_sparse = mat.ndim - n_batch - n_dense
if n_sparse != 2:
raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.")
data_shape = mat.shape[:n_batch] + (nse,) + mat.shape[n_batch + n_sparse:]
index_shape = mat.shape[:n_batch] + (nse,)
indptr_shape = mat.shape[:n_batch] + (mat.shape[n_batch] + 1,)
return (core.ShapedArray(data_shape, mat.dtype),
core.ShapedArray(index_shape, index_dtype),
core.ShapedArray(indptr_shape, index_dtype))
def _bcsr_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch,
n_dense, index_dtype):
M, = batched_args
if batch_dims != (0,):
raise NotImplementedError(f"{batch_dims=}")
new_n_batch = n_batch + 1
n_sparse = M.ndim - n_dense - new_n_batch
if n_sparse != 2:
raise ValueError("_bcsr_fromdense_batching_rule: must have 2 sparse "
f"dimensions but {n_sparse} is given.")
return _bcsr_fromdense(M, nse=nse, n_batch=new_n_batch, n_dense=n_dense,
index_dtype=index_dtype), (0, 0, 0)
batching.primitive_batchers[bcsr_fromdense_p] = _bcsr_fromdense_batching_rule
mlir.register_lowering(bcsr_fromdense_p, mlir.lower_fun(
_bcsr_fromdense_impl, multiple_results=True))
#----------------------------------------------------------------------
# bcsr_todense
bcsr_todense_p = core.Primitive('bcsr_todense')
def bcsr_todense(mat: BCSR) -> Array:
"""Convert batched sparse matrix to a dense matrix.
Args:
mat: BCSR matrix.
Returns:
The dense version of ``mat``.
"""
return _bcsr_todense(mat.data, mat.indices, mat.indptr,
shape=tuple(mat.shape))
def _bcsr_todense(data: ArrayLike, indices: ArrayLike, indptr: ArrayLike, *, shape: Shape) -> Array:
"""Convert batched sparse matrix to a dense matrix.
Args:
data : array of shape ``batch_dims + (nse,) + dense_dims``.
indices : array of shape ``batch_dims + (nse,)``.
indptr : array of shape ``batch_dims + (shape[len(batch_dims)] + 1,).
shape : tuple; the shape of the (batched) matrix. Equal to
``batch_dims + 2(sparse_dims) + dense_dims``
Returns:
mat : array with specified shape and dtype matching ``data``
"""
return bcsr_todense_p.bind(jnp.asarray(data), jnp.asarray(indices),
jnp.asarray(indptr), shape=shape)
@bcsr_todense_p.def_impl
def _bcsr_todense_impl(data, indices, indptr, *, shape):
bcoo_indices = _bcsr_to_bcoo(indices, indptr, shape=shape)
return (bcoo.BCOO((data, bcoo_indices), shape=shape)).todense()
@bcsr_todense_p.def_abstract_eval
def _bcsr_todense_abstract_eval(data, indices, indptr, *, shape):
_validate_bcsr(data, indices, indptr, shape)
return core.ShapedArray(shape, data.dtype)
def _bcsr_todense_batching_rule(batched_args, batch_dims, *, shape):
data, indices, indptr = batched_args
if any(b not in [0, None] for b in batch_dims):
raise NotImplementedError(f"{batch_dims=}. Only 0 and None are supported.")
if batch_dims[0] is None:
data = data[None, ...]
if batch_dims[1] is None:
indices = indices[None, ...]
if batch_dims[2] is None:
indptr = indptr[None, ...]
return _bcsr_todense(data, indices, indptr, shape=shape), 0
batching.primitive_batchers[bcsr_todense_p] = _bcsr_todense_batching_rule
mlir.register_lowering(bcsr_todense_p, mlir.lower_fun(
_bcsr_todense_impl, multiple_results=False))
#--------------------------------------------------------------------
# bcsr_extract
bcsr_extract_p = core.Primitive('bcsr_extract')
def bcsr_extract(indices: ArrayLike, indptr: ArrayLike, mat: ArrayLike) -> Array:
"""Extract values from a dense matrix at given BCSR (indices, indptr).
Args:
indices: An ndarray; see BCSR indices.
indptr: An ndarray; see BCSR indptr.
mat: A dense matrix.
Returns:
An ndarray; see BCSR data.
"""
return bcsr_extract_p.bind(indices, indptr, mat)
@bcsr_extract_p.def_impl
def _bcsr_extract_impl(indices, indptr, mat):
mat = jnp.asarray(mat)
bcoo_indices = _bcsr_to_bcoo(indices, indptr, shape=mat.shape)
return bcoo.bcoo_extract(bcoo_indices, mat)
@bcsr_extract_p.def_abstract_eval
def _bcsr_extract_abstract_eval(indices, indptr, mat):
n_batch, n_dense, nse = _validate_bcsr_indices(indices, indptr, mat.shape)
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
return core.ShapedArray(out_shape, mat.dtype)
mlir.register_lowering(bcsr_extract_p, mlir.lower_fun(
_bcsr_extract_impl, multiple_results=False))
#----------------------------------------------------------------------
# bcsr_dot_general
bcsr_dot_general_p = core.Primitive('bcsr_dot_general')
def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
precision: None = None,
preferred_element_type: None = None) -> Array:
"""A general contraction operation.
Args:
lhs: An ndarray or BCSR-format sparse array.
rhs: An ndarray or BCSR-format sparse array..
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
precision: unused
preferred_element_type: unused
Returns:
An ndarray or BCSR-format sparse array containing the result. If both inputs
are sparse, the result will be sparse, of type BCSR. If either input is
dense, the result will be dense, of type ndarray.
"""
del precision, preferred_element_type # unused
if isinstance(rhs, (np.ndarray, jnp.ndarray)):
if isinstance(lhs, (np.ndarray, jnp.ndarray)):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
if isinstance(lhs, BCSR):
lhs_data, lhs_indices, lhs_indptr = lhs._bufs
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs._info)
raise NotImplementedError("bcsr_dot_general currently implemented for BCSR "
"lhs and ndarray rhs.")
def _bcsr_dot_general(lhs_data: jnp.ndarray, lhs_indices: jnp.ndarray,
lhs_indptr: jnp.ndarray, rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
lhs_spinfo: SparseInfo) -> Array:
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
return bcsr_dot_general_p.bind(jnp.asarray(lhs_data),
jnp.asarray(lhs_indices),
jnp.asarray(lhs_indptr), jnp.asarray(rhs),
dimension_numbers=(cdims, bdims),
lhs_spinfo=lhs_spinfo)
@bcsr_dot_general_p.def_impl
def _bcsr_dot_general_impl(lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
lhs_data = jnp.asarray(lhs_data)
lhs_bcsr_indices = jnp.asarray(lhs_indices)
lhs_bcsr_indptr = jnp.asarray(lhs_indptr)
rhs = jnp.asarray(rhs)
lhs_bcoo_indices = _bcsr_to_bcoo(lhs_bcsr_indices, lhs_bcsr_indptr,
shape=lhs_spinfo.shape)
return bcoo._bcoo_dot_general_impl(lhs_data, lhs_bcoo_indices, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
@bcsr_dot_general_p.def_abstract_eval
def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
if lhs_data.dtype != rhs.dtype:
raise ValueError("bcsr_dot_general requires arguments to have matching "
f"dtypes; got lhs.dtype={lhs_data.dtype}, "
f"rhs.dtype={rhs.dtype}")
(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
props = _validate_bcsr_indices(lhs_indices, lhs_indptr, lhs_spinfo.shape)
out_shape = _dot_general_validated_shape(lhs_spinfo.shape, rhs.shape,
dimension_numbers)
if lhs_batch and max(lhs_batch) >= props.n_batch:
raise NotImplementedError(
"bcsr_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
f"got {lhs_batch=}, {props.n_batch=}")
# TODO: support contraction of dense dimensions?
if any(d >= props.n_batch + 2 for d in lhs_contracting):
raise NotImplementedError("bcsr_dot_general: contracting over dense dimensions.")
return core.ShapedArray(out_shape, lhs_data.dtype)
# def _bcsr_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, lhs_indptr,
# rhs, *, dimension_numbers, lhs_spinfo):
# del lhs_data
# return _bcsr_dot_general(lhs_data_dot, lhs_indices, lhs_indptr, rhs,
# dimension_numbers=dimension_numbers,
# lhs_spinfo=lhs_spinfo)
# def _bcsr_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, lhs_indptr, rhs,
# *, dimension_numbers, lhs_spinfo):
# del rhs
# return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs_dot,
# dimension_numbers=dimension_numbers,
# lhs_spinfo=lhs_spinfo)
# def _bcsr_dot_general_transpose(ct, lhs_data, lhs_indices, lhs_inptr, rhs, *,
# dimension_numbers, lhs_spinfo):
# lhs_bcoo_indices = _bcsr_to_bcoo(
# lhs_indices, lhs_inptr, shape=lhs_spinfo.shape)
# return bcoo._bcoo_dot_general_transpose(
# ct, lhs_data, lhs_bcoo_indices, rhs, dimension_numbers=dimension_numbers,
# lhs_spinfo=lhs_spinfo)
# def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
# dimension_numbers, lhs_spinfo):
# lhs_data, lhs_indices, lhs_indptr, rhs = batched_args
# lhs_bcoo_indices = _bcsr_to_bcoo(
# lhs_indices, lhs_indptr, shape=lhs_spinfo.shape)
# return bcoo._bcoo_dot_general_batch_rule(
# (lhs_data, lhs_bcoo_indices, rhs), batch_dims,
# dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
# ad.defjvp(bcsr_dot_general_p, _bcsr_dot_general_jvp_lhs, None,
# _bcsr_dot_general_jvp_rhs)
# ad.primitive_transposes[bcsr_dot_general_p] = _bcsr_dot_general_transpose
# batching.primitive_batchers[bcsr_dot_general_p] = _bcsr_dot_general_batch_rule
_bcsr_dot_general_default_lowering = mlir.lower_fun(
_bcsr_dot_general_impl, multiple_results=False)
mlir.register_lowering(
bcsr_dot_general_p, _bcsr_dot_general_default_lowering)
@tree_util.register_pytree_node_class
class BCSR(JAXSparse):
"""Experimental batched CSR matrix implemented in JAX."""
data: jnp.ndarray
indices: jnp.ndarray
indptr: jnp.ndarray
shape: Shape
nse = property(lambda self: self.indices.shape[-1])
dtype = property(lambda self: self.data.dtype)
n_batch = property(lambda self: self.indices.ndim - 1)
n_sparse = property(lambda _: 2)
n_dense = property(lambda self: self.data.ndim - self.indices.ndim)
_bufs = property(lambda self: (self.data, self.indices, self.indptr))
_info = property(lambda self: SparseInfo(self.shape))
@property
def _sparse_shape(self):
return tuple(self.shape[self.n_batch:self.n_batch + 2])
def __init__(self, args, *, shape):
# JAX transforms will sometimes instantiate pytrees with null values, so we
# must catch that in the initialization of inputs.
self.data, self.indices, self.indptr = map(jnp.asarray, args)
super().__init__(args, shape=shape)
_validate_bcsr(self.data, self.indices, self.indptr, self.shape)
def __repr__(self):
name = self.__class__.__name__
try:
nse = self.nse
n_batch = self.n_batch
n_dense = self.n_dense
dtype = self.dtype
shape = list(self.shape)
except Exception: # pylint: disable=broad-except
repr_ = f"{name}(<invalid>)"
else:
extra = f", {nse=}"
if n_batch: extra += f", {n_batch=}"
if n_dense: extra += f", {n_dense=}"
repr_ = f"{name}({dtype}{shape}{extra})"
if isinstance(self.data, core.Tracer):
repr_ = f"{type(self.data).__name__}[{repr_}]"
return repr_
def transpose(self, *args, **kwargs):
raise NotImplementedError("Tranpose is not implemented.")
def tree_flatten(self):
# TODO(tianjianlu): Unflatten SparseInfo with self._info._asdict().
return (self.data, self.indices, self.indptr), {'shape': self.shape}
@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices, obj.indptr = children
if aux_data.keys() != {'shape'}:
raise ValueError(f"BCSR.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj
@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32', n_dense=0,
n_batch=0, nse=0):
"""Create an empty BCSR instance. Public method is sparse.empty()."""
shape = tuple(shape)
if n_dense < 0 or n_batch < 0 or nse < 0:
raise ValueError(f"Invalid inputs: {shape=}, {n_dense=}, {n_batch=}, {nse=}")
n_sparse = len(shape) - n_dense - n_batch
if n_sparse != 2:
raise ValueError("BCSR sparse.empty: must have 2 sparse dimensions.")
batch_shape, sparse_shape, dense_shape = split_list(shape,
[n_batch, n_sparse])
data = jnp.zeros((*batch_shape, nse, *dense_shape), dtype)
indices = jnp.full((*batch_shape, nse), jnp.array(sparse_shape[1]),
index_dtype)
indptr = jnp.zeros((*batch_shape, sparse_shape[0] + 1), index_dtype)
return cls((data, indices, indptr), shape=shape)
@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0,
n_batch=0):
"""Create a BCSR array from a (dense) :class:`DeviceArray`."""
return bcsr_fromdense(mat, nse=nse, index_dtype=index_dtype,
n_dense=n_dense, n_batch=n_batch)
def todense(self):
"""Create a dense version of the array."""
return bcsr_todense(self)
@classmethod
def from_scipy_sparse(cls, mat, *, index_dtype=None, n_dense=0, n_batch=0):
"""Create a BCSR array from a :mod:`scipy.sparse` array."""
if n_dense != 0 or n_batch != 0:
raise NotImplementedError("BCSR from_scipy_sparse with nonzero n_dense/n_batch.")
if mat.ndim != 2:
raise ValueError(f"BCSR from_scipy_sparse requires 2D array; {mat.ndim}D is given.")
mat = mat.tocsr()
data = jnp.asarray(mat.data)
indices = jnp.asarray(mat.indices).astype(index_dtype or jnp.int32)
indptr = jnp.asarray(mat.indptr).astype(index_dtype or jnp.int32)
return cls((data, indices, indptr), shape=mat.shape)
#--------------------------------------------------------------------
# vmappable handlers
def _bcsr_to_elt(cont, _, val, axis):
if axis is None:
return val
if axis >= val.n_batch:
raise ValueError(f"Cannot map in_axis={axis} for BCSR array with n_batch="
f"{val.n_batch}. in_axes for batched BCSR operations must "
"correspond to a batched dimension.")
return BCSR((cont(val.data, axis),
cont(val.indices, axis),
cont(val.indptr, axis)),
shape=val.shape[:axis] + val.shape[axis + 1:])
def _bcsr_from_elt(cont, axis_size, elt, axis):
if axis is None:
return elt
if axis > elt.n_batch:
raise ValueError(f"BCSR: cannot add out_axis={axis} for BCSR array with "
f"n_batch={elt.n_batch}. BCSR batch axes must be a "
"contiguous block of leading dimensions.")
return BCSR((cont(axis_size, elt.data, axis),
cont(axis_size, elt.indices, axis),
cont(axis_size, elt.indptr, axis)),
shape=elt.shape[:axis] + (axis_size,) + elt.shape[axis:])
batching.register_vmappable(BCSR, int, int, _bcsr_to_elt, _bcsr_from_elt, None)