diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index a2b64aebf54c..c4c17ebfd594 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1905,17 +1905,62 @@ def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv): mlir.register_lowering(svd_p, _svd_tpu_lowering_rule) + def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b, *, m, n, ldb, t): - return [lowering(dl, d, du, b, m=m, n=n, ldb=ldb, - t=dtypes.canonicalize_dtype(t))] + _, _, _, b_aval = ctx.avals_in + b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape) + return [lowering( + dl, d, du, b, m=m, n=n, ldb=ldb, t=dtypes.canonicalize_dtype(t), + b_shape_vals=b_shape_vals)] + + +def _tridiagonal_solve_transpose_rule(cotangent, dl, d, du, b, *, m, n, ldb, t): + del m, n, ldb, t + # Tridiagonal solve is nonlinear in the tridiagonal arguments and linear + # otherwise. + assert not (ad.is_undefined_primal(dl) or ad.is_undefined_primal(d) or + ad.is_undefined_primal(du)) and ad.is_undefined_primal(b) + if type(cotangent) is ad_util.Zero: + cotangent_b = ad_util.Zero(b.aval) + else: + cotangent_b = tridiagonal_solve(dl, d, du, cotangent) + return [None, None, None, cotangent_b] + + +def _tridiagonal_solve_batching_rule( + batched_args, batch_dims, *, m, n, ldb, t): + del m, n, ldb, t + dl, d, du, b = batched_args + bdl, bd, bdu, bb = batch_dims + if (bdl is batching.not_mapped and + bd is batching.not_mapped and + bdu is batching.not_mapped): + + b = batching.moveaxis(b, bb, -2) + b_flat = b.reshape(b.shape[:-3] + (b.shape[-3], b.shape[-2] * b.shape[-1])) + bdim_out = b.ndim - 2 + out_flat = tridiagonal_solve(dl, d, du, b_flat) + return out_flat.reshape(b.shape), bdim_out + else: + size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) + if i is not None) + dl = batching.bdim_at_front(dl, bdl, size) + d = batching.bdim_at_front(d, bd, size) + du = batching.bdim_at_front(du, bdu, size) + b = batching.bdim_at_front(b, bb, size) + return tridiagonal_solve(dl, d, du, b), 0 + tridiagonal_solve_p = Primitive('tridiagonal_solve') tridiagonal_solve_p.multiple_results = False tridiagonal_solve_p.def_impl( functools.partial(dispatch.apply_primitive, tridiagonal_solve_p)) tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b) +ad.primitive_transposes[tridiagonal_solve_p] = _tridiagonal_solve_transpose_rule +batching.primitive_batchers[tridiagonal_solve_p] = _tridiagonal_solve_batching_rule # TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve? + mlir.register_lowering( tridiagonal_solve_p, partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2), @@ -1928,12 +1973,26 @@ def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b, *, m, n, ldb, t def _tridiagonal_solve_jax(dl, d, du, b, **kw): """Pure JAX implementation of `tridiagonal_solve`.""" - prepend_zero = lambda x: jnp.append(jnp.zeros([1], dtype=x.dtype), x[:-1]) + def prepend_zero(x): + return jnp.append( + jnp.zeros((1,) + x.shape[1:], dtype=x.dtype), + x[:-1], axis=0) fwd1 = lambda tu_, x: x[1] / (x[0] - x[2] * tu_) - fwd2 = lambda b_, x: (x[0] - x[3] * b_) / (x[1] - x[3] * x[2]) - bwd1 = lambda x_, x: x[0] - x[1] * x_ + + def fwd2(b_, x): + return (x[0] - x[3][jnp.newaxis, ...] * b_) / ( + x[1] - x[3] * x[2])[jnp.newaxis, ...] + + bwd1 = lambda x_, x: x[0] - x[1][jnp.newaxis, ...] * x_ double = lambda f, args: (f(*args), f(*args)) + # Move relevant dimensions to the front for the scan. + dl = jnp.moveaxis(dl, -1, 0) + d = jnp.moveaxis(d, -1, 0) + du = jnp.moveaxis(du, -1, 0) + b = jnp.moveaxis(b, -1, 0) + b = jnp.moveaxis(b, -1, 0) + # Forward pass. _, tu_ = lax.scan(lambda tu_, x: double(fwd1, (tu_, x)), du[0] / d[0], @@ -1941,7 +2000,7 @@ def _tridiagonal_solve_jax(dl, d, du, b, **kw): unroll=32) _, b_ = lax.scan(lambda b_, x: double(fwd2, (b_, x)), - b[0] / d[0], + b[0] / d[0:1], (b, d, prepend_zero(tu_), dl), unroll=32) @@ -1951,7 +2010,10 @@ def _tridiagonal_solve_jax(dl, d, du, b, **kw): (b_[::-1], tu_[::-1]), unroll=32) - return x_[::-1] + result = x_[::-1] + result = jnp.moveaxis(result, 0, -1) + result = jnp.moveaxis(result, 0, -1) + return result mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun( @@ -1967,31 +2029,30 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: A . X = B Args: - dl: The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``. + + dl: A batch of vectors with shape ``[..., m]``. + The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``. Note that ``dl[0] = 0``. - d: The middle diagnoal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``. - du: The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``. + d: A batch of vectors with shape ``[..., m]``. + The middle diagnoal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``. + du: A batch of vectors with shape ``[..., m]``. + The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``. Note that ``dl[m - 1] = 0``. b: Right hand side matrix. Returns: Solution ``X`` of tridiagonal system. """ - if dl.ndim != 1 or d.ndim != 1 or du.ndim != 1: - raise ValueError('dl, d and du must be vectors') - if dl.shape != d.shape or d.shape != du.shape: raise ValueError( f'dl={dl.shape}, d={d.shape} and du={du.shape} must all be `[m]`') - if b.ndim != 2: - raise ValueError(f'b={b.shape} must be a matrix') - - m, = dl.shape + m = dl.shape[-1] if m < 3: raise ValueError(f'm ({m}) must be >= 3') - ldb, n = b.shape + ldb = b.shape[-2] + n = b.shape[-1] if ldb < max(1, m): raise ValueError(f'Leading dimension of b={ldb} must be ≥ max(1, {m})') diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index 67166b58ba47..7bd1d463d7b5 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -548,8 +548,8 @@ std::pair BuildCooMatmatDescriptor( #endif // if JAX_GPU_HAVE_SPARSE -py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) { - return PackDescriptor(Gtsv2Descriptor{m, n, ldb}); +py::bytes BuildGtsv2Descriptor(int b, int m, int n, int ldb) { + return PackDescriptor(Gtsv2Descriptor{b, m, n, ldb}); } template diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 1d7ece842c03..93c6aef17008 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -557,16 +557,17 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); const Gtsv2Descriptor& descriptor = **s; + int batch = descriptor.batch; int m = descriptor.m; int n = descriptor.n; int ldb = descriptor.ldb; - const T* dl = (const T*)(buffers[0]); - const T* d = (const T*)(buffers[1]); - const T* du = (const T*)(buffers[2]); - const T* B = (T*)(buffers[3]); - T* X = (T*)(buffers[4]); - void* buffer = buffers[5]; + T* dl = static_cast(buffers[0]); + T* d = static_cast(buffers[1]); + T* du = static_cast(buffers[2]); + T* B = static_cast(buffers[3]); + T* X = static_cast(buffers[4]); + void* buffer = static_cast(buffers[5]); // The solution X is written in place to B. We need to therefore copy the // contents of B into the output buffer X and pass that into the kernel as B. @@ -575,13 +576,18 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, // and X might alias, but today we know they will not. // TODO(b/182906199): Update the comment here once copy insertion is WAI. if (X != B) { - size_t B_bytes = ldb * n * sizeof(T); + size_t B_bytes = ldb * n * sizeof(T) * batch; JAX_RETURN_IF_ERROR(JAX_AS_STATUS( gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream))); } - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer))); + for (int i = 0; i < batch; ++i) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(computeGtsv2( + handle.get(), m, n, dl, d, du, X, ldb, buffer))); + dl += m; + d += m; + du += m; + X += m * n; + } return absl::OkStatus(); } diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index aa7e8215bc46..2180767b0cf7 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -140,7 +140,7 @@ void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, #endif // JAX_GPU_HAVE_SPARSE struct Gtsv2Descriptor { - int m, n, ldb; + int batch, m, n, ldb; }; void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque, diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index 791f573083fc..d4c16e93de55 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -15,6 +15,7 @@ cusparse wrappers for performing sparse matrix computations in JAX """ +import math from functools import partial import jaxlib.mlir.ir as ir @@ -23,7 +24,7 @@ from jaxlib import xla_client -from .hlo_helpers import custom_call +from .hlo_helpers import custom_call, mk_result_types_and_shapes try: from .cuda import _sparse as _cusparse # pytype: disable=import-error @@ -338,26 +339,37 @@ def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape, rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse) -def _gtsv2_hlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t): +def _gtsv2_hlo( + platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t, b_shape_vals=None): """Calls `cusparsegtsv2(dl, d, du, B, m, n, ldb)`.""" + assert len(b_shape_vals) >= 2 + batch_dim_vals = b_shape_vals[:-2] + batch_size = math.prod(batch_dim_vals) + num_bd = len(b_shape_vals) - 2 f32 = (t == np.float32) if f32: buffer_size = gpu_sparse.gtsv2_f32_buffer_size(m, n, ldb) else: buffer_size = gpu_sparse.gtsv2_f64_buffer_size(m, n, ldb) + + b_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + d_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1)) + b_type = ir.RankedTensorType(B.type) + + shape_type_pairs = [ + (batch_dim_vals + (ldb, n), b_type.element_type), + ((buffer_size,), ir.IntegerType.get_signless(8)) + ] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) out = custom_call( f"{platform}sparse_gtsv2_" + ("f32" if f32 else "f64"), - [ - ir.RankedTensorType.get( - [ldb, n], ir.F32Type.get() if f32 else ir.F64Type.get()), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], + result_types, [dl, d, du, B], - backend_config=gpu_sparse.build_gtsv2_descriptor(m, n, ldb), - operand_layouts=[[0]] * 3 + [[1, 0]], - result_layouts=[[1, 0], [0]], - operand_output_aliases={3: 0}) + backend_config=gpu_sparse.build_gtsv2_descriptor(batch_size, m, n, ldb), + operand_layouts=[d_layout] * 3 + [b_layout], + result_layouts=[b_layout, [0]], + operand_output_aliases={3: 0}, + result_shapes=result_shapes) return out[0] cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse) diff --git a/tests/batching_test.py b/tests/batching_test.py index 1a51c3eb830c..7b090f205db3 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -641,6 +641,33 @@ def testLaxLinalgTriangularSolve(self): [lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)]) self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5) + def testLaxLinalgTridiagonalSolve(self): + dl = self.rng().randn(4, 10).astype(np.float32) + d = self.rng().randn(4, 10).astype(np.float32) + 1. + du = self.rng().randn(4, 10).astype(np.float32) + b = self.rng().randn(4, 5, 10).astype(np.float32) + + ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(1, 1, 1, 2))(dl, d, du, b) + expected = np.stack( + [lax.linalg.tridiagonal_solve( + dl[:, i], d[:, i], du[:, i], b[..., i]) for i in range(10)]) + self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5) + + ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(None, None, None, 2))( + dl[:, 0], d[:, 0], du[:, 0], b) + expected = np.stack( + [lax.linalg.tridiagonal_solve( + dl[:, 0], d[:, 0], du[:, 0], b[..., i]) for i in range(10)]) + self.assertAllClose(ans, expected) + + ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(1, 1, 1, None))( + dl, d, du, b[..., 0]) + expected = np.stack( + [lax.linalg.tridiagonal_solve( + dl[:, i], d[:, i], du[:, i], b[..., 0]) for i in range(10)]) + self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5) + + @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,