You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
__________________ LaxBackedNumpyTests.testClipStaticBounds18 __________________
[gw6] linux -- Python 3.12.0 /home/kbuilder/.pyenv/versions/3.12.0/bin/python
tests/lax_numpy_test.py:877: in testClipStaticBounds
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
jax/_src/test_util.py:1180: in _CheckAgainstNumpy
lax_ans = lax_op(*args)
tests/lax_numpy_test.py:875: in <lambda>
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
jax/_src/pjit.py:304: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
jax/_src/pjit.py:171: in _python_pjit_helper
attrs_tracked) = _infer_params(jit_info, args, kwargs)
jax/_src/pjit.py:605: in _infer_params
jaxpr, consts, out_shardings_flat, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
jax/_src/pjit.py:1222: in _pjit_jaxpr
jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
jax/_src/linear_util.py:350: in memoized_fun
ans = call(fun, *args)
jax/_src/pjit.py:1170: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
jax/_src/profiler.py:335: in wrapper
return func(*args, **kwargs)
jax/_src/interpreters/partial_eval.py:2326: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
jax/_src/interpreters/partial_eval.py:2348: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
jax/_src/linear_util.py:192: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jax/_src/numpy/lax_numpy.py:1343: in clip
warnings.warn(
E DeprecationWarning: Clip received a complex value either through the input or the min/max keywords. Complex values have no ordering and cannot be clipped. Attempting to clip using complex numbers is deprecated and will soon raise a ValueError. Please convert to a real value or array by taking the real or imaginary components via jax.numpy.real/imag respectively.
System info (python version, jaxlib version, accelerator, etc.)
Linux x86_64
Python 3.9-3.12
CPU
The text was updated successfully, but these errors were encountered:
Description
It looks to have been caused by #20550
Test location:
https://github.com/google/jax/blob/f5cc272615ce2795f9133e63b7b535ec5ada7e52/tests/lax_numpy_test.py#L869
System info (python version, jaxlib version, accelerator, etc.)
Linux x86_64
Python 3.9-3.12
CPU
The text was updated successfully, but these errors were encountered: