Skip to content

Commit

Permalink
Add batching rules to jax.lax.linalg.tridiagonal_solve.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555700103
  • Loading branch information
srvasude authored and jax authors committed Aug 10, 2023
1 parent 60c3fdf commit 7dfc8ff
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 43 deletions.
97 changes: 79 additions & 18 deletions jax/_src/lax/linalg.py
Expand Up @@ -1905,17 +1905,62 @@ def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv):

mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)


def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b, *, m, n, ldb, t):
return [lowering(dl, d, du, b, m=m, n=n, ldb=ldb,
t=dtypes.canonicalize_dtype(t))]
_, _, _, b_aval = ctx.avals_in
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
return [lowering(
dl, d, du, b, m=m, n=n, ldb=ldb, t=dtypes.canonicalize_dtype(t),
b_shape_vals=b_shape_vals)]


def _tridiagonal_solve_transpose_rule(cotangent, dl, d, du, b, *, m, n, ldb, t):
del m, n, ldb, t
# Tridiagonal solve is nonlinear in the tridiagonal arguments and linear
# otherwise.
assert not (ad.is_undefined_primal(dl) or ad.is_undefined_primal(d) or
ad.is_undefined_primal(du)) and ad.is_undefined_primal(b)
if type(cotangent) is ad_util.Zero:
cotangent_b = ad_util.Zero(b.aval)
else:
cotangent_b = tridiagonal_solve(dl, d, du, cotangent)
return [None, None, None, cotangent_b]


def _tridiagonal_solve_batching_rule(
batched_args, batch_dims, *, m, n, ldb, t):
del m, n, ldb, t
dl, d, du, b = batched_args
bdl, bd, bdu, bb = batch_dims
if (bdl is batching.not_mapped and
bd is batching.not_mapped and
bdu is batching.not_mapped):

b = batching.moveaxis(b, bb, -2)
b_flat = b.reshape(b.shape[:-3] + (b.shape[-3], b.shape[-2] * b.shape[-1]))
bdim_out = b.ndim - 2
out_flat = tridiagonal_solve(dl, d, du, b_flat)
return out_flat.reshape(b.shape), bdim_out
else:
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
if i is not None)
dl = batching.bdim_at_front(dl, bdl, size)
d = batching.bdim_at_front(d, bd, size)
du = batching.bdim_at_front(du, bdu, size)
b = batching.bdim_at_front(b, bb, size)
return tridiagonal_solve(dl, d, du, b), 0


tridiagonal_solve_p = Primitive('tridiagonal_solve')
tridiagonal_solve_p.multiple_results = False
tridiagonal_solve_p.def_impl(
functools.partial(dispatch.apply_primitive, tridiagonal_solve_p))
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
ad.primitive_transposes[tridiagonal_solve_p] = _tridiagonal_solve_transpose_rule
batching.primitive_batchers[tridiagonal_solve_p] = _tridiagonal_solve_batching_rule
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?


mlir.register_lowering(
tridiagonal_solve_p,
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2),
Expand All @@ -1928,20 +1973,34 @@ def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b, *, m, n, ldb, t

def _tridiagonal_solve_jax(dl, d, du, b, **kw):
"""Pure JAX implementation of `tridiagonal_solve`."""
prepend_zero = lambda x: jnp.append(jnp.zeros([1], dtype=x.dtype), x[:-1])
def prepend_zero(x):
return jnp.append(
jnp.zeros((1,) + x.shape[1:], dtype=x.dtype),
x[:-1], axis=0)
fwd1 = lambda tu_, x: x[1] / (x[0] - x[2] * tu_)
fwd2 = lambda b_, x: (x[0] - x[3] * b_) / (x[1] - x[3] * x[2])
bwd1 = lambda x_, x: x[0] - x[1] * x_

def fwd2(b_, x):
return (x[0] - x[3][jnp.newaxis, ...] * b_) / (
x[1] - x[3] * x[2])[jnp.newaxis, ...]

bwd1 = lambda x_, x: x[0] - x[1][jnp.newaxis, ...] * x_
double = lambda f, args: (f(*args), f(*args))

