Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-GPU] llm.c now runs on multiple GPUs with NCCL #248

Merged
merged 1 commit into from Apr 26, 2024

Conversation

PeterZhizhin
Copy link
Contributor

@PeterZhizhin PeterZhizhin commented Apr 24, 2024

I have tested this on a vast.ai setup with 2 RTX A2000. This shows that the code works, but my setup is not good for profiling, since it doesn't have NVLink.

Here are my results:

On 1 GPU:

./train_gpt2cu -b 4

step    1/74: train loss 4.368361 (acc 4.368361) (246.222282 ms, 16635 tok/s)
step    2/74: train loss 4.506210 (acc 4.506210) (241.274053 ms, 16976 tok/s)
step    3/74: train loss 4.418672 (acc 4.418672) (242.210043 ms, 16910 tok/s)
step    4/74: train loss 3.966732 (acc 3.966732) (242.601008 ms, 16883 tok/s)
step    5/74: train loss 3.596187 (acc 3.596187) (242.081950 ms, 16919 tok/s)

On 2 GPUs (steps with same step number show results from different processes):

$ mpirun -np 2 ./train_gpt2cu -b 2
step    1/74: train loss 4.462181 (acc 4.368362) (198.245372 ms, 10330 tok/s)
step    1/74: train loss 4.274543 (acc 4.368362) (198.247060 ms, 10330 tok/s)
step    2/74: train loss 4.545842 (acc 4.509134) (192.055940 ms, 10663 tok/s)
step    2/74: train loss 4.472425 (acc 4.509134) (192.059574 ms, 10663 tok/s)
step    3/74: train loss 4.573365 (acc 4.420272) (192.845851 ms, 10619 tok/s)
step    3/74: train loss 4.267179 (acc 4.420272) (192.827421 ms, 10620 tok/s)
step    4/74: train loss 3.920646 (acc 3.966377) (192.742148 ms, 10625 tok/s)
step    4/74: train loss 4.012107 (acc 3.966377) (192.772412 ms, 10623 tok/s)
step    5/74: train loss 3.553988 (acc 3.596775) (188.588768 ms, 10859 tok/s)
step    5/74: train loss 3.639561 (acc 3.596775) (188.597767 ms, 10859 tok/s)

It does run faster on equal batch sizes (~190ms vs ~240ms).

However, what's more important is that it now runs on larger batch sizes:

./train_gpt2cu -b 24
<OOMs>

mpirun -np 2 --allow-run-as-root ./train_gpt2cu -b 12
step    1/12: train loss 4.210306 (acc 4.316231) (721.931293 ms, 17021 tok/s)
step    1/12: train loss 4.422156 (acc 4.316231) (721.934773 ms, 17020 tok/s)
step    2/12: train loss 4.624632 (acc 4.576700) (711.685577 ms, 17266 tok/s)
step    2/12: train loss 4.528768 (acc 4.576700) (711.695615 ms, 17265 tok/s)                                                                                                                                                        
step    3/12: train loss 4.408728 (acc 4.198446) (719.000947 ms, 17090 tok/s)
step    3/12: train loss 3.988165 (acc 4.198446) (718.976635 ms, 17090 tok/s)
step    4/12: train loss 3.946835 (acc 4.027372) (719.855278 ms, 17070 tok/s)
step    4/12: train loss 4.107909 (acc 4.027372) (719.890087 ms, 17069 tok/s)

@karpathy
Copy link
Owner

Very exciting! Looking forward to step through this in detail tomorrow.

#ifdef MULTI_GPU
// Average all gradients.
char* grads_memory_iterator = (char*)model->grads_memory;
for (int i = 0; i < NUM_PARAMETER_TENSORS; ++i) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh really sad to have this loop here :( with @ngc92 changed this will be just two calls. but ideally it would be 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it is ugly. Looking forward to the rearrangements.

I think we would need at least 2 calls, as long as we have both 16 and 32 bit gradients.

@karpathy
Copy link
Owner

Sadly I'm not able to test these changes yet because Lambda brought down my box during a re-image :(. I expect I'll get it back later today and will run & take a closer look.

@ademeure
Copy link
Contributor

ademeure commented Apr 26, 2024

640K tokens/s on 8xA100 with -np 8 -b 3 (BF16 mode)! I couldn't get higher batch sizes to work at -np 8, which I assume is related to the existing batch sizes bug, and not an issue with the PR.

Perf drops from about ~86K (NO_MULTI_GPU) to ~82K (1 GPU with MULTI_GPU) to ~80K per GPU when going to 8x multi-GPU which seems really good for a first implementation! We'd probably want to avoid the overhead of multi GPU when there's only 1 GPU though, even if it's being built by nccl for some reason.

I needed to fix 2 small issues to get it working:

  1. "all_hostsname_hashes" is not initialised before the MPI_Allgather() ==> I added all_hostsname_hashes[process_rank] = hostname_hash; but not 100% sure that's the correct way to initialise it.

  2. It was running inference on all GPUs for the last step (which generates a confusing mess repeating every word 8x) because of the priority of || vs && in conditionals, need to use () like this: if (multi_gpu_config.process_rank == 0 && (step > 0 && (step % sample_every) == 0 || last_step)) {

@karpathy karpathy merged commit 8389fba into karpathy:master Apr 26, 2024
@rosslwheeler
Copy link
Contributor

@PeterZhizhin - ran into an interesting problem with the multi-gpu change above in train_gpt2.cu. Designated initializers (been in GCC C for a while now) are not supported in C++ (e.g. Cuda) until C++ 20. Microsoft's C++ compiler errors out the code below UNLESS you specify C++ 20 which is arguably correct.

Question: would you be okay if we change this code to simple C code (see diff below)? This keeps us from forcing the NVCC builds on Windows to C++ 20.

image

I can do the PR if you're okay with the change? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants