Skip to content

Commit

Permalink
DOC: improve docs of transpose & matrix_transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 29, 2024
1 parent d92d939 commit b55e69f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 14 deletions.
123 changes: 112 additions & 11 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,78 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
view of the input.
"""

@util.implements(np.transpose, lax_description=_ARRAY_VIEW_DOC)
def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
"""Return a transposed version of an N-dimensional array.
JAX implementation of :func:`jax.numpy.transpose`, implemented in terms of
:func:`jax.lax.transpose`.
Args:
a: input array
axes: optionally specify the permutation using a length-`a.ndim` sequence of integers
``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e
reverses the order of all axes.
Returns:
transposed copy of the array.
See Also:
- :func:`jax.Array.transpose`: equivalent function via an :class:`~jax.Array` method.
- :attr:`jax.Array.T`: equivalent function via an :class:`~jax.Array` property.
- :func:`jax.numpy.matrix_transpose`: transpose the last two axes of an array. This is
suitable for working with batched 2D matrices.
- :func:`jax.numpy.swapaxes`: swap any two axes in an array.
- :func:`jax.numpy.moveaxis`: move an axis to another postion in the array.
Note:
Unlike :func:`numpy.transpose`, :func:`jax.numpy.transpose` will return a copy rather
than a view of the input array. However, under JIT, the compiler will optimize-away
such copies when possible, so this doesn't have performance impacts in practice.
Examples:
For a 1D array, the transpose is the identity:
>>> x = jnp.array([1, 2, 3, 4])
>>> jnp.transpose(x)
Array([1, 2, 3, 4], dtype=int32)
For a 2D array, the transpose is a matrix transpose:
>>> x = jnp.array([[1, 2],
... [3, 4]])
>>> jnp.transpose(x)
Array([[1, 3],
[2, 4]], dtype=int32)
For an N-dimensional array, the transpose reverses the order of the axes:
>>> x = jnp.zeros(shape=(3, 4, 5))
>>> jnp.transpose(x).shape
(5, 4, 3)
The ``axes`` argument can be specified to change this default behavior:
>>> jnp.transpose(x, (0, 2, 1)).shape
(3, 5, 4)
Since swapping the last two axes is a common operation, it can be done
via its own API, :func:`jax.numpy.matrix_transpose`:
>>> jnp.matrix_transpose(x).shape
(3, 5, 4)
For convenience, transposes may also be performed using the :meth:`jax.Array.transpose`
method or the :attr:`jax.Array.T` property:
>>> x = jnp.array([[1, 2],
... [3, 4]])
>>> x.transpose()
Array([[1, 3],
[2, 4]], dtype=int32)
>>> x.T
Array([[1, 3],
[2, 4]], dtype=int32)
"""
util.check_arraylike("transpose", a)
axes_ = list(range(ndim(a))[::-1]) if axes is None else axes
axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_]
Expand All @@ -555,19 +625,50 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array:
return lax.transpose(a, axes)


@util.implements(getattr(np, 'matrix_transpose', None))
def matrix_transpose(x: ArrayLike, /) -> Array:
"""Transposes the last two dimensions of x.
"""Transpose the last two dimensions of an array.
JAX implementation of :func:`jax.numpy.matrix_transpose`, implemented in terms of
:func:`jax.lax.transpose`.
Parameters
----------
x : array_like
Input array. Must have ``x.ndim >= 2``.
Args:
x: input array, Must have ``x.ndim >= 2``
Returns:
matrix-transposed copy of the array.
Returns
-------
xT : Array
Transposed array.
See Also:
- :attr:`jax.Array.mT`: same operation accessed via an :func:`~jax.Array` property.
- :func:`jax.numpy.transpose`: general multi-axis transpose
Note:
Unlike :func:`numpy.matrix_transpose`, :func:`jax.numpy.matrix_transpose` will return a
copy rather than a view of the input array. However, under JIT, the compiler will
optimize-away such copies when possible, so this doesn't have performance impacts in practice.
Examples:
Here is a 2x2x2 matrix representing a batched 2x2 matrix:
>>> x = jnp.array([[[1, 2],
... [3, 4]],
... [[5, 6],
... [7, 8]]])
>>> jnp.matrix_transpose(x)
Array([[[1, 3],
[2, 4]],
<BLANKLINE>
[[5, 7],
[6, 8]]], dtype=int32)
For convenience, you can perform the same transpose via the :attr:`~jax.Array.mT`
property of :class:`jax.Array`:
>>> x.mT
Array([[[1, 3],
[2, 4]],
<BLANKLINE>
[[5, 7],
[6, 8]]], dtype=int32)
"""
util.check_arraylike("matrix_transpose", x)
ndim = np.ndim(x)
Expand Down
14 changes: 11 additions & 3 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6018,9 +6018,17 @@ def test_lax_numpy_docstrings(self):

# Functions that have their own docstrings & don't wrap numpy.
known_exceptions = {
'fromfile', 'fromiter', 'frompyfunc', 'vectorize',
'argwhere', 'where', 'nonzero', 'flatnonzero'}

'argwhere',
'flatnonzero',
'fromfile',
'fromiter',
'frompyfunc',
'matrix_transpose',
'nonzero',
'transpose',
'vectorize',
'where',
}
for name in dir(jnp):
if name in known_exceptions or name.startswith('_'):
continue
Expand Down

0 comments on commit b55e69f

Please sign in to comment.