# Move relevant dimensions to the front for the scan.
dl = jnp.moveaxis(dl, -1, 0)
d = jnp.moveaxis(d, -1, 0)
du = jnp.moveaxis(du, -1, 0)
b = jnp.moveaxis(b, -1, 0)
b = jnp.moveaxis(b, -1, 0)

# Forward pass.
_, tu_ = lax.scan(lambda tu_, x: double(fwd1, (tu_, x)),
du[0] / d[0],
(d, du, dl),
unroll=32)

_, b_ = lax.scan(lambda b_, x: double(fwd2, (b_, x)),
b[0] / d[0],
b[0] / d[0:1],
(b, d, prepend_zero(tu_), dl),
unroll=32)

Expand All @@ -1951,7 +2010,10 @@ def _tridiagonal_solve_jax(dl, d, du, b, **kw):
(b_[::-1], tu_[::-1]),
unroll=32)

return x_[::-1]
result = x_[::-1]
result = jnp.moveaxis(result, 0, -1)
result = jnp.moveaxis(result, 0, -1)
return result


mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun(
Expand All @@ -1967,31 +2029,30 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
A . X = B
Args:
dl: The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
dl: A batch of vectors with shape ``[..., m]``.
The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
Note that ``dl[0] = 0``.
d: The middle diagnoal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
du: The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
d: A batch of vectors with shape ``[..., m]``.
The middle diagnoal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
du: A batch of vectors with shape ``[..., m]``.
The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
Note that ``dl[m - 1] = 0``.
b: Right hand side matrix.
Returns:
Solution ``X`` of tridiagonal system.
"""
if dl.ndim != 1 or d.ndim != 1 or du.ndim != 1:
raise ValueError('dl, d and du must be vectors')

if dl.shape != d.shape or d.shape != du.shape:
raise ValueError(
f'dl={dl.shape}, d={d.shape} and du={du.shape} must all be `[m]`')

if b.ndim != 2:
raise ValueError(f'b={b.shape} must be a matrix')

m, = dl.shape
m = dl.shape[-1]
if m < 3:
raise ValueError(f'm ({m}) must be >= 3')

ldb, n = b.shape
ldb = b.shape[-2]
n = b.shape[-1]
if ldb < max(1, m):
raise ValueError(f'Leading dimension of b={ldb} must be ≥ max(1, {m})')

Expand Down
4 changes: 2 additions & 2 deletions jaxlib/gpu/sparse.cc
Expand Up @@ -548,8 +548,8 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(

#endif // if JAX_GPU_HAVE_SPARSE

py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
py::bytes BuildGtsv2Descriptor(int b, int m, int n, int ldb) {
return PackDescriptor(Gtsv2Descriptor{b, m, n, ldb});
}

template <typename F>
Expand Down
26 changes: 16 additions & 10 deletions jaxlib/gpu/sparse_kernels.cc
Expand Up @@ -557,16 +557,17 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers,
auto s = UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const Gtsv2Descriptor& descriptor = **s;
int batch = descriptor.batch;
int m = descriptor.m;
int n = descriptor.n;
int ldb = descriptor.ldb;

const T* dl = (const T*)(buffers[0]);
const T* d = (const T*)(buffers[1]);
const T* du = (const T*)(buffers[2]);
const T* B = (T*)(buffers[3]);
T* X = (T*)(buffers[4]);
void* buffer = buffers[5];
T* dl = static_cast<T*>(buffers[0]);
T* d = static_cast<T*>(buffers[1]);
T* du = static_cast<T*>(buffers[2]);
T* B = static_cast<T*>(buffers[3]);
T* X = static_cast<T*>(buffers[4]);
void* buffer = static_cast<void *>(buffers[5]);

// The solution X is written in place to B. We need to therefore copy the
// contents of B into the output buffer X and pass that into the kernel as B.
Expand All @@ -575,13 +576,18 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers,
// and X might alias, but today we know they will not.
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
if (X != B) {
size_t B_bytes = ldb * n * sizeof(T);
size_t B_bytes = ldb * n * sizeof(T) * batch;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream)));
}

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer)));
for (int i = 0; i < batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(computeGtsv2(
handle.get(), m, n, dl, d, du, X, ldb, buffer)));
dl += m;
d += m;
du += m;
X += m * n;
}
return absl::OkStatus();
}

Expand Down
2 changes: 1 addition & 1 deletion jaxlib/gpu/sparse_kernels.h
Expand Up @@ -140,7 +140,7 @@ void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque,
#endif // JAX_GPU_HAVE_SPARSE

struct Gtsv2Descriptor {
int m, n, ldb;
int batch, m, n, ldb;
};

void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque,
Expand Down
36 changes: 24 additions & 12 deletions jaxlib/gpu_sparse.py
Expand Up @@ -15,6 +15,7 @@
cusparse wrappers for performing sparse matrix computations in JAX
"""

import math
from functools import partial

import jaxlib.mlir.ir as ir
Expand All @@ -23,7 +24,7 @@

from jaxlib import xla_client

from .hlo_helpers import custom_call
from .hlo_helpers import custom_call, mk_result_types_and_shapes

try:
from .cuda import _sparse as _cusparse # pytype: disable=import-error
Expand Down Expand Up @@ -338,26 +339,37 @@ def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape,
rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse)


