Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test failure in scico.jax with latest jax version 0.4.29 #535

Closed
bwohlberg opened this issue Jun 13, 2024 · 1 comment · Fixed by #541
Closed

Test failure in scico.jax with latest jax version 0.4.29 #535

bwohlberg opened this issue Jun 13, 2024 · 1 comment · Fixed by #541
Assignees
Labels
bug Something isn't working developer Developer environment: issues related to CI, git, etc.

Comments

@bwohlberg
Copy link
Collaborator

Jax release 0.4.29 appears to again have broken a component of scico.jax (full log)

============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.2.2, pluggy-1.5.0
rootdir: /home/runner/work/scico/scico
configfile: pytest.ini
testpaths: scico/test, docs
plugins: split-0.8.2
collected 3329 items / 3 skipped
scico/test/flax/test_apply.py .......                                    [  0%]
scico/test/flax/test_checkpoints.py ....                                 [  0%]
scico/test/flax/test_clu.py .....                                        [  0%]
scico/test/flax/test_examples_flax.py ss..ssssF......................... [  1%]
.....                                                                    [  1%]
scico/test/flax/test_flax.py ..........................                  [  2%]
[...]
=================================== FAILURES ===================================
__________________________ test_blur_data_generation ___________________________
>   ???
_mt19937.pyx:180: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py:766: in __index__
    raise self.aval._index(self)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
self = ShapedArray(int64[])
arg = Traced<ShapedArray(int64[])>with<BatchTrace(level=1/0)> with
  val = Array([0], dtype=int64)
  batch_dim = 0
    def error(self, arg):
>     raise TracerIntegerConversionError(arg)
E     jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
E     This BatchTracer with object id 140001038232176 was created on line:
E       /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
E     See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py:1508: TracerIntegerConversionError
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "_mt19937.pyx", line 180, in numpy.random._mt19937.MT19937._legacy_seeding
  File "/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py", line 766, in __index__
    raise self.aval._index(self)
  File "/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py", line 1508, in error
    raise TracerIntegerConversionError(arg)
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
This BatchTracer with object id 140001038232176 was created on line:
  /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
    def test_blur_data_generation():
        N = 32
        nimg = 8
        n = 3  # convolution kernel size
        blur_kernel = np.ones((n, n)) / (n * n)
    
        def random_img_gen(seed, size, ndata):
            np.random.seed(seed)
            return np.random.randn(ndata, size, size, 1)
    
>       img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen)
scico/test/flax/test_examples_flax.py:157: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
scico/flax/examples/data_generation.py:318: in generate_blur_data
    img = distributed_data_generation(imgfunc, size, nimg, False)
scico/flax/examples/data_generation.py:382: in distributed_data_generation
    imgs = jax.vmap(imgenf, (0, None, None))(idx, size, ndata_per_proc)
scico/test/flax/test_examples_flax.py:154: in random_img_gen
    np.random.seed(seed)
numpy/random/mtrand.pyx:4806: in numpy.random.mtrand.seed
    ???
numpy/random/mtrand.pyx:250: in numpy.random.mtrand.RandomState.seed
    ???
_mt19937.pyx:168: in numpy.random._mt19937.MT19937._legacy_seeding
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
>   ???
E   jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[].
E   This BatchTracer with object id 140001038232176 was created on line:
E     /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
E   See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
_mt19937.pyx:185: TracerArrayConversionError
=========================== short test summary info ============================
FAILED scico/test/flax/test_examples_flax.py::test_blur_data_generation - jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[].
This BatchTracer with object id 140001038232176 was created on line:
  /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
====== 1 failed, 3280 passed, 24 skipped, 27 xfailed in 294.89s (0:04:54) ======
Error: Process completed with exit code 1.
@bwohlberg bwohlberg added bug Something isn't working developer Developer environment: issues related to CI, git, etc. labels Jun 13, 2024
@bwohlberg
Copy link
Collaborator Author

scico/test/flax/test_examples_flax.py tests are also failing on jax 0.4.28 (nominally supported according to current requirements.txt) on GPU device.

@bwohlberg bwohlberg linked a pull request Jul 12, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working developer Developer environment: issues related to CI, git, etc.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants