Skip to content

Commit

Permalink
[shape_poly] linalg.svd: shape polymorphism for native serialization …
Browse files Browse the repository at this point in the history
…on CPU

PiperOrigin-RevId: 542483203
  • Loading branch information
gnecula authored and jax authors committed Jun 22, 2023
1 parent 1535fa0 commit 92288d3
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 49 deletions.
22 changes: 21 additions & 1 deletion jax/_src/interpreters/mlir.py
Expand Up @@ -21,6 +21,7 @@
from functools import partial
import io
import itertools
import operator
import re
import typing
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
Expand Down Expand Up @@ -559,11 +560,14 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
ctx = ctx.replace(
primitive="eval_dynamic_shape",
avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars))

res = lower_fun(
partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
multiple_results=True)(ctx, *ctx.dim_var_values)
return util.flatten(res) # type: ignore
return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir
for d, d_ir in zip(shape, util.flatten(res))) # type: ignore

# TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Value, ...]:
"""Evaluates the dynamic shapes as int32 values."""
Expand All @@ -579,6 +583,22 @@ def convert_dim(d: Union[int, Value]):
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))


def eval_dynamic_shape_as_ivals(
ctx: LoweringRuleContext, shape: core.Shape
) -> Tuple[Union[int, Value], ...]:
"""Evaluates the dynamic shapes as int or ir.int32 values."""
def convert_dim(d: Union[int, Value]) -> Union[int, ir.Value]:
if type(d) is int:
return d
else:
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
if d.type != i32_type: # type: ignore
return hlo.ConvertOp(i32_type, d).result
else:
return d
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))


class LoweringResult(NamedTuple):
module: ir.Module
keepalive: Optional[Any]
Expand Down
37 changes: 28 additions & 9 deletions jax/_src/lax/linalg.py
Expand Up @@ -1715,21 +1715,37 @@ def _empty_svd(a, *, full_matrices, compute_uv):
return s, u, v

def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
compute_uv):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (svd); b/261671778")
compute_uv, platform: str):
operand_aval, = ctx.avals_in
s_aval = ctx.avals_out[0]
m, n = operand_aval.shape[-2:]
# Since the last two dimensions (m, n) are used to compute the workspace
# size, we support dynamic dimensions only for the batch size for now.
if not is_constant_shape([m, n]):
raise NotImplementedError(
"Shape polymorphism for native serialization for svd on CPU and GPU is "
f"implemented only for the batch dimensions: {operand_aval.shape}")
batch_dims = operand_aval.shape[:-2]

if m == 0 or n == 0:
return mlir.lower_fun(_empty_svd, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)

s, u, vt, info = gesvd_impl(operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv)
if platform in ["cuda", "rocm"] or jaxlib_version < (0, 4, 13):
if not is_constant_shape(operand_aval.shape):
# TODO(necula): remove the platform kwarg when we implement GPU support.
raise NotImplementedError(
"Shape polymorphism for native serialization for SVD is not "
f"implemented, try to upgrade jaxlib; b/261671778; {operand_aval.shape}")
s, u, vt, info = gesvd_impl(operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
s, u, vt, info = gesvd_impl(operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
a_shape_vals=a_shape_vals)
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
select_s_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
Expand Down Expand Up @@ -1805,13 +1821,16 @@ def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv):
batching.primitive_batchers[svd_p] = _svd_batching_rule

mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo),
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo,
platform='cpu'),
platform='cpu')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd),
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd,
platform='cuda'),
platform='cuda')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd),
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd,
platform='rocm'),
platform='rocm')

mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/jax_export.py
Expand Up @@ -721,6 +721,8 @@ def _check_lowering(lowering) -> None:
"cusolver_geqrf", "cusolver_orgqr",
# qr and svd on TPU
"Qr", "ProductOfElementaryHouseholderReflectors",
# svd on CPU
"lapack_sgesdd", "lapack_dsesdd", "lapack_cgesdd", "lapack_zgesdd",
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
# # lu on CPU
# "lapack_sgetrf" , "lapack_dgetrf" , "lapack_cgetrf" , "lapack_zgetrf",
Expand Down
5 changes: 4 additions & 1 deletion jax/experimental/jax2tf/tests/back_compat_test.py
Expand Up @@ -414,7 +414,10 @@ def test_custom_call_coverage(self):

covered_targets = covered_targets.union({
# TODO(necula): add tests for eig on CPU
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev"})
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
# TODO(necula): add tests for svd on CPU
"lapack_sgesdd", "lapack_dsesdd", "lapack_cgesdd", "lapack_zgesdd",
})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered)

Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/shape_poly_test.py
Expand Up @@ -2811,7 +2811,7 @@ def test_harness(self, harness: PolyHarness):
# custom_linear_solve uses lu
"vmap_custom_linear_solve:cpu", "vmap_custom_linear_solve:gpu",
"vmap_qr:cpu", "vmap_qr:gpu",
"vmap_svd:cpu", "vmap_svd:gpu",
"vmap_svd:gpu",
}
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
Expand Down
104 changes: 67 additions & 37 deletions jaxlib/lapack.py
Expand Up @@ -19,7 +19,7 @@
import jaxlib.mlir.dialects.stablehlo as hlo

import numpy as np
from typing import Tuple
from typing import List, Optional, Sequence, Tuple, Union
from jaxlib import xla_client

from .hlo_helpers import (
Expand Down Expand Up @@ -48,6 +48,38 @@ def _hlo_s32(x):
np.array(x, dtype=np.int32),
type=ir.IntegerType.get_signless(32))).result

def _ensure_hlo_s32(x):
return _hlo_s32(x) if isinstance(x, int) else x

# When we generate custom calls with dynamic shapes we have to pass
# both the result_types, with ir.ShapedType.get_dynamic_size in place of
# the dynamic dimensions, and also result_shapes, which are ir.Value representing
# 1D int32 tensors. If all the shapes are static we can use result_shapes=None.
# We first construct for each result a pair with the shape and element type,
# the shape containing either integer or ir.Value.
DimensionSize = Union[int, ir.Value] # an ir.Value if not static dimension
ShapeTypePair = Tuple[Sequence[DimensionSize], ir.Type]

def mk_result_types_and_shapes(
shape_type_pairs: Sequence[ShapeTypePair]
) -> Tuple[List[ir.Type], Optional[List[ir.Value]]]:
result_types: List[ir.Type] = []
result_shapes: List[ir.Value] = []
has_dynamic_shapes = any(
any(not isinstance(d, int) for d in rshape)
for rshape, _ in shape_type_pairs)
for (rshape, rtype) in shape_type_pairs:
if has_dynamic_shapes:
result_shapes.append(shape_tensor(rshape))
result_types.append(
ir.RankedTensorType.get(
[d if isinstance(d, int) else ir.ShapedType.get_dynamic_size()
for d in rshape],
rtype))
return (result_types,
result_shapes if has_dynamic_shapes else None)


# TODO(phawkins): it would be nice to avoid duplicating code for each type.

# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
Expand Down Expand Up @@ -298,83 +330,80 @@ def potrf_hlo(dtype, a, lower=False):
return out[:2]



# # ?gesdd: Singular value decomposition

def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True):
def gesdd_hlo(dtype, a: ir.Value, *, full_matrices=True, compute_uv=True,
a_shape_vals: Tuple[DimensionSize, ...]):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
assert len(a_shape_vals) >= 2
m, n = a_shape_vals[-2:]
assert type(m) is int
assert type(n) is int
batch_dims_vals = a_shape_vals[:-2]
num_bd = len(batch_dims_vals)
batch_size_val = ir_constant_i32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, _ensure_hlo_s32(b_v)).result

i32_type = ir.IntegerType.get_signless(32)
workspace: List[ShapeTypePair]
if dtype == np.float32:
fn = b"lapack_sgesdd"
singular_vals_type = ir.F32Type.get()
lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
([_lapack.gesdd_iwork_size(m, n)], i32_type),
([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0]]
elif dtype == np.float64:
fn = b"lapack_dgesdd"
singular_vals_type = ir.F64Type.get()
lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
([_lapack.gesdd_iwork_size(m, n)], i32_type),
([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0]]
elif dtype == np.complex64:
fn = b"lapack_cgesdd"
singular_vals_type = ir.F32Type.get()
lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type),
ir.RankedTensorType.get(
[_lapack.cgesdd_rwork_size(m, n, int(compute_uv))],
ir.F32Type.get()),
ir.RankedTensorType.get([lwork], a_type.element_type),
([_lapack.gesdd_iwork_size(m, n)], i32_type),
([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()),
([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0], [0]]
elif dtype == np.complex128:
fn = b"lapack_zgesdd"
singular_vals_type = ir.F64Type.get()
lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type),
ir.RankedTensorType.get(
[_lapack.cgesdd_rwork_size(m, n, int(compute_uv))],
ir.F64Type.get()),
ir.RankedTensorType.get([lwork], a_type.element_type),
([_lapack.gesdd_iwork_size(m, n)], i32_type),
([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()),
([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0], [0]]
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

shape_type_pairs: Sequence[ShapeTypePair] = [
(a_shape_vals, a_type.element_type),
(batch_dims_vals + (min(m, n),), singular_vals_type),
(batch_dims_vals + (m, m if full_matrices else min(m, n)), a_type.element_type),
(batch_dims_vals + (n if full_matrices else min(m, n), n), a_type.element_type),
(batch_dims_vals, i32_type),
] + workspace
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
ir.RankedTensorType.get(
batch_dims + (m, m if full_matrices else min(m, n)),
a_type.element_type),
ir.RankedTensorType.get(
batch_dims + (n if full_matrices else min(m, n), n),
a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
] + workspace,
[_hlo_s32(int(full_matrices)), _hlo_s32(int(compute_uv)), _hlo_s32(b),
result_types,
[_hlo_s32(int(full_matrices)), _hlo_s32(int(compute_uv)), batch_size_val,
_hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a],
operand_layouts=[scalar_layout] * 6 + [layout],
result_layouts=[
Expand All @@ -385,6 +414,7 @@ def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True):
tuple(range(num_bd - 1, -1, -1)),
] + workspace_layouts,
operand_output_aliases={6: 0},
result_shapes=result_shapes
)
return out[1:5]

Expand Down

0 comments on commit 92288d3

Please sign in to comment.