def _gtsv2_hlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
def _gtsv2_hlo(
platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t, b_shape_vals=None):
"""Calls `cusparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
assert len(b_shape_vals) >= 2
batch_dim_vals = b_shape_vals[:-2]
batch_size = math.prod(batch_dim_vals)
num_bd = len(b_shape_vals) - 2
f32 = (t == np.float32)
if f32:
buffer_size = gpu_sparse.gtsv2_f32_buffer_size(m, n, ldb)
else:
buffer_size = gpu_sparse.gtsv2_f64_buffer_size(m, n, ldb)

b_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
d_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1))
b_type = ir.RankedTensorType(B.type)

shape_type_pairs = [
(batch_dim_vals + (ldb, n), b_type.element_type),
((buffer_size,), ir.IntegerType.get_signless(8))
]
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
out = custom_call(
f"{platform}sparse_gtsv2_" + ("f32" if f32 else "f64"),
[
ir.RankedTensorType.get(
[ldb, n], ir.F32Type.get() if f32 else ir.F64Type.get()),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
],
result_types,
[dl, d, du, B],
backend_config=gpu_sparse.build_gtsv2_descriptor(m, n, ldb),
operand_layouts=[[0]] * 3 + [[1, 0]],
result_layouts=[[1, 0], [0]],
operand_output_aliases={3: 0})
backend_config=gpu_sparse.build_gtsv2_descriptor(batch_size, m, n, ldb),
operand_layouts=[d_layout] * 3 + [b_layout],
result_layouts=[b_layout, [0]],
operand_output_aliases={3: 0},
result_shapes=result_shapes)
return out[0]

cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse)
Expand Down
27 changes: 27 additions & 0 deletions tests/batching_test.py
Expand Up @@ -641,6 +641,33 @@ def testLaxLinalgTriangularSolve(self):
[lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)

def testLaxLinalgTridiagonalSolve(self):
dl = self.rng().randn(4, 10).astype(np.float32)
d = self.rng().randn(4, 10).astype(np.float32) + 1.
du = self.rng().randn(4, 10).astype(np.float32)
b = self.rng().randn(4, 5, 10).astype(np.float32)

ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(1, 1, 1, 2))(dl, d, du, b)
expected = np.stack(
[lax.linalg.tridiagonal_solve(
dl[:, i], d[:, i], du[:, i], b[..., i]) for i in range(10)])
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)

ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(None, None, None, 2))(
dl[:, 0], d[:, 0], du[:, 0], b)
expected = np.stack(
[lax.linalg.tridiagonal_solve(
dl[:, 0], d[:, 0], du[:, 0], b[..., i]) for i in range(10)])
self.assertAllClose(ans, expected)

ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(1, 1, 1, None))(
dl, d, du, b[..., 0])
expected = np.stack(
[lax.linalg.tridiagonal_solve(
dl[:, i], d[:, i], du[:, i], b[..., 0]) for i in range(10)])
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)


@parameterized.named_parameters(
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
Expand Down

0 comments on commit 7dfc8ff

Please sign in to comment.