Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX GPU test failures #18447

Open
sampathweb opened this issue Jul 8, 2023 · 2 comments
Open

JAX GPU test failures #18447

sampathweb opened this issue Jul 8, 2023 · 2 comments
Labels

Comments

@sampathweb
Copy link
Collaborator

Install:

pip install tensorflow_cpu
pip install -U torch==2.0.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Run tests as -

KERAS_BACKEND=jax pytest keras_core --ignore keras_core/applications


Errors:
1. Assertion Error
2. Thread error (might be my Cuda setup, although not clear).
@sampathweb
Copy link
Collaborator Author

  1. Assertion Error:
keras_core/layers/attention/multi_head_attention_test.py:236:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras_core/testing/test_case.py:34: in assertAllClose
    np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (<function assert_allclose.<locals>.compare at 0x7f33340ffb50>, array([[[5.68042  , 5.68042  ],
        [4.3220215, 4.3220215]]], dtype=float32), array([[[5.679, 5.679],
        [4.32 , 4.32 ]]]))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=0.001', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError:
E           Not equal to tolerance rtol=1e-06, atol=0.001
E
E           Mismatched elements: 4 / 4 (100%)
E           Max absolute difference: 0.00202148
E           Max relative difference: 0.00046794
E            x: array([[[5.68042 , 5.68042 ],
E                   [4.322021, 4.322021]]], dtype=float32)
E            y: array([[[5.679, 5.679],
E                   [4.32 , 4.32 ]]])

@sampathweb
Copy link
Collaborator Author

  1. Thread Error Exception that occurs repeatedly (Not a test failure). Might be my local setup, will investigate further:
2023-07-07 19:53:47.065261: F external/xla/xla/stream_executor/cuda/cuda_driver.cc:149] Failed setting context: CUDA_ERROR_NOT_INITIALIZED: initialization error
Fatal Python error: Aborted

Current thread 0x00007f35b11e96c0 (most recent call first):
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/multiprocessing/popen_fork.py", line 66 in _launch
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/multiprocessing/popen_fork.py", line 19 in __init__
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/multiprocessing/context.py", line 281 in _Popen
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/multiprocessing/process.py", line 121 in start
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/multiprocessing/pool.py", line 329 in _repopulate_pool_static
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/multiprocessing/pool.py", line 340 in _maintain_pool
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/multiprocessing/pool.py", line 516 in _handle_workers
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/threading.py", line 953 in run
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/local/google/home/rameshsampath/miniconda3/envs/keras-jax2/lib/python3.10/threading.py", line 973 in _bootstrap

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, google._upb._message, tensorflow.python.framework.fast_tensor_util, charset_normalizer.md, h5py._errors, h5py.defs, h5py._objects, h5py.h5, h5py.h5r, h5py.utils, h5py.h5s, h5py.h5ac, h5py.h5p, h5py.h5t, h5py._conv, h5py.h5z, h5py._proxy, h5py.h5a, h5py.h5d, h5py.h5ds, h5py.h5g, h5py.h5i, h5py.h5f, h5py.h5fd, h5py.h5pl, h5py.h5o, h5py.h5l, h5py._selector, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._isolve._iterative, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.linalg._flinalg, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, jaxlib.cpu_feature_guard, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pandas._libs.hashing, pandas._libs.tslib, pandas._libs.ops, pandas._libs.arrays, pandas._libs.sparse, pandas._libs.reduction, pandas._libs.indexing, pandas._libs.index, pandas._libs.internals, pandas._libs.join, pandas._libs.writers, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.testing, pandas._libs.parsers, pandas._libs.json, scipy.ndimage._nd_image, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, _ni_label, scipy.ndimage._ni_label, numpy.linalg.lapack_lite, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.spatial.transform._rotation, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize.__nnls, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, scipy.stats.beta_ufunc, scipy.stats._boost.beta_ufunc, scipy.stats.binom_ufunc, scipy.stats._boost.binom_ufunc, scipy.stats.nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, scipy.stats.hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, scipy.stats.ncf_ufunc, scipy.stats._boost.ncf_ufunc, scipy.stats.ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, scipy.stats.nct_ufunc, scipy.stats._boost.nct_ufunc, scipy.stats.skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, scipy.stats.invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._statlib, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont (total: 191)

@sampathweb sampathweb changed the title Test failure in JAX GPU with TF CPU and Torch CPU JAX GPU test failures Jul 8, 2023
@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants