New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support for TPU - sets environment variables correctly to use T… #863
Conversation
…PU with jax and launchpad
edan seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
Codecov Report
@@ Coverage Diff @@
## develop #863 +/- ##
========================================
Coverage 93.54% 93.54%
========================================
Files 167 167
Lines 9259 9263 +4
========================================
+ Hits 8661 8665 +4
Misses 598 598
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Hey Edan thanks for the change, this really cool. Just one thing can you do
Will this work or does it have to be exported before you start running python? Kinda hard for us to test, because we haven't really done too much with Mava on TPUs. |
I'm not exactly sure if that would work as it seems that as soon as JAX is imported, the TPU device is registered in use but I can maybe try it out just not exactly sure when I can. |
Okay so I've done some testing. Unfortunately it looks like if someone wants to do it in the code - it needs to happen in the run script before jax is imported. I've tried it in system and I've tried it in the builder. Additionally, I tried to update the jax config as well and nothing seems to work. I'm sure there is other ways of doing it but I strongly believe it will probably need to be done in the first script that is run. If someone is going to use a TPU i dont think its too much to ask them to simply run |
@EdanToledo Thanks for the PR! Just to clarify, this is for if you want to run on tpus, but you want certain nodes to not be on tpu (i.e. on CPU)? Prev, we would just set them not to run on GPU and this would still mean they could still run on tpus. So the above didn't work? You had to add |
@KaleabTessera yeah seemingly it didn't work. Just tried it again and you get the following error:
Essentially, as soon as Jax is imported with the TPU being an option, it locks the TPU for use from other processes even if you specify to use the cpu. Only the trainer crashes, everything else can run (with the code changes in the PR present) but if you do the |
Closing for now since the change didn't have any impact. |
@EdanToledo Please reopen if it is still an issue! |
Hi Kale-ab, I'm not sure what you mean by no impact - the trainer is much faster on a TPU and without this change, the trainer cannot run on a TPU. But if its not on mava's road plan then closing is fine. |
TPU support
What?
Changed Environment variables in lp_utils.to_device function to set up that only "nodes_on_gpu" can see the TPU and other nodes can only see CPU. This allows the trainer to run on a TPU. Additionally, a new config parameter simply called "use_tpu" was added and threaded through the launcher.
Why?
This is due to launchpad processes crashing if more than one process tries to use a TPU.
How?
As stated in "What", the environment variables decide which platform JAX uses.
Extra
There is a slight problem when wanting to use a TPU. The base python environment (the one calling the training script) needs to be set to only see a CPU otherwise it will crash for the same reason as stated above. This is simple to do through
export JAX_PLATFORMS="cpu"
. One thing that has not been considered in this PR is if someone wants to put certain nodes on the TPU and other nodes on the GPU but that is quite fine grained and can be easily added later down the line. It gets quite complicated as TPUs can only have a single model running on it so I'm also not sure how this will work for non-parameter sharing situations i.e heterogenous agents. One would need multiple TPUs if they wanted to run multiple nodes on the TPU as only one process can occupy a TPU at a time. I'm not sure if JAX would would automatic assigning if there were multiple TPUs but for now having at least some TPU support is good. I dont know how many people have multiple TPUs at their disposal.