Skip to content

Commit

Permalink
Document jax.lax.Precision
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 8, 2021
1 parent f126100 commit f2a9590
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 40 deletions.
102 changes: 71 additions & 31 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,20 +527,60 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
"""
return concatenate_p.bind(*operands, dimension=dimension)

Precision = xla_client.PrecisionConfig.Precision
Precision.__str__ = lambda precision: precision.name # type: ignore

class _enum_descriptor(object):
def __init__(self, val):
self.val = val
def __get__(self, _, owner):
return owner(self.val)


class Precision(xla_client.PrecisionConfig.Precision): # type: ignore
"""Precision enum for lax functions
The `precision` argument to JAX functions generally controls the tradeoff
between speed and accuracy for array computations on accelerator backends,
(i.e. TPU and GPU). Members are:
DEFAULT:
Fastest mode, but least accurate. Performs computations in bfloat16.
Aliases: ``'default'``, ``'fastest'``, ``'bfloat16'``.
HIGH:
Slower but more accurate. Performs float32 computations in 3 bfloat16
passes, or using tensorfloat32 where available. Aliases: ``'high'`,
``'bfloat16_3x'``, ``'tensorfloat32'``.
HIGHEST:
Slowest but most accurate. Performs computations in float32 or float64
as applicable. Aliases: ``'highest'``, ``'float32'``.
"""
# Wrap enum values with this class.
DEFAULT = _enum_descriptor('default')
HIGH = _enum_descriptor('high')
HIGHEST = _enum_descriptor('highest')

_strings = {
'highest': xla_client.PrecisionConfig.Precision.HIGHEST,
'float32': xla_client.PrecisionConfig.Precision.HIGHEST,
'high': xla_client.PrecisionConfig.Precision.HIGH,
'bfloat16_3x': xla_client.PrecisionConfig.Precision.HIGH,
'tensorfloat32': xla_client.PrecisionConfig.Precision.HIGH,
'default': xla_client.PrecisionConfig.Precision.DEFAULT,
'bfloat16': xla_client.PrecisionConfig.Precision.DEFAULT,
'fastest': xla_client.PrecisionConfig.Precision.DEFAULT,
None: xla_client.PrecisionConfig.Precision.DEFAULT,
}
def __init__(self, arg0):
arg0 = self._strings.get(arg0, arg0)
super().__init__(arg0)

def __str__(self) -> str:
return self.name


PrecisionType = Any
PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
Tuple[PrecisionType, PrecisionType]]
_precision_strings = {
'highest': Precision.HIGHEST,
'float32': Precision.HIGHEST,
'bfloat16_3x': Precision.HIGH,
'tensorfloat32': Precision.HIGH,
'bfloat16': Precision.DEFAULT,
'fastest': Precision.DEFAULT,
None: Precision.DEFAULT,
}


class ConvDimensionNumbers(NamedTuple):
"""Describes batch, spatial, and feature dimensions of a convolution.
Expand Down Expand Up @@ -595,10 +635,10 @@ def conv_general_dilated(
feature_group_count: integer, default 1. See XLA HLO docs.
batch_group_count: integer, default 1. See XLA HLO docs.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``), a string (e.g. 'highest' or
'fastest', see the ``jax.default_matmul_precision`` context manager), or a
tuple of two ``lax.Precision`` enums or strings indicating precision of
tuple of two :class:`~jax.lax.Precision` enums or strings indicating precision of
``lhs`` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
Expand Down Expand Up @@ -674,9 +714,9 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
lhs: an array of rank 1 or 2.
rhs: an array of rank 1 or 2.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Expand Down Expand Up @@ -712,9 +752,9 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Expand Down Expand Up @@ -1841,9 +1881,9 @@ def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
strides.
padding: either the string `'SAME'`, the string `'VALID'`.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Expand Down Expand Up @@ -1879,9 +1919,9 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Expand Down Expand Up @@ -1958,9 +1998,9 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
applied to the same kernel. For typical use in neural nets this is completely
pointless and just makes input/output channel specification confusing.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Expand Down Expand Up @@ -7212,28 +7252,28 @@ def canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[Precision
if config.jax_default_matmul_precision is None:
return None
try:
precision = _precision_strings[config.jax_default_matmul_precision]
precision = Precision(config.jax_default_matmul_precision)
return (precision, precision)
except KeyError:
except TypeError:
raise ValueError(
"jax_default_matmul_precision flag must be set to None or a value in "
f"{_precision_strings}, but got {config.jax_default_matmul_precision}"
f"{list(Precision._strings)}, but got {config.jax_default_matmul_precision}"
) from None
elif isinstance(precision, str) and precision in _precision_strings:
precision = _precision_strings.get(precision)
elif isinstance(precision, str) and precision in Precision._strings:
precision = Precision(precision)
return (precision, precision)
elif isinstance(precision, Precision):
elif isinstance(precision, xla_client.PrecisionConfig.Precision):
return (precision, precision)
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(p, Precision) for p in precision)):
all(isinstance(p, xla_client.PrecisionConfig.Precision) for p in precision)):
return precision # type: ignore[return-value]
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(s, str) for s in precision)):
s1, s2 = precision
return (canonicalize_precision(s1)[0], canonicalize_precision(s2)[0]) # type: ignore
else:
raise ValueError(
f"Precision argument must be None, a string in {_precision_strings}, "
f"Precision argument must be None, a string in {list(Precision._strings)}, "
"a lax.Precision value or a tuple of two lax.Precision values or "
f"strings; got {precision}.")

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def conv_general_dilated_patches(
`(lhs_spec, rhs_spec, out_spec)`, where each element is a string
of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`.
precision: Optional. Either ``None``, which means the default precision for
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
the backend, or a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
If side is "right" then `a = up`. If side is "left" then `a = pu`. The
default is "right".
method: Determines the algorithm used, as described above.
precision: Controls the TPU matrix multiplication precision.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
The remaining arguments are only meaningful if method is "qdwh".
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@
In addition to the original NumPy arguments listed below, also supports
``precision`` for extra control over matrix-multiplication precision
on supported devices. ``precision`` may be set to ``None``, which means
default precision for the backend, a ``lax.Precision`` enum value
default precision for the backend, a :class:`~jax.lax.Precision` enum value
(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple
of two ``lax.Precision`` enums indicating separate precision for each argument.
of two :class:`~jax.lax.Precision` enums indicating separate precision for each argument.
"""

# We replace some builtin names to follow Numpy's API, so we capture here.
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/scipy/eigh.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _split_spectrum_jittable(P, H, V0, rank, precision):
H: Matrix to be projected.
V0: Accumulates the isometries into the projected subspaces.
rank: Rank of P.
precision: The matmul precision.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
Returns:
H1, V1: Projection of H into the column space of P, and the accumulated
isometry performing that projection.
Expand Down Expand Up @@ -128,7 +128,7 @@ def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
H: The Hermitian matrix to split.
split_point: The eigenvalue to split along.
V0: Matrix of isometries to be updated.
precision: TPU matmul precision.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
Returns:
Hm: A Hermitian matrix sharing the eigenvalues of `H` beneath
`split_point`.
Expand Down Expand Up @@ -164,7 +164,7 @@ def _eigh_work(
Args:
H: The Hermitian input.
V: Stores the isometries projecting H into its subspaces.
precision: The matmul precision.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
Returns:
H, V: The result of the projection.
Expand Down Expand Up @@ -197,7 +197,7 @@ def eigh(
Args:
H: The `n x n` Hermitian input.
precision: The matmul precision.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used.
termination_size: Recursion ends once the blocks reach this linear size.
Returns:
Expand Down Expand Up @@ -225,7 +225,7 @@ def svd(A, precision=lax.Precision.HIGHEST):
Args:
A: The `m` by `n` input matrix.
precision: TPU matmul precision.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
Returns:
U: An `m` by `m` unitary matrix of `A`'s left singular vectors.
S: A length-`min(m, n)` vector of `A`'s singular values.
Expand Down

0 comments on commit f2a9590

Please sign in to comment.