From 92288d3071ab497fe5988a15f0193c4ee86c230b Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 22 Jun 2023 01:05:23 -0700 Subject: [PATCH] [shape_poly] linalg.svd: shape polymorphism for native serialization on CPU PiperOrigin-RevId: 542483203 --- jax/_src/interpreters/mlir.py | 22 +++- jax/_src/lax/linalg.py | 37 +++++-- jax/experimental/jax2tf/jax_export.py | 2 + .../jax2tf/tests/back_compat_test.py | 5 +- .../jax2tf/tests/shape_poly_test.py | 2 +- jaxlib/lapack.py | 104 +++++++++++------- 6 files changed, 123 insertions(+), 49 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 6e97e72d29cb..161d98083d0b 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -21,6 +21,7 @@ from functools import partial import io import itertools +import operator import re import typing from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional, @@ -559,11 +560,14 @@ def eval_dynamic_shape(ctx: LoweringRuleContext, ctx = ctx.replace( primitive="eval_dynamic_shape", avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars)) + res = lower_fun( partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars), multiple_results=True)(ctx, *ctx.dim_var_values) - return util.flatten(res) # type: ignore + return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir + for d, d_ir in zip(shape, util.flatten(res))) # type: ignore +# TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext, shape: core.Shape) -> Tuple[Value, ...]: """Evaluates the dynamic shapes as int32 values.""" @@ -579,6 +583,22 @@ def convert_dim(d: Union[int, Value]): return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape)) +def eval_dynamic_shape_as_ivals( + ctx: LoweringRuleContext, shape: core.Shape + ) -> Tuple[Union[int, Value], ...]: + """Evaluates the dynamic shapes as int or ir.int32 values.""" + def convert_dim(d: Union[int, Value]) -> Union[int, ir.Value]: + if type(d) is int: + return d + else: + i32_type = aval_to_ir_type(core.ShapedArray((), np.int32)) + if d.type != i32_type: # type: ignore + return hlo.ConvertOp(i32_type, d).result + else: + return d + return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape)) + + class LoweringResult(NamedTuple): module: ir.Module keepalive: Optional[Any] diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 4e4527af050c..e30c10892e2c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1715,21 +1715,37 @@ def _empty_svd(a, *, full_matrices, compute_uv): return s, u, v def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices, - compute_uv): - if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)): - raise NotImplementedError("Shape polymorphism for custom call is not implemented (svd); b/261671778") + compute_uv, platform: str): operand_aval, = ctx.avals_in s_aval = ctx.avals_out[0] m, n = operand_aval.shape[-2:] + # Since the last two dimensions (m, n) are used to compute the workspace + # size, we support dynamic dimensions only for the batch size for now. + if not is_constant_shape([m, n]): + raise NotImplementedError( + "Shape polymorphism for native serialization for svd on CPU and GPU is " + f"implemented only for the batch dimensions: {operand_aval.shape}") batch_dims = operand_aval.shape[:-2] if m == 0 or n == 0: return mlir.lower_fun(_empty_svd, multiple_results=True)( ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv) - s, u, vt, info = gesvd_impl(operand_aval.dtype, operand, - full_matrices=full_matrices, - compute_uv=compute_uv) + if platform in ["cuda", "rocm"] or jaxlib_version < (0, 4, 13): + if not is_constant_shape(operand_aval.shape): + # TODO(necula): remove the platform kwarg when we implement GPU support. + raise NotImplementedError( + "Shape polymorphism for native serialization for SVD is not " + f"implemented, try to upgrade jaxlib; b/261671778; {operand_aval.shape}") + s, u, vt, info = gesvd_impl(operand_aval.dtype, operand, + full_matrices=full_matrices, + compute_uv=compute_uv) + else: + a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) + s, u, vt, info = gesvd_impl(operand_aval.dtype, operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + a_shape_vals=a_shape_vals) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED") select_s_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_)) @@ -1805,13 +1821,16 @@ def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv): batching.primitive_batchers[svd_p] = _svd_batching_rule mlir.register_lowering( - svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo), + svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo, + platform='cpu'), platform='cpu') mlir.register_lowering( - svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd), + svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd, + platform='cuda'), platform='cuda') mlir.register_lowering( - svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd), + svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd, + platform='rocm'), platform='rocm') mlir.register_lowering(svd_p, _svd_tpu_lowering_rule) diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 23181bee8b99..9af1d016beeb 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -721,6 +721,8 @@ def _check_lowering(lowering) -> None: "cusolver_geqrf", "cusolver_orgqr", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", + # svd on CPU + "lapack_sgesdd", "lapack_dsesdd", "lapack_cgesdd", "lapack_zgesdd", # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU # # lu on CPU # "lapack_sgetrf" , "lapack_dgetrf" , "lapack_cgetrf" , "lapack_zgetrf", diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index ff93b5700da6..fbf5ed667550 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -414,7 +414,10 @@ def test_custom_call_coverage(self): covered_targets = covered_targets.union({ # TODO(necula): add tests for eig on CPU - "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev"}) + "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", + # TODO(necula): add tests for svd on CPU + "lapack_sgesdd", "lapack_dsesdd", "lapack_cgesdd", "lapack_zgesdd", + }) not_covered = targets_to_cover.difference(covered_targets) self.assertEmpty(not_covered) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index b903f29d8593..80f8cdbeb806 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2811,7 +2811,7 @@ def test_harness(self, harness: PolyHarness): # custom_linear_solve uses lu "vmap_custom_linear_solve:cpu", "vmap_custom_linear_solve:gpu", "vmap_qr:cpu", "vmap_qr:gpu", - "vmap_svd:cpu", "vmap_svd:gpu", + "vmap_svd:gpu", } if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses: raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778") diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 44a0758080bb..e99347bd5f37 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -19,7 +19,7 @@ import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np -from typing import Tuple +from typing import List, Optional, Sequence, Tuple, Union from jaxlib import xla_client from .hlo_helpers import ( @@ -48,6 +48,38 @@ def _hlo_s32(x): np.array(x, dtype=np.int32), type=ir.IntegerType.get_signless(32))).result +def _ensure_hlo_s32(x): + return _hlo_s32(x) if isinstance(x, int) else x + +# When we generate custom calls with dynamic shapes we have to pass +# both the result_types, with ir.ShapedType.get_dynamic_size in place of +# the dynamic dimensions, and also result_shapes, which are ir.Value representing +# 1D int32 tensors. If all the shapes are static we can use result_shapes=None. +# We first construct for each result a pair with the shape and element type, +# the shape containing either integer or ir.Value. +DimensionSize = Union[int, ir.Value] # an ir.Value if not static dimension +ShapeTypePair = Tuple[Sequence[DimensionSize], ir.Type] + +def mk_result_types_and_shapes( + shape_type_pairs: Sequence[ShapeTypePair] +) -> Tuple[List[ir.Type], Optional[List[ir.Value]]]: + result_types: List[ir.Type] = [] + result_shapes: List[ir.Value] = [] + has_dynamic_shapes = any( + any(not isinstance(d, int) for d in rshape) + for rshape, _ in shape_type_pairs) + for (rshape, rtype) in shape_type_pairs: + if has_dynamic_shapes: + result_shapes.append(shape_tensor(rshape)) + result_types.append( + ir.RankedTensorType.get( + [d if isinstance(d, int) else ir.ShapedType.get_dynamic_size() + for d in rshape], + rtype)) + return (result_types, + result_shapes if has_dynamic_shapes else None) + + # TODO(phawkins): it would be nice to avoid duplicating code for each type. # ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b): @@ -298,29 +330,31 @@ def potrf_hlo(dtype, a, lower=False): return out[:2] - # # ?gesdd: Singular value decomposition -def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True): +def gesdd_hlo(dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, + a_shape_vals: Tuple[DimensionSize, ...]): _initialize() a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - b = 1 - for d in batch_dims: - b *= d + assert len(a_shape_vals) >= 2 + m, n = a_shape_vals[-2:] + assert type(m) is int + assert type(n) is int + batch_dims_vals = a_shape_vals[:-2] + num_bd = len(batch_dims_vals) + batch_size_val = ir_constant_i32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.MulOp(batch_size_val, _ensure_hlo_s32(b_v)).result i32_type = ir.IntegerType.get_signless(32) + workspace: List[ShapeTypePair] if dtype == np.float32: fn = b"lapack_sgesdd" singular_vals_type = ir.F32Type.get() lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ - ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), - ir.RankedTensorType.get([lwork], a_type.element_type), + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([lwork], a_type.element_type), ] workspace_layouts = [[0], [0]] elif dtype == np.float64: @@ -328,8 +362,8 @@ def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True): singular_vals_type = ir.F64Type.get() lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ - ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), - ir.RankedTensorType.get([lwork], a_type.element_type), + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([lwork], a_type.element_type), ] workspace_layouts = [[0], [0]] elif dtype == np.complex64: @@ -337,11 +371,9 @@ def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True): singular_vals_type = ir.F32Type.get() lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ - ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), - ir.RankedTensorType.get( - [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], - ir.F32Type.get()), - ir.RankedTensorType.get([lwork], a_type.element_type), + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()), + ([lwork], a_type.element_type), ] workspace_layouts = [[0], [0], [0]] elif dtype == np.complex128: @@ -349,11 +381,9 @@ def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True): singular_vals_type = ir.F64Type.get() lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ - ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), - ir.RankedTensorType.get( - [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], - ir.F64Type.get()), - ir.RankedTensorType.get([lwork], a_type.element_type), + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()), + ([lwork], a_type.element_type), ] workspace_layouts = [[0], [0], [0]] else: @@ -361,20 +391,19 @@ def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True): scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + + shape_type_pairs: Sequence[ShapeTypePair] = [ + (a_shape_vals, a_type.element_type), + (batch_dims_vals + (min(m, n),), singular_vals_type), + (batch_dims_vals + (m, m if full_matrices else min(m, n)), a_type.element_type), + (batch_dims_vals + (n if full_matrices else min(m, n), n), a_type.element_type), + (batch_dims_vals, i32_type), + ] + workspace + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) out = custom_call( fn, - [ - a.type, - ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type), - ir.RankedTensorType.get( - batch_dims + (m, m if full_matrices else min(m, n)), - a_type.element_type), - ir.RankedTensorType.get( - batch_dims + (n if full_matrices else min(m, n), n), - a_type.element_type), - ir.RankedTensorType.get(batch_dims, i32_type), - ] + workspace, - [_hlo_s32(int(full_matrices)), _hlo_s32(int(compute_uv)), _hlo_s32(b), + result_types, + [_hlo_s32(int(full_matrices)), _hlo_s32(int(compute_uv)), batch_size_val, _hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a], operand_layouts=[scalar_layout] * 6 + [layout], result_layouts=[ @@ -385,6 +414,7 @@ def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True): tuple(range(num_bd - 1, -1, -1)), ] + workspace_layouts, operand_output_aliases={6: 0}, + result_shapes=result_shapes ) return out[1:5]