Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multinode training on GPU #2731

Closed
py4 opened this issue Apr 15, 2020 · 18 comments · Fixed by #11803
Closed

Support multinode training on GPU #2731

py4 opened this issue Apr 15, 2020 · 18 comments · Fixed by #11803
Assignees
Labels
enhancement New feature or request

Comments

@py4
Copy link

py4 commented Apr 15, 2020

I don't have a node with 8 gpus. I have two nodes each with 4 gpus. So is it possible to train a model on multiple nodes?

@mattjj mattjj added the question Questions for the JAX team label Apr 16, 2020
@hawkinsp hawkinsp added the enhancement New feature or request label Apr 16, 2020
@hawkinsp
Copy link
Member

This is actually something that does work right now but it's still experimental. There's also no real public-facing API for it yet; you have to type in some obscure and fairly magical things to set it all up correctly.

We should polish it off and document it!

@hawkinsp
Copy link
Member

Can you say a bit more about your model, though? Would gradient all-reductions across multiple nodes suffice?

@hawkinsp hawkinsp changed the title training on multiple nodes? Support multinode training on GPU Apr 16, 2020
@py4
Copy link
Author

py4 commented Apr 16, 2020

@hawkinsp Technically, I'm training a reformer model using Trax library.

@hawkinsp
Copy link
Member

And I assume you're just looking for data parallelism, i.e., partitioning a minibatch across GPUs, not partitioning in any other way (e.g., model parallelism)?

@py4
Copy link
Author

py4 commented Apr 16, 2020

@hawkinsp yeah my concern is data parallelism

@hawkinsp hawkinsp removed the question Questions for the JAX team label Apr 16, 2020
@powderluv
Copy link

@hawkinsp Can you please share your notes on this (don't need a stable api) ? We are trying some hybrid data/model/pipeline parallelism so it is a little different from @py4 but would love to get started with data parallelism

@brettkoonce
Copy link
Contributor

Data parallelism would of value to other projects that use XLA as well (eg https://www.tensorflow.org/swift). Exposing this functionality in a standardized way would help drive progress in the broader ecosystem!

@yxd886
Copy link

yxd886 commented Dec 9, 2020

I don't have a node with 8 gpus. I have two nodes each with 4 gpus. So is it possible to train a model on multiple nodes?

Hello py4, I am meeting the same problem, have you found some solutions?

@yxd886
Copy link

yxd886 commented Dec 9, 2020

This is actually something that does work right now but it's still experimental. There's also no real public-facing API for it yet; you have to type in some obscure and fairly magical things to set it all up correctly.

We should polish it off and document it!

Hello hawkinsp, Could you please provide more details about how to run data parallel with multi node GPUs?

@connection-on-fiber-bundles

@hawkinsp We are also interested in running JAX code on multiple nodes. Anything (hacky or not) that you can share would be appreciated. Thanks!

@jramapuram
Copy link

jramapuram commented Feb 21, 2021

I really enjoyed Jax during my DM internship and wanted to use it on my university SLURM cluster, but the lack of a clear (official) data parallel (multi-node) solution is a huge blocker to increasing Jax adoption outside of Google where you cant just grab a TPU pod and pmap across the pod. A single 8 (GPU) replica setup can barely train a Resnet50 imagenet classifier. Training SimCLR or any other large SOTA model is currently impossible without multi-node data parallelism.

@StellaAthena
Copy link

I would love this feature! I enjoy Jax, but I've been largely using DeepSpeed due to its ability to distribute across clusters.

@jrabary
Copy link

jrabary commented Sep 17, 2021

Any progress on this issue ? Using JAX to train a model on multi-node, multi-GPU is becoming a very important features for us.

@sudhakarsingh27
Copy link
Collaborator

sudhakarsingh27 commented Nov 18, 2021

@hawkinsp This is a significant bottleneck for scaling on multi-node GPU clusters. Is there any update on this issue?
Also, there was a recent pjit tutorial that explains multi-node TPU scaling but doesn't mention about GPUs. Is that planned to be updated in the future?

@cloudhan
Copy link
Contributor

@sudhakarsingh27 I constantly monitoring the jax releases, and there is something WIP that you might be interested in #8364

@brettkoonce
Copy link
Contributor

See also: #9582

@hawkinsp
Copy link
Member

hawkinsp commented Mar 8, 2022

Yes indeed. We haven't advertised it that much yet, but (a) you need to initialize the cluster using that API, and (b) you need to follow the same rules of multi-host programming that also apply on TPU, documented here: https://jax.readthedocs.io/en/latest/multi_process.html

I suspect we can consider this issue closed when we've documented (a) in the document (b).

@sudhakarsingh27
Copy link
Collaborator

sudhakarsingh27 commented May 20, 2022

@hawkinsp @zhangqiaorjc
Multinode (or multiprocess) doesn't seem to work with the following jax(lib) versions:

jax                           0.3.13                                                                                                                                                                       
jaxlib                        0.3.10+cuda11.cudnn82

Ran the attached minimal code on single node with 8 V100 GPUs as follows (2 processes with 4 GPUs each):

CUDA_VISIBLE_DEVICES="0,1,2,3" python jax_multi_node_experiment.py 0 &
CUDA_VISIBLE_DEVICES="4,5,6,7" python jax_multi_node_experiment.py 1

I could check that multi process(host/node) first fails with jax[cuda]==0.3.12 installed with following command

pip install jax[cuda]==0.3.12 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I get the following error when I run the multi-process jax commands above:

127.0.0.1:65432 2 1
I0525 00:05:16.228919 139978761119552 distributed.py:59] Connecting to JAX distributed service on 127.0.0.1:65432
I0525 00:05:16.245648 139978761119552 xla_bridge.py:330] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0525 00:05:16.246569 139742444975936 xla_bridge.py:330] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0525 00:05:18.227763 139978761119552 xla_bridge.py:330] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I0525 00:05:18.228022 139978761119552 xla_bridge.py:330] Unable to initialize backend 'cuda': make_gpu_client() got an unexpected keyword argument 'platform_name'
I0525 00:05:18.228085 139978761119552 xla_bridge.py:330] Unable to initialize backend 'rocm': make_gpu_client() got an unexpected keyword argument 'platform_name'
global devices= [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]
local devices= [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]
I0525 00:05:18.246024 139742444975936 xla_bridge.py:330] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I0525 00:05:18.246273 139742444975936 xla_bridge.py:330] Unable to initialize backend 'cuda': make_gpu_client() got an unexpected keyword argument 'platform_name'
I0525 00:05:18.246334 139742444975936 xla_bridge.py:330] Unable to initialize backend 'rocm': make_gpu_client() got an unexpected keyword argument 'platform_name'
global devices= [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]
local devices= [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]

