When installing Jax 0.5.3, the default version for nvidia-cudnn-cu12 is 9.11, which gives the following error at model compilation. Downgrading to nvidia-cudnn-cu12<=9.10 seems to fix this issue.
This is in addition to the issue https://github.com/jax-ml/jax/pull/28897 and needing to downgrade to nvidia-cublas-cu12<12.9
This is in Python 3.12.11.
INFO:phaser.utils.num:JIT-compiling kernel 'run_model'...
2025-07-28 20:58:26.880340: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:864] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
When installing Jax 0.5.3, the default version for nvidia-cudnn-cu12 is 9.11, which gives the following error at model compilation. Downgrading to nvidia-cudnn-cu12<=9.10 seems to fix this issue.
This is in addition to the issue https://github.com/jax-ml/jax/pull/28897 and needing to downgrade to nvidia-cublas-cu12<12.9
This is in Python 3.12.11.