Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA OOM with jax/pytorch notebook #489

Closed
P-Schumacher opened this issue Jun 5, 2024 · 1 comment
Closed

CUDA OOM with jax/pytorch notebook #489

P-Schumacher opened this issue Jun 5, 2024 · 1 comment

Comments

@P-Schumacher
Copy link

Hi,
the notebook on the jax + torch tutorial is very nice and useful for me, but it uses a certain flag:
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

I understand this flag prevents a CUDA OOM issue, but it has been mentioned by the jax team that it also strongly slows down computation
https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

I tried to remove it and solve the memory issues in other ways, but I haven't been successful so far.
Is there any update from your team on this or maybe at least a current guess on where this memory leak is originating? Any kind of info would be very helpful.

@erikfrey
Copy link
Collaborator

erikfrey commented Jul 2, 2024

Hi @P-Schumacher - back when we wrote this colab, we did not see significant change in performance with/without this particular flag. Our training workloads spend >99% of their time on device after initial setup, and don't release their DeviceArray buffers in a way that would cause deallocations / thrashing during training.

But I could be wrong! I'm going to close this for now, but if you find evidence that this flag is significantly impacting performance, please let us know.

@erikfrey erikfrey closed this as completed Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants