Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6215,9 +6215,86 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
return from_dlpack(x, device=device, copy=copy)

@util.implements(np.fromfunction)

def fromfunction(function: Callable[..., Array], shape: Any,
*, dtype: DTypeLike = float, **kwargs) -> Array:
"""Create an array from a function applied over indices.

JAX implementation of :func:`numpy.fromfunction`. The JAX implementation
differs in that it dispatches via :func:`jax.vmap`, and so unlike in NumPy
the function logically operates on scalar inputs, and need not explicitly
handle broadcasted inputs (See *Examples* below).

Args:
function: a function that takes *N* dynamic scalars and outputs a scalar.
shape: a length-*N* tuple of integers specifying the output shape.
dtype: optionally specify the dtype of the inputs. Defaults to floating-point.
kwargs: additional keyword arguments are passed statically to ``function``.

Returns:
An array of shape ``shape`` if ``function`` returns a scalar, or in general
a pytree of arrays with leading dimensions ``shape``, as determined by the
output of ``function``.

See also:
- :func:`jax.vmap`: the core transformation that the :func:`fromfunction`
API is built on.

Examples:
Generate a multiplication table of a given shape:

>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int)
Array([[ 0, 0, 0, 0, 0, 0],
[ 0, 1, 2, 3, 4, 5],
[ 0, 2, 4, 6, 8, 10]], dtype=int32)

When ``function`` returns a non-scalar the output will have leading
dimension of ``shape``:

>>> def f(x):
... return (x + 1) * jnp.arange(3)
>>> jnp.fromfunction(f, shape=(2,))
Array([[0., 1., 2.],
[0., 2., 4.]], dtype=float32)

``function`` may return multiple results, in which case each is mapped
independently:

>>> def f(x, y):
... return x + y, x * y
>>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5))
>>> print(x_plus_y)
[[0. 1. 2. 3. 4.]
[1. 2. 3. 4. 5.]
[2. 3. 4. 5. 6.]]
>>> print(x_times_y)
[[0. 0. 0. 0. 0.]
[0. 1. 2. 3. 4.]
[0. 2. 4. 6. 8.]]

The JAX implementation differs slightly from NumPy's implementation. In
:func:`numpy.fromfunction`, the function is expected to explicitly operate
element-wise on the full grid of input values:

>>> def f(x, y):
... print(f"{x.shape = }\\n{y.shape = }")
... return x + y
...
>>> np.fromfunction(f, (2, 3))
x.shape = (2, 3)
y.shape = (2, 3)
array([[0., 1., 2.],
[1., 2., 3.]])

In :func:`jax.numpy.fromfunction`, the function is vectorized via
:func:`jax.vmap`, and so is expected to operate on scalar values:

>>> jnp.fromfunction(f, (2, 3))
x.shape = ()
y.shape = ()
Array([[0., 1., 2.],
[1., 2., 3.]], dtype=float32)
"""
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
for i in range(len(shape)):
in_axes = [0 if i == j else None for j in range(len(shape))]
Expand Down
Loading