diff --git a/Makefile b/Makefile index 9ea6866df..2536d57c3 100644 --- a/Makefile +++ b/Makefile @@ -188,23 +188,21 @@ else endif endif -# Check if OpenMPI and NCCL are available, include them if so, for multi-GPU training +# Check if NCCL is available for multi-GPU training ifeq ($(NO_MULTI_GPU), 1) - $(info → Multi-GPU (OpenMPI + NCCL) is manually disabled) + $(info → Multi-GPU is manually disabled) else ifneq ($(OS), Windows_NT) # Detect if running on macOS or Linux ifeq ($(SHELL_UNAME), Darwin) - $(info ✗ Multi-GPU on CUDA on Darwin is not supported, skipping OpenMPI + NCCL support) - else ifeq ($(shell [ -d /usr/lib/x86_64-linux-gnu/openmpi/lib/ ] && [ -d /usr/lib/x86_64-linux-gnu/openmpi/include/ ] && echo "exists"), exists) - $(info ✓ OpenMPI found, OK to train with multiple GPUs) - NVCC_INCLUDES += -I/usr/lib/x86_64-linux-gnu/openmpi/include - NVCC_LDFLAGS += -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ - NVCC_LDLIBS += -lmpi -lnccl + $(info ✗ Multi-GPU on CUDA on Darwin is not supported, skipping NCCL support) + else ifeq ($(shell dpkg -l | grep -q nccl && echo "exists"), exists) + $(info ✓ NCCL found, OK to train with multiple GPUs) + NVCC_LDLIBS += -lnccl NVCC_FLAGS += -DMULTI_GPU else - $(info ✗ OpenMPI is not found, disabling multi-GPU support) - $(info ---> On Linux you can try install OpenMPI with `sudo apt install openmpi-bin openmpi-doc libopenmpi-dev`) + $(info ✗ NCCL is not found, disabling multi-GPU support) + $(info ---> On Linux you can try install NCCL with `sudo apt install libnccl2 libnccl-dev`) endif endif endif diff --git a/README.md b/README.md index d3dea874f..c6ea127bd 100644 --- a/README.md +++ b/README.md @@ -128,19 +128,21 @@ sudo apt-get -y install libcudnn9-dev-cuda-12 On top of this you need the [cuDNN frontend](https://github.com/NVIDIA/cudnn-frontend/tree/main), but this is just header files. Simply clone the repo to your disk. The Makefile currently looks for it in either your home directory or the current directory. If you have put it elsewhere, add `CUDNN_FRONTEND_PATH=/path/to/your/cudnn-frontend/include` to the `make` command-line. -**multi-GPU training**. As of April 26, 2024 there is now also support for multi-GPU training using MPI and NCCL. Make sure you install MPI, e.g. on Linux: +**multi-GPU training**. Support for multi-GPU training is availabel using NCCL. Make sure you download and install [NCCL](https://docs.nvidia.com/deeplearning/nccl/install-guide/index.html), e.g. on Linux: ```bash -sudo apt install openmpi-bin openmpi-doc libopenmpi-dev +sudo sudo apt install libnccl2 libnccl-dev ``` and then: ```bash make train_gpt2cu -mpirun -np ./train_gpt2cu +mpirun -np bach -c './train_gpt2cu -pn -pr $OMPI_COMM_WORLD_RANK' ``` +**multi-node training**. For SLURM enabled cluster, use the sample script in [scripts/run_gpt2_124M.sbatch](scripts/run_gpt2_124M.sbatch) + ## experiments / sweeps Just as an example process to sweep learning rates on a machine with 4 GPUs on TinyStories. Run a shell script `sweep.sh` (after you of course `chmod u+x sweep.sh`): diff --git a/dev/cuda/Makefile b/dev/cuda/Makefile index 4a14ac49f..b5bf9f31d 100644 --- a/dev/cuda/Makefile +++ b/dev/cuda/Makefile @@ -11,7 +11,6 @@ endif # Compiler flags CFLAGS = -O3 --use_fast_math NVCCFLAGS = -lcublas -lcublasLt -std=c++17 -MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ # Default rule for our CUDA files %: %.cu @@ -52,7 +51,7 @@ global_norm: global_norm.cu # NCCL communication kernels nccl_all_reduce: nccl_all_reduce.cu - $(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce + $(NVCC) -lnccl $(NVCCFLAGS) nccl_all_reduce.cu -o nccl_all_reduce # Run all targets run_all: all diff --git a/dev/cuda/nccl_all_reduce.cu b/dev/cuda/nccl_all_reduce.cu index 260ba02ba..1e18d8b8a 100644 --- a/dev/cuda/nccl_all_reduce.cu +++ b/dev/cuda/nccl_all_reduce.cu @@ -5,17 +5,16 @@ Fills a vector with 1s on the first GPU, 2s on the second, etc. Then aggregates the values in the resulting vectors. Compile example: -nvcc -lmpi -lnccl -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ -lcublas -lcublasLt nccl_all_reduce.cu -o nccl_all_reduce +nvcc -lnccl -lcublas -lcublasLt nccl_all_reduce.cu -o nccl_all_reduce Run on 2 local GPUs (set -np to a different value to change GPU count): -mpirun -np 2 ./nccl_all_reduce +mpirun -np 2 bash -c './nccl_all_reduce $OMPI_COMM_WORLD_RANK' */ #include "common.h" #include #include -#include #include #include #include @@ -31,99 +30,52 @@ void nccl_check(ncclResult_t status, const char *file, int line) { } #define ncclCheck(err) (nccl_check(err, __FILE__, __LINE__)) -void mpi_check(int status, const char *file, int line) { - if (status != MPI_SUCCESS) { - char mpi_error[4096]; - int mpi_error_len = 0; - assert(MPI_Error_string(status, &mpi_error[0], &mpi_error_len) == - MPI_SUCCESS); - printf("[MPI ERROR] at file %s:%d:\n%.*s\n", file, line, mpi_error_len, - mpi_error); - exit(EXIT_FAILURE); - } -} -#define mpiCheck(err) (mpi_check(err, __FILE__, __LINE__)) - -// Sets a vector to a predefined value -__global__ void set_vector(float *data, int N, float value) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - - // Check for out-of-bounds access - if (i < N) { - data[i] = value; - } -} - size_t cdiv(size_t a, size_t b) { return (a + b - 1) / b; } // Parameters specific to training on multiple GPUs. typedef struct { - int process_rank; // Rank of this process among all MPI processes on all hosts. 0 if no multi-GPU. + int process_rank; // Rank of this process among all processes on all hosts. 0 if no multi-GPU. int num_processes; // Total number of processes on all hosts. 1 if no multi-GPU. - int local_device_idx; // This process GPU index on current machine. 0 if no multi-GPU. + int device_idx; // This process GPU index on current machine. 0 if no multi-GPU. ncclComm_t nccl_comm; // NCCL communication primitive, used for collective mutli-GPU work. } MultiGpuConfig; -// Determine which GPU this process should use. -// Processes on the same machines use different GPU indicies. Processes on other machines don't. -// Copied from NCCL examples: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html#example-2-one-device-per-process-or-thread -int multi_gpu_get_local_device_idx(int process_rank, int num_processes) { - char hostname[1024]; - hostname[1023] = '\0'; - // All processes on the same machine will share the same hostname. - gethostname(hostname, 1023); - for (int i=0; i < 1024; i++) { - if (hostname[i] == '.') { - hostname[i] = '\0'; - break; - } - } - uint64_t hostname_hash = 5381; - for (int c = 0; hostname[c] != '\0'; c++){ hostname_hash = ((hostname_hash << 5) + hostname_hash) ^ hostname[c]; } - - // Distribute all hostname hashes to all processes. - uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t)); - all_hostsname_hashes[process_rank] = hostname_hash; - mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); - - // Identify which GPU we need to use. - int local_device_idx = 0; - for (int current_process = 0; current_process < num_processes; ++current_process) { - if (current_process == process_rank) { - // Found my gpu, local_device_idx now has my target GPU index. - break; - } - if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) { - // This process ID runs on the same machine, but it's not me, skip this GPU - local_device_idx++; - } - } - - free(all_hostsname_hashes); - return local_device_idx; -} - -MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) { - // Initialize MPI. +MultiGpuConfig multi_gpu_config_init(int num_processes, int process_rank, int gpus_per_node, char *dfs_path) { MultiGpuConfig result; - mpiCheck(MPI_Init(argc, argv)); - mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank)); - mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes)); - result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes); - printf("[Process rank %d] Using GPU %d\n", result.process_rank, result.local_device_idx); - cudaCheck(cudaSetDevice(result.local_device_idx)); ncclUniqueId nccl_id; - if (result.process_rank == 0) { + + result.process_rank = process_rank; + result.num_processes = num_processes; + result.device_idx = process_rank % gpus_per_node; + + FILE* idFile; + static char filename[256]; + snprintf(filename, sizeof(filename), "%s/ncclUniqueId.dat", dfs_path); + + if (result.process_rank == 0) { // Generate the NCCL unique ID at rank 0 and write it to a file ncclCheck(ncclGetUniqueId(&nccl_id)); + idFile = fopen(filename, "wb"); + assert(idFile != NULL); + fwrite(&nccl_id, sizeof(nccl_id), 1, idFile); + fclose(idFile); + } else { // Other ranks wait until the file is available and read the unique ID + do { + usleep(1000000); + idFile = fopen(filename, "rb"); + if (idFile != NULL) break; + } while (idFile == NULL); + fread(&nccl_id, sizeof(nccl_id), 1, idFile); + fclose(idFile); } - mpiCheck(MPI_Bcast((void *)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD)); + + printf("ProcessID:%d, NumProcess::%d, DeviceId:%d\n", result.process_rank, result.num_processes, result.device_idx); + cudaCheck(cudaSetDevice(result.device_idx)); ncclCheck(ncclCommInitRank(&result.nccl_comm, result.num_processes, nccl_id, result.process_rank)); return result; } void multi_gpu_config_free(const MultiGpuConfig* multi_gpu_config) { ncclCommDestroy(multi_gpu_config->nccl_comm); - mpiCheck(MPI_Finalize()); } float get_mean(float *arr, size_t size, int process_rank) { @@ -134,12 +86,20 @@ float get_mean(float *arr, size_t size, int process_rank) { return sum / size; } -int main(int argc, char **argv) { +// CUDA kernel to set each element of the array to a specific value +__global__ void set_vector(float *array, float value, size_t num_elements) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_elements) { + array[idx] = value; + } +} + +int main(int argc, char *argv[]) { // Some constants const size_t all_reduce_buffer_size = 32 * 1024 * 1024; const size_t threads_per_block = 1024; - MultiGpuConfig multi_gpu_config = multi_gpu_config_init(&argc, &argv); + MultiGpuConfig multi_gpu_config = multi_gpu_config_init(2, atoi(argv[1]), 8, "."); // Allocating buffers on each of the devices. float *all_reduce_buffer; diff --git a/llmc/zero.cuh b/llmc/zero.cuh index 160dae7ac..772186c13 100644 --- a/llmc/zero.cuh +++ b/llmc/zero.cuh @@ -12,7 +12,6 @@ Utilities for ZeRO sharding #include #ifdef MULTI_GPU -#include #include #endif @@ -36,97 +35,57 @@ void nccl_check(ncclResult_t status, const char *file, int line) { } #define ncclCheck(err) (nccl_check(err, __FILE__, __LINE__)) -void mpi_check(int status, const char *file, int line) { - if (status != MPI_SUCCESS) { - char mpi_error[4096]; - int mpi_error_len = 0; - assert(MPI_Error_string(status, &mpi_error[0], &mpi_error_len) == MPI_SUCCESS); - printf("[MPI ERROR] at file %s:%d:\n%.*s\n", file, line, mpi_error_len, mpi_error); - exit(EXIT_FAILURE); - } -} -#define mpiCheck(err) (mpi_check(err, __FILE__, __LINE__)) - #endif // MULTI_GPU // ---------------------------------------------------------------------------- -// MPI / multi-processing setup - // Parameters specific to training on multiple GPUs. typedef struct { - int process_rank; // Rank of this process among all MPI processes. 0 if no multi-GPU. + int process_rank; // Rank of this process among all processes launched. 0 if no multi-GPU. int num_processes; // Total number of processes. 1 if no multi-GPU. - int local_device_idx; // This process GPU index on current machine. 0 if no multi-GPU. + int device_idx; // This process GPU index on current machine. 0 if no multi-GPU. // Zero Redundancy Optimizer stage - https://fairscale.readthedocs.io/en/stable/deep_dive/oss_sdp_fsdp.html - // 0-Disabled - // 1-Optimizer State Sharding (OSS) - // 2-Optimizer + Gradient State Sharding (SDP) - // 3-Optimizer + Gradient + Horizontal Model Sharding (FSDP) - int zero_stage; + int zero_stage; // 0-Disabled, 1-OSS, 2-SDP, 3-FSDP size_t shard_num_parameters; + size_t shard_offset; #ifdef MULTI_GPU - ncclComm_t nccl_comm; // NCCL communication primitive, used for collective multi-GPU work. + ncclComm_t nccl_comm; // NCCL communication primitive, used for collective multi-GPU work. cudaStream_t nccl_stream; // CUDA Stream to perform NCCL operations. cudaEvent_t compute_nccl_sync; // Event used to synchronize NCCL with the compute #endif } MultiGpuConfig; +MultiGpuConfig multi_gpu_config_init(int num_processes, int process_rank, int gpus_per_node, char *dfs_path) { #ifdef MULTI_GPU -// Determine which GPU this process should use. -// Processes on the same machines use different GPU indicies. Processes on other machines don't. -// Copied from NCCL examples: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html#example-2-one-device-per-process-or-thread -int multi_gpu_get_local_device_idx(int process_rank, int num_processes) { - char hostname[1024]; - hostname[1023] = '\0'; - // All processes on the same machine will share the same hostname. - gethostname(hostname, 1023); - for (int i=0; i < 1024; i++) { - if (hostname[i] == '.') { - hostname[i] = '\0'; - break; - } - } - uint64_t hostname_hash = 5381u; - for (int c = 0; hostname[c] != '\0'; c++){ hostname_hash = ((hostname_hash << 5u) + hostname_hash) ^ hostname[c]; } - - // Distribute all hostname hashes to all processes. - uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t)); - all_hostsname_hashes[process_rank] = hostname_hash; - mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); - - // Identify which GPU we need to use. - int local_device_idx = 0; - for (int current_process = 0; current_process < num_processes; ++current_process) { - if (current_process == process_rank) { - // Found my gpu, local_device_idx now has my target GPU index. - break; - } - if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) { - // This process ID runs on the same machine, but it's not me, skip this GPU - local_device_idx++; - } - } - - free(all_hostsname_hashes); - return local_device_idx; -} -#endif - -MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) { -#ifdef MULTI_GPU - // Initialize MPI. MultiGpuConfig result; - mpiCheck(MPI_Init(argc, argv)); - mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank)); - mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes)); - result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes); - cudaCheck(cudaSetDevice(result.local_device_idx)); ncclUniqueId nccl_id; - if (result.process_rank == 0) { + + result.process_rank = process_rank; + result.num_processes = num_processes; + result.device_idx = process_rank % gpus_per_node; + + FILE* idFile; + static char filename[256]; + snprintf(filename, sizeof(filename), "%s/ncclUniqueId.dat", dfs_path); + + if (result.process_rank == 0) { // Generate the NCCL unique ID at rank 0 and write it to a file ncclCheck(ncclGetUniqueId(&nccl_id)); + idFile = fopen(filename, "wb"); + assert(idFile != NULL); + fwrite(&nccl_id, sizeof(nccl_id), 1, idFile); + fclose(idFile); + } else { // Other ranks wait until the file is available and read the unique ID + do { + usleep(1000000); + idFile = fopen(filename, "rb"); + if (idFile != NULL) break; + } while (idFile == NULL); + fread(&nccl_id, sizeof(nccl_id), 1, idFile); + fclose(idFile); } - mpiCheck(MPI_Bcast((void *)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD)); + + printf("ProcessID:%d, NumProcess::%d, DeviceId:%d\n", result.process_rank, result.num_processes, result.device_idx); + cudaCheck(cudaSetDevice(result.device_idx)); ncclCheck(ncclCommInitRank(&result.nccl_comm, result.num_processes, nccl_id, result.process_rank)); cudaCheck(cudaStreamCreate(&result.nccl_stream)); // event without timing for maximum performance @@ -140,7 +99,7 @@ MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) { MultiGpuConfig result; result.process_rank = 0; result.num_processes = 1; - result.local_device_idx = 0; + result.device_idx = 0; return result; #endif } @@ -150,16 +109,17 @@ void multi_gpu_config_free(MultiGpuConfig* multi_gpu_config) { ncclCheck(ncclCommDestroy(multi_gpu_config->nccl_comm)); cudaCheck(cudaStreamDestroy(multi_gpu_config->nccl_stream)); cudaCheck(cudaEventDestroy(multi_gpu_config->compute_nccl_sync)); - mpiCheck(MPI_Finalize()); #endif } -void multi_gpu_barrier(const MultiGpuConfig* multi_gpu_config) { +void multi_gpu_barrier(const MultiGpuConfig* multi_gpu_config, float *unified_buffer) { #ifdef MULTI_GPU if (multi_gpu_config->num_processes > 1) { - mpiCheck(MPI_Barrier(MPI_COMM_WORLD)); + if (unified_buffer == NULL) cudaCheck(cudaMallocManaged(&unified_buffer, sizeof(float))); + ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, 0)); } #endif + cudaCheck(cudaDeviceSynchronize()); } // Offset and size of a tensor shard diff --git a/profile_gpt2.cu b/profile_gpt2.cu index 8a7ddbfac..1d038a444 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -28,7 +28,7 @@ the profile.ncu-rep from a cloud box to local to pretty view. #include "train_gpt2.cu" int main(int argc, char *argv[]) { - multi_gpu_config = multi_gpu_config_init(&argc, &argv); + multi_gpu_config = multi_gpu_config_init(1, 0, 8, "."); common_start(true, true); // build the GPT-2 model from a checkpoint diff --git a/scripts/README.md b/scripts/README.md index 876005955..0cd53d932 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -21,7 +21,7 @@ Long story short, try `-r 1` (recompute GeLU, trading off speed and memory) to c It might be that you only have one GPU and not a whole box of them. Every script is fairly easy to change for just a single GPU. For llm.c, simply change line 1 to line 2 and leave everything else the same: ```bash -mpirun -np 8 ./train_gpt2cu \ +mpirun -np 8 bach -c './train_gpt2cu -pn 8 -pr $OMPI_COMM_WORLD_RANK' ./train_gpt2cu \ ``` diff --git a/scripts/run_gpt2_124M.sbatch b/scripts/run_gpt2_124M.sbatch new file mode 100644 index 000000000..d5158600a --- /dev/null +++ b/scripts/run_gpt2_124M.sbatch @@ -0,0 +1,38 @@ +#!/bin/bash +#SBATCH --job-name=llmc-multinode +#SBATCH --output=/dfs/llm.c/log124M/%x_%j_%t.log +#SBATCH --ntasks=32 # total number of processes to launch +#SBATCH --ntasks-per-node=8 # assuming each node has 8 gpus +#SBATCH --gres=gpu:8 # request 8 gpus from each node +#SBATCH --nodelist=node[000-003] # list of the nodes to dispatch processes (32/8=4) + +cd /dfs/llm.c/ # path to the repo in distributed file system +mkdir -p log124M + +# export NCCL_SOCKET_IFNAME=ib0 # network interface Ethernet or InifiniBand which enables gpu direct rdma +# export NCCL_IB_HCA=mlx5_0,mlx5_1 # list of all InfiniBand devices available if available + +# GPT-2 (124M) repro on FineWeb100B +# Global batch size is set to (1024 * 64) * 32 +srun bash -c " + ./train_gpt2cu \ + -i 'dev/data/fineweb100B/fineweb_train_*.bin' \ + -j 'dev/data/fineweb100B/fineweb_val_*.bin' \ + -o "log124M" \ + -v 250 -s 20000 -g 144 \ + -h 1 \ + -b 64 -t 1024 \ + -d 2097152 \ + -r 0 \ + -z 1 \ + -c 0.1 \ + -l 0.0006 \ + -q 0.0 \ + -u 700 \ + -n 10000 \ + -y 1 \ + -e d12 \ + -pn 32 \ + -pr \$SLURM_PROCID \ + -pg 8 \ + -pd "/dfs/llm.c/log124M"" diff --git a/scripts/run_gpt2_124M.sh b/scripts/run_gpt2_124M.sh index 9ca1f0822..d2b41fa4d 100755 --- a/scripts/run_gpt2_124M.sh +++ b/scripts/run_gpt2_124M.sh @@ -20,9 +20,10 @@ while true; do # run python dev/data/fineweb.py --version 10B to prepro data # run python dev/data/hellaswag.py to prepro hellaswag eval - mpirun -np 8 ./train_gpt2cu \ - -i "dev/data/fineweb10B/fineweb_train_*.bin" \ - -j "dev/data/fineweb10B/fineweb_val_*.bin" \ + mpirun -np 8 bash -c " + ./train_gpt2cu \ + -i 'dev/data/fineweb10B/fineweb_train_*.bin' \ + -j 'dev/data/fineweb10B/fineweb_val_*.bin' \ -o $out_dir \ -v 250 -s 20000 -g 144 \ -h 1 \ @@ -36,7 +37,9 @@ while true; do -u 700 \ -n 5000 \ -y 1 \ - -e "d12" + -e "d12" \ + -pn 8 \ + -pr \$OMPI_COMM_WORLD_RANK" sleep 1 done diff --git a/scripts/run_gpt2_350M.sh b/scripts/run_gpt2_350M.sh index 1f9defc12..d144ac1b1 100644 --- a/scripts/run_gpt2_350M.sh +++ b/scripts/run_gpt2_350M.sh @@ -20,9 +20,10 @@ while true; do # run python dev/data/fineweb.py --version 100B to prepro data # run python dev/data/hellaswag.py to prepro hellaswag eval - mpirun -np 8 ./train_gpt2cu \ - -i "dev/data/fineweb100B/fineweb_train_*.bin" \ - -j "dev/data/fineweb100B/fineweb_val_*.bin" \ + mpirun -np 8 bash -c " + ./train_gpt2cu \ + -i 'dev/data/fineweb100B/fineweb_train_*.bin' \ + -j 'dev/data/fineweb100B/fineweb_val_*.bin' \ -o $out_dir \ -v 250 -s 100000 -g 144 \ -h 1 \ @@ -37,7 +38,9 @@ while true; do -n 2000 \ -x 60000 \ -y 1 \ - -e "d24" + -e "d24" \ + -pn 8 \ + -pr \$OMPI_COMM_WORLD_RANK" sleep 1 done diff --git a/scripts/run_gpt3_124M.sh b/scripts/run_gpt3_124M.sh index bde1e6859..426d83701 100644 --- a/scripts/run_gpt3_124M.sh +++ b/scripts/run_gpt3_124M.sh @@ -20,9 +20,10 @@ while true; do # run python dev/data/fineweb.py --version 10B to prepro data # run python dev/data/hellaswag.py to prepro hellaswag eval - mpirun -np 8 ./train_gpt2cu \ - -i "dev/data/fineweb100B/fineweb_train_*.bin" \ - -j "dev/data/fineweb100B/fineweb_val_*.bin" \ + mpirun -np 8 bash -c " + ./train_gpt2cu \ + -i 'dev/data/fineweb100B/fineweb_train_*.bin' \ + -j 'dev/data/fineweb100B/fineweb_val_*.bin' \ -o $out_dir \ -v 250 -s 20000 -g 144 \ -h 1 \ @@ -37,7 +38,9 @@ while true; do -n 10000 \ -y 1 \ -x 565950 \ - -e "d12" + -e "d12" \ + -pn 8 \ + -pr \$OMPI_COMM_WORLD_RANK" sleep 1 done diff --git a/test_gpt2.cu b/test_gpt2.cu index 71f9d0704..dbdc7b2da 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -89,7 +89,7 @@ float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size } int main(int argc, char *argv[]) { - multi_gpu_config = multi_gpu_config_init(&argc, &argv); + multi_gpu_config = multi_gpu_config_init(1, 0, 8, "."); common_start(false, true); // set the right paths diff --git a/train_gpt2.cu b/train_gpt2.cu index e7b871834..bf693747e 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -325,6 +325,7 @@ typedef struct { int* targets; // the target tokens for the current forward pass float mean_loss; // after a forward pass with targets, will be populated with the mean loss float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs + float* unified_buffer; // GPU buffer to avg loss across process floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost float* cpu_losses_fp32; // same but fp32 unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. @@ -346,6 +347,7 @@ void gpt2_init_common(GPT2 *model) { model->targets = NULL; model->cpu_losses = NULL; model->cpu_losses_fp32 = NULL; + model->unified_buffer = NULL; // the B,T params are determined and set, fixed on first batch in forward() model->batch_size = 0; model->seq_len = 0; @@ -894,12 +896,17 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { } // Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled. -float multi_gpu_cpu_float_sum(float value) { +float multi_gpu_float_sum(float value, float *unified_buffer, const MultiGpuConfig* multi_gpu_config) { #ifdef MULTI_GPU - // note MPI doesn't support all reduce with mean, only sum - float result; - mpiCheck(MPI_Allreduce(&value, &result, 1, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD)); - return result; + if (multi_gpu_config->num_processes == 1) return value; + + if (unified_buffer == NULL) cudaCheck(cudaMallocManaged(&unified_buffer, sizeof(float))); + *unified_buffer = value; + cudaCheck(cudaMemPrefetchAsync(unified_buffer, sizeof(float), multi_gpu_config->device_idx, 0)); + ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, 0)); + cudaCheck(cudaMemPrefetchAsync(unified_buffer, sizeof(float), cudaCpuDeviceId, 0)); + cudaCheck(cudaDeviceSynchronize()); + return *unified_buffer; #else return value; #endif @@ -913,7 +920,7 @@ void gpt2_multi_gpu_loss_reduce(GPT2* model, MultiGpuConfig* multi_gpu_config) { // If there's only one process, there is nothing to do if (multi_gpu_config->num_processes == 1) { return; } // Average all losses. - model->accumulated_mean_loss = multi_gpu_cpu_float_sum(model->mean_loss) / multi_gpu_config->num_processes; + model->accumulated_mean_loss = multi_gpu_float_sum(model->mean_loss, model->unified_buffer, multi_gpu_config) / multi_gpu_config->num_processes; #endif cudaCheck(cudaDeviceSynchronize()); } @@ -992,7 +999,7 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl global_norm_squared_aggregate(grad_norm_squared, max_num_block_sums, main_stream); cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); // further sum the (partial) squared norm across all GPUs (see comment ^1 above) - grad_norm_squared_cpu = multi_gpu_cpu_float_sum(grad_norm_squared_cpu); + grad_norm_squared_cpu = multi_gpu_float_sum(grad_norm_squared_cpu, model->unified_buffer, multi_gpu_config); } else { // in regular DDP, backward has averaged the gradients across all GPUs // so each GPU can compute the squared norm over the whole grad vector, with no added comms needed @@ -1124,10 +1131,10 @@ void gpt2_free(GPT2 *model) { void common_start(bool override_enable_tf32 = true, bool print_device_info = true) { // get CUDA device infos - cudaGetDeviceProperties(&deviceProp, multi_gpu_config.local_device_idx); + cudaGetDeviceProperties(&deviceProp, multi_gpu_config.device_idx); if (print_device_info) { printf("[System]\n"); - printf("Device %d: %s\n", multi_gpu_config.local_device_idx, deviceProp.name); + printf("Device %d: %s\n", multi_gpu_config.device_idx, deviceProp.name); } // set up the cuda streams. atm everything is on the single main stream @@ -1326,7 +1333,6 @@ void error_usage() { // ---------------------------------------------------------------------------- // main training loop int main(int argc, char *argv[]) { - multi_gpu_config = multi_gpu_config_init(&argc, &argv); // read in the (optional) command line arguments const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin"; @@ -1353,10 +1359,14 @@ int main(int argc, char *argv[]) { int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training int hellaswag_eval = 0; + int num_processes = 1; + int process_rank = 0; + int gpus_per_node = 8; + char dfs_path[256] = "."; for (int i = 1; i < argc; i+=2) { if (i + 1 >= argc) { error_usage(); } // must have arg after flag if (argv[i][0] != '-') { error_usage(); } // must start with dash - if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter) + if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { error_usage(); } // must be -x (one dash, one letter) // read in the args if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; } else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; } @@ -1382,8 +1392,15 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'z') { zero_stage = atoi(argv[i+1]); } else if (argv[i][1] == 'r') { recompute = atoi(argv[i+1]); } else if (argv[i][1] == 'h') { hellaswag_eval = atoi(argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'n') { num_processes = atoi(argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'r') { process_rank = atoi(argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'g') { gpus_per_node = atoi(argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'd') { strcpy(dfs_path, argv[i+1]); } else { error_usage(); } } + + multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, dfs_path); + // should do a bit more error checking here assert(warmup_iterations >= 0); if (output_log_dir != NULL) { @@ -1582,7 +1599,7 @@ int main(int argc, char *argv[]) { val_loss += model.mean_loss; } val_loss /= val_num_batches; - val_loss = multi_gpu_cpu_float_sum(val_loss) / multi_gpu_config.num_processes; + val_loss = multi_gpu_float_sum(val_loss, model.unified_buffer, &multi_gpu_config) / multi_gpu_config.num_processes; printf0("val loss %f\n", val_loss); logger_log_val(&logger, step, val_loss); } @@ -1601,7 +1618,7 @@ int main(int argc, char *argv[]) { eval_acc_norm += (float)correct; } // careful because not all ranks may have the exact same allocation of number of examples - eval_acc_norm = multi_gpu_cpu_float_sum(eval_acc_norm); + eval_acc_norm = multi_gpu_float_sum(eval_acc_norm, model.unified_buffer, &multi_gpu_config); printf0("HellaSwag: %d/%d = %f\n", (int)eval_acc_norm, eval_loader.num_examples, eval_acc_norm / eval_loader.num_examples); logger_log_eval(&logger, step, eval_acc_norm / eval_loader.num_examples); } @@ -1666,13 +1683,13 @@ int main(int argc, char *argv[]) { snprintf(filename_buffer, 512, "%s/state_%08d_%05d.bin", output_log_dir, step, multi_gpu_config.process_rank); save_state(filename_buffer, step, &model, &train_loader); // DONE file is a signal that this checkpoint as a whole is complete - multi_gpu_barrier(&multi_gpu_config); + multi_gpu_barrier(&multi_gpu_config, model.unified_buffer); if (multi_gpu_config.process_rank == 0) { snprintf(filename_buffer, 512, "%s/DONE_%08d", output_log_dir, step); FILE* done_file = fopenCheck(filename_buffer, "w"); fclose(done_file); } - multi_gpu_barrier(&multi_gpu_config); + multi_gpu_barrier(&multi_gpu_config, model.unified_buffer); } resuming = 0;