You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have the low level resolution model running locally in inference on a GPU (RTX 4090) and call also run the high resolution (37 pressure levels) for a couple of timesteps before running out of memory.
Does anyone have any advice on parallelising across multiple GPUs or using a TPU v3-8 instance in GCP and utilising all TPU cores?
I see there is the xarray_jax.pmap function which I assume can be used for this, but I'm not sure how to use it properly.
The text was updated successfully, but these errors were encountered:
As near as I can tell, there is not presently a mechanism for dividing graphcast across GPUs. I have made some attempts via mpirun, or setting OMP threads, but I have personally not been able to achieve anything other than a 'serial' run. I'll follow this issue along with you, and see what the folks at Google have to say.
In the paper they describe that they are able to generate a 10 day forecast on 1 TPU v4 device. My mistake was assuming that a v3 device would work. The v4 has 32GB memory available per chip, whereas the v3 has only 16GB. I didn't realise this discrepancy initially and assumed that the model would need to be distributed across multiple devices to run. I able to get it running on NVIDIA GPUs that have >= 24GB memory (I don't have access to v4 TPUs).
The xarray_jax.pmap function can be used for distributing across multiple devices but only at a batch level, not at a model level.
I have the low level resolution model running locally in inference on a GPU (RTX 4090) and call also run the high resolution (37 pressure levels) for a couple of timesteps before running out of memory.
Does anyone have any advice on parallelising across multiple GPUs or using a TPU v3-8 instance in GCP and utilising all TPU cores?
I see there is the xarray_jax.pmap function which I assume can be used for this, but I'm not sure how to use it properly.
The text was updated successfully, but these errors were encountered: