You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
the notebook on the jax + torch tutorial is very nice and useful for me, but it uses a certain flag: os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
I tried to remove it and solve the memory issues in other ways, but I haven't been successful so far.
Is there any update from your team on this or maybe at least a current guess on where this memory leak is originating? Any kind of info would be very helpful.
The text was updated successfully, but these errors were encountered:
Hi @P-Schumacher - back when we wrote this colab, we did not see significant change in performance with/without this particular flag. Our training workloads spend >99% of their time on device after initial setup, and don't release their DeviceArray buffers in a way that would cause deallocations / thrashing during training.
But I could be wrong! I'm going to close this for now, but if you find evidence that this flag is significantly impacting performance, please let us know.
Hi,
the notebook on the jax + torch tutorial is very nice and useful for me, but it uses a certain flag:
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
I understand this flag prevents a CUDA OOM issue, but it has been mentioned by the jax team that it also strongly slows down computation
https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
I tried to remove it and solve the memory issues in other ways, but I haven't been successful so far.
Is there any update from your team on this or maybe at least a current guess on where this memory leak is originating? Any kind of info would be very helpful.
The text was updated successfully, but these errors were encountered: