Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for TPU - sets environment variables correctly to use T… #863

Closed

Conversation

EdanToledo
Copy link
Contributor

@EdanToledo EdanToledo commented Jan 4, 2023

TPU support

What?

Changed Environment variables in lp_utils.to_device function to set up that only "nodes_on_gpu" can see the TPU and other nodes can only see CPU. This allows the trainer to run on a TPU. Additionally, a new config parameter simply called "use_tpu" was added and threaded through the launcher.

Why?

This is due to launchpad processes crashing if more than one process tries to use a TPU.

How?

As stated in "What", the environment variables decide which platform JAX uses.

Extra

There is a slight problem when wanting to use a TPU. The base python environment (the one calling the training script) needs to be set to only see a CPU otherwise it will crash for the same reason as stated above. This is simple to do through export JAX_PLATFORMS="cpu". One thing that has not been considered in this PR is if someone wants to put certain nodes on the TPU and other nodes on the GPU but that is quite fine grained and can be easily added later down the line. It gets quite complicated as TPUs can only have a single model running on it so I'm also not sure how this will work for non-parameter sharing situations i.e heterogenous agents. One would need multiple TPUs if they wanted to run multiple nodes on the TPU as only one process can occupy a TPU at a time. I'm not sure if JAX would would automatic assigning if there were multiple TPUs but for now having at least some TPU support is good. I dont know how many people have multiple TPUs at their disposal.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


edan seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@codecov
Copy link

codecov bot commented Jan 4, 2023

Codecov Report

Merging #863 (d7b8020) into develop (99e41a3) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff            @@
##           develop     #863   +/-   ##
========================================
  Coverage    93.54%   93.54%           
========================================
  Files          167      167           
  Lines         9259     9263    +4     
========================================
+ Hits          8661     8665    +4     
  Misses         598      598           
Impacted Files Coverage Δ
mava/components/building/distributor.py 100.00% <100.00%> (ø)
mava/systems/launcher.py 71.60% <100.00%> (+0.35%) ⬆️
mava/utils/lp_utils.py 52.08% <100.00%> (+2.08%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@sash-a
Copy link
Contributor

sash-a commented Jan 4, 2023

Hey Edan thanks for the change, this really cool. Just one thing can you do export JAX_PLATFORMS="cpu" in code. So somewhere before we lp.launch can you do something like

import os
os.environ["JAX_PLATFORMS"] = "cpu"

Will this work or does it have to be exported before you start running python? Kinda hard for us to test, because we haven't really done too much with Mava on TPUs.

@EdanToledo
Copy link
Contributor Author

I'm not exactly sure if that would work as it seems that as soon as JAX is imported, the TPU device is registered in use but I can maybe try it out just not exactly sure when I can.

@EdanToledo
Copy link
Contributor Author

Okay so I've done some testing. Unfortunately it looks like if someone wants to do it in the code - it needs to happen in the run script before jax is imported. I've tried it in system and I've tried it in the builder. Additionally, I tried to update the jax config as well and nothing seems to work. I'm sure there is other ways of doing it but I strongly believe it will probably need to be done in the first script that is run. If someone is going to use a TPU i dont think its too much to ask them to simply run export JAX_PLATFORMS=''cpu" If they are using a tpu, they should be a pretty advanced user.

@KaleabTessera
Copy link
Contributor

@EdanToledo Thanks for the PR!

Just to clarify, this is for if you want to run on tpus, but you want certain nodes to not be on tpu (i.e. on CPU)? Prev, we would just set them not to run on GPU and this would still mean they could still run on tpus.

So the above didn't work? You had to add import jax; jax.config.update('jax_platform_name', 'cpu') ?

@EdanToledo
Copy link
Contributor Author

EdanToledo commented Jan 9, 2023

@KaleabTessera yeah seemingly it didn't work. Just tried it again and you get the following error:

RuntimeError: Unable to initialize backend 'tpu': ABORTED: libtpu.so is already in use by process with pid 3569148. Not attempting to load libtpu.so in this process. (set JAX_PLATFORMS='' to automatically choose an available backend)

Essentially, as soon as Jax is imported with the TPU being an option, it locks the TPU for use from other processes even if you specify to use the cpu.

Only the trainer crashes, everything else can run (with the code changes in the PR present) but if you do the export JAX_PLATFORMS='cpu' it all works fine.

@KaleabTessera
Copy link
Contributor

Closing for now since the change didn't have any impact.

@KaleabTessera
Copy link
Contributor

@EdanToledo Please reopen if it is still an issue!

@EdanToledo
Copy link
Contributor Author

Hi Kale-ab, I'm not sure what you mean by no impact - the trainer is much faster on a TPU and without this change, the trainer cannot run on a TPU. But if its not on mava's road plan then closing is fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants