Skip to content

Commit

Permalink
Merge pull request #10072 from jakevdp:fromiter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 438141629
  • Loading branch information
jax authors committed Mar 29, 2022
2 parents 23c783a + fbfc3d8 commit b31cf89
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ namespace; they are listed below.
fmod
frexp
frombuffer
fromfile
fromfunction
fromiter
fromstring
full
full_like
Expand Down
36 changes: 36 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,42 @@ def frombuffer(buffer, dtype=float, count=-1, offset=0):
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))


def fromfile(*args, **kwargs):
"""Unimplemented JAX wrapper for jnp.fromfile.
This function is left deliberately unimplemented because it may be non-pure and thus
unsafe for use with JIT and other JAX transformations. Consider using
``jnp.asarray(np.fromfile(...))`` instead, although care should be taken if ``np.fromfile``
is used within jax transformations because of its potential side-effect of consuming the
file object; for more information see `Common Gotchas: Pure Functions
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions>`_.
"""
raise NotImplementedError(
"jnp.fromfile() is not implemented because it may be non-pure and thus unsafe for use "
"with JIT and other JAX transformations. Consider using jnp.asarray(np.fromfile(...)) "
"instead, although care should be taken if np.fromfile is used within a jax transformations "
"because of its potential side-effect of consuming the file object; for more information see "
"https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions")


def fromiter(*args, **kwargs):
"""Unimplemented JAX wrapper for jnp.fromiter.
This function is left deliberately unimplemented because it may be non-pure and thus
unsafe for use with JIT and other JAX transformations. Consider using
``jnp.asarray(np.fromiter(...))`` instead, although care should be taken if ``np.fromiter``
is used within jax transformations because of its potential side-effect of consuming the
iterable object; for more information see `Common Gotchas: Pure Functions
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions>`_.
"""
raise NotImplementedError(
"jnp.fromiter() is not implemented because it may be non-pure and thus unsafe for use "
"with JIT and other JAX transformations. Consider using jnp.asarray(np.fromiter(...)) "
"instead, although care should be taken if np.fromiter is used within a jax transformations "
"because of its potential side-effect of consuming the iterable object; for more information see "
"https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions")


@_wraps(np.fromfunction)
def fromfunction(function, shape, *, dtype=float, **kwargs):
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
Expand Down
2 changes: 2 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@
fmax as fmax,
fmin as fmin,
frombuffer as frombuffer,
fromfile as fromfile,
fromfunction as fromfunction,
fromiter as fromiter,
fromstring as fromstring,
full as full,
full_like as full_like,
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6245,7 +6245,7 @@ def test_lax_numpy_docstrings(self):
# Test that docstring wrapping & transformation didn't fail.

# Functions that have their own docstrings & don't wrap numpy.
known_exceptions = {'broadcast_arrays', 'vectorize'}
known_exceptions = {'broadcast_arrays', 'fromfile', 'fromiter', 'vectorize'}

for name in dir(jnp):
if name in known_exceptions or name.startswith('_'):
Expand Down

0 comments on commit b31cf89

Please sign in to comment.