For reference, here's the ouput from jax[cuda]==0.3.10 where multi-process seems to be working okay:

127.0.0.1:65432 2 1
I0525 00:09:03.394093 140366043674432 distributed.py:59] Connecting to JAX distributed service on 127.0.0.1:65432
I0525 00:09:03.410755 140366043674432 xla_bridge.py:263] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0525 00:09:03.410994 140588577851200 xla_bridge.py:263] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0525 00:09:05.517608 140366043674432 xla_bridge.py:263] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
global devices= [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0), GpuDevice(id=4, process_index=1), GpuDevice(id=5, process_index=1), GpuDevice(id=6, process_index=1), GpuDevice(id=7, process_index=1)]
I0525 00:09:05.517817 140588577851200 xla_bridge.py:263] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
local devices= [GpuDevice(id=4, process_index=1), GpuDevice(id=5, process_index=1), GpuDevice(id=6, process_index=1), GpuDevice(id=7, process_index=1)]
global devices= [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0), GpuDevice(id=4, process_index=1), GpuDevice(id=5, process_index=1), GpuDevice(id=6, process_index=1), GpuDevice(id=7, process_index=1)]
local devices= [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]

sudhakarsingh27 added a commit to sudhakarsingh27/t5x that referenced this issue Jun 23, 2022
To run T5x on multi-node and multi-GPUs, `jax.distributed.initialize`
needs to be called with appropriate setup as mentioned here:
google/jax#8364.
Added a command line flag - `multiprocess` to enable multiprocess T5x run
on GPUs.  Also, added command line flags for the arguments to
`jax.distributed.initialize`, namely - `coordinator_address`,
`num_processes` and `process_id`.

Example usage 1 (2 processes, running on 2 separate nodes, 8 GPUs each):
On the first node:
```
python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=i.p.ad.dr:port \
  --num_processes=2 \
  --process_id=0
```

On the second node:
```
python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=i.p.ad.dr:port \
  --num_processes=2 \
  --process_id=1
```
Notice that the `process_id` is different for the two processes. Also,
substitute the appropriate coordinator_address in `i.p.ad.dr:port`

Example usage 2 (1 node, 2 processes, 4 GPUs each):
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=127.0.0.1:12345 \
  --num_processes=2 \
  --process_id=0 & \
  && CUDA_VISIBLE_DEVICES=4,5,6,7 python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR} \
  --multiprocess \
  --coordinator_address=127.0.0.1:12345 \
  --num_processes=2 \
  --process_id=1
```

More information about multiprocess JAX runs:
google/jax#2731

Note: T5x partitioning fix: google-research#608
complements this change.

Fixes google-research#410/google-research#89
hawkinsp added a commit to hawkinsp/jax that referenced this issue Aug 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.