Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference on multiple TPU cores / GPUs #33

Closed
ChrisAGBlake opened this issue Dec 4, 2023 · 3 comments
Closed

Inference on multiple TPU cores / GPUs #33

ChrisAGBlake opened this issue Dec 4, 2023 · 3 comments

Comments

@ChrisAGBlake
Copy link

I have the low level resolution model running locally in inference on a GPU (RTX 4090) and call also run the high resolution (37 pressure levels) for a couple of timesteps before running out of memory.
Does anyone have any advice on parallelising across multiple GPUs or using a TPU v3-8 instance in GCP and utilising all TPU cores?
I see there is the xarray_jax.pmap function which I assume can be used for this, but I'm not sure how to use it properly.

@Dadoof
Copy link

Dadoof commented Dec 4, 2023

As near as I can tell, there is not presently a mechanism for dividing graphcast across GPUs. I have made some attempts via mpirun, or setting OMP threads, but I have personally not been able to achieve anything other than a 'serial' run. I'll follow this issue along with you, and see what the folks at Google have to say.

@oubahe
Copy link

oubahe commented Dec 12, 2023

I also encountered the same problem, and I'll follow this issue along with you.

@ChrisAGBlake
Copy link
Author

In the paper they describe that they are able to generate a 10 day forecast on 1 TPU v4 device. My mistake was assuming that a v3 device would work. The v4 has 32GB memory available per chip, whereas the v3 has only 16GB. I didn't realise this discrepancy initially and assumed that the model would need to be distributed across multiple devices to run. I able to get it running on NVIDIA GPUs that have >= 24GB memory (I don't have access to v4 TPUs).

The xarray_jax.pmap function can be used for distributing across multiple devices but only at a batch level, not at a model level.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants