Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX GPU Test aborts in PyDatasetAdapter if multiprocessing=True #18431

Open
sampathweb opened this issue Jul 26, 2023 · 1 comment
Open

JAX GPU Test aborts in PyDatasetAdapter if multiprocessing=True #18431

sampathweb opened this issue Jul 26, 2023 · 1 comment
Labels

Comments

@sampathweb
Copy link
Collaborator

Repro steps in Colab Pro with GPU enabled terminal:

pip install -U tensorflow  # Update TF to 2.13
git clone https://github.com/keras-team/keras-core.git
cd keras-core

KERAS_BACKEND=jax pytest keras_core --ignore keras_core/applications

This will abort at 98% in test - PyDatasetAdapterTest when it runs with multiprocessing=True option.

jax-colab-test-run-error2

Buf when its running the test independently or even at higher level folders it doesn't abort. Also, it doesn't abort for TensorFlow or Torch backends. I am not sure why it aborts only for JAX GPU when running the entire test suite. May be multiprocessing with pytest doesn't play well with XLA / JAX / Cuda?

@fchollet
Copy link
Member

Multiprocessing in Python is pretty broken in general, so I wouldn't be surprised if there's some weird interaction going on.

@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants