Skip to content

Commit

Permalink
Merge pull request #444 from karpathy/feature/sharded_data
Browse files Browse the repository at this point in the history
extend dataloader to be sharded
  • Loading branch information
karpathy committed May 22, 2024
2 parents 967420d + 099d30f commit 69f1221
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 127 deletions.
104 changes: 69 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ make train_gpt2fp32cu
./train_gpt2fp32cu
```

The above lines (1) download the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset, tokenize it with the GPT-2 Tokenizer, (2) download and save the GPT-2 (124M) weights, (3) init from them in C/CUDA and train for one epoch on tineshakespeare with AdamW (using batch size 4, context length 1024, total of 74 steps), evaluate validation loss, and sample some text. Note that in this quickstart we are using the fp32 version [train_gpt2_fp32.cu](train_gpt2_fp32.cu) of the CUDA code. Below in the CUDA section we document the current "mainline" [train_gpt2.cu](train_gpt2.cu), which is still being very actively developed, uses mixed precision, and runs ~2X faster.
The above lines (1) download the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset, tokenize it with the GPT-2 Tokenizer, (2) download and save the GPT-2 (124M) weights, (3) init from them in C/CUDA and train for one epoch on tineshakespeare with AdamW (using batch size 4, context length 1024, total of 74 steps), evaluate validation loss, and sample some text. Note that in this quickstart we are using the fp32 version [train_gpt2_fp32.cu](train_gpt2_fp32.cu) of the CUDA code. In the next section we document the current "mainline" [train_gpt2.cu](train_gpt2.cu), which uses mixed precision, and runs ~2X faster.

## quick start (GPU, fast bleeding edge)

Expand All @@ -45,75 +45,102 @@ Note that the default batch size is very low (4). If you have enough memory on y
./train_gpt2cu -b 32
```

My standard "prod" run with a nice GPU (e.g. A100 40GB) actually trains on TinyStories instead of TinyShakespeare, and looks like this:
My standard single-GPU "prod" run (e.g. with a A100 40GB) trains on TinyStories instead of TinyShakespeare and looks like this, as an example:

```bash
python dev/data/tinystories.py
make train_gpt2cu USE_CUDNN=1
./train_gpt2cu -i dev/data/tinystories/TinyStories -v 250 -s 250 -g 144 -o stories.log -b 32
./train_gpt2cu -i dev/data/tinystories/TinyStories_train.bin \
-j dev/data/tinystories/TinyStories_val.bin \
-v 250 -s 250 -g 144 -o stories.log -b 32
```

Where I decrease the frequency of validation loss and sampling to every 250 steps, sample 144 tokens during sampling stage (to fit ~one story), and at batch size 32.
The `-i` flag is a glob pattern for the input data, `-j` for the val data. In addition I decrease the frequency of validation loss and sampling to every 250 steps, sample 144 tokens during sampling stage (to fit ~one story), and at batch size 32.

## quick start (CPU)

The "I am so GPU poor that I don't even have one" section. No worries, run:
If you want to train on actual, real pretraining data, check out the recently added support for [fineweb dataset](https://huggingface.co/datasets/HuggingFaceFW/fineweb). Unlike the datasets above where the train/val tokens fit into a single .bin file, we now have multiple data shards as well. Here is an example:

```bash
pip install -r requirements.txt
python dev/data/tinyshakespeare.py
python train_gpt2.py
make train_gpt2
OMP_NUM_THREADS=8 ./train_gpt2
```
# write fineweb data in 100M token shards to dev/data/fineweb10B
python dev/data/fineweb.py -s 100000000
# compile and run
./train_gpt2cu -i "dev/data/fineweb10B/fineweb_train_*.bin" \
-j "dev/data/fineweb10B/fineweb_val_*.bin" \
-v 250 -s 250 -g 144 -o fineweb.log -b 32
```

The above lines (1) download the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset, tokenize it with the GPT-2 Tokenizer, (2) download and save the GPT-2 (124M) weights, (3) init from them in C and train for 40 steps on tineshakespeare with AdamW (using batch size 4, context length only 64), evaluate validation loss, and sample some text. Honestly, unless you have a beefy CPU (and can crank up the number of OMP threads in the launch command), you're not going to get that far on CPU training LLMs, but it might be a good demo/reference.
Where you will notice the use of glob pattern `*` to match all the train shards.

## quick start (multiple GPUs)

You'll be using the (more bleeding edge) mixed precision version of the code:
Great, let's get even more serious. We're using MPI and NCCL for multi-GPU training. Everything in the section above applies, with the following changes:

```bash
# example to install MPI:
sudo apt install openmpi-bin openmpi-doc libopenmpi-dev
# the run command is now preceeded by `mpirun`:
mpirun -np <number of GPUs on your machine> ./train_gpt2cu
```

Sub in the number of GPUs you'd like to run on in the last command. All of the flags discussed in the section above apply here as well.

## quick start (CPU)

The "I am so GPU poor that I don't even have one" section. You can still train! But you won't go too far. You can still finetune a GPT-2 small (124M parameter model) to output Shakespeare-like text, as an example:

```bash
pip install -r requirements.txt
python dev/data/tinyshakespeare.py
python train_gpt2.py
make train_gpt2cu
mpirun -np <number of GPUs on your machine> ./train_gpt2cu
make train_gpt2
OMP_NUM_THREADS=8 ./train_gpt2
```

Sub in the number of GPUs you'd like to run on in the last command.
The above lines (1) download the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset, tokenize it with the GPT-2 Tokenizer, (2) download and save the GPT-2 (124M) weights, (3) init from them in C and train for 40 steps on tineshakespeare with AdamW (using batch size 4, context length only 64), evaluate validation loss, and sample some text. Honestly, unless you have a beefy CPU (and can crank up the number of OMP threads in the launch command), you're not going to get that far on CPU training LLMs, but it might be a good demo/reference.

## training: more detail

Download and tokenize a dataset. The [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset is the fastest to download and tokenize:
The data files inside `/dev/data/(dataset).py` are responsible for downloading, tokenizing and saving the tokens to file. So for example when you run:

```bash
python dev/data/tinyshakespeare.py
```

This prints:
We download and tokenize the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset. The output of this looks like this:

```
Saved 32768 tokens to (...)/tiny_shakespeare_val.bin
Saved 305260 tokens to (...)/tiny_shakespeare_train.bin
writing 32,768 tokens to ./dev/data/tinyshakespeare/tiny_shakespeare_val.bin
writing 305,260 tokens to ./dev/data/tinyshakespeare/tiny_shakespeare_train.bin
```

The .bin files are raw byte streams of int32 numbers indicating the token ids with the GPT-2 tokenizer. Alternatively you could also tokenize the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset with `tinystories.py`.
The .bin files contain a short header (1024 bytes) and then a stream of tokens in uint16, indicating the token ids with the GPT-2 tokenizer. More datasets are available in `/dev/data`.

In principle we'd be ready to train the model right here. However the baseline CPU/fp32 reference code is so inefficient that it's not practical to train these models from scratch yet. Instead, we initialize with the GPT-2 weights released by OpenAI and just do finetuning. For that, we have to download the GPT-2 weights and save them as a checkpoint we can load in C:
In principle, once we have the tokens, we'd be ready to train the model right here. However, current code can't start training from scratch just yet (coming very soon), so we initialize training from the pretrained models released by OpenAI and do finetuning. For that, we have to download the GPT-2 weights and save them as a checkpoint we can load in C. This is what happens when you run this script:

```bash
python train_gpt2.py
```

You'll recognize this code from nanoGPT as a simple GPT-2 reference implementation in PyTorch. This script will download the GPT-2 (124M) model, overfit a single batch of data for 10 iterations, run a few steps of generation, and most importantly it will save three files: 1) the `gpt2_124M.bin` file that contains the raw model weights for loading in C, 2) the `gpt2_124M_debug_state.bin`, which also contains more debug state: the inputs, targets, logits and loss (useful for debugging and unit testing), and finally 3) the `gpt2_tokenizer.bin` which stores the vocabulary for the GPT-2 tokenizer, translating token ids to byte sequences of UTF-8 encoded string pieces. We can now initialize with these model weights and continue training in raw C. First compile the code:
You'll recognize this code from nanoGPT as a simple GPT-2 reference implementation in PyTorch. This script will download the GPT-2 (124M) model, overfit a single batch of data for 10 iterations, run a few steps of generation, and most importantly it will save three files: 1) the `gpt2_124M.bin` file that contains the raw model weights for loading in C, 2) the `gpt2_124M_debug_state.bin`, which also contains more debug state: the inputs, targets, logits and loss (useful for debugging and unit testing), and finally 3) the `gpt2_tokenizer.bin` which stores the vocabulary for the GPT-2 tokenizer, translating token ids to byte sequences of UTF-8 encoded string pieces. The file also saves both the fp32 versions of the above, and the bfloat16 versions of them for mixed precision training. We can now initialize with these model weights and continue training in raw C. Then we compile the training programs with `make`. There are currently three parallel implementations:

```bash
# the simple, CPU, reference code version
make train_gpt2
# the single-GPU fp32 CUDA version
make train_gpt2fp32cu
# the multi-GPU mixed precision CUDA version
make train_gpt2cu
```

You can have a look inside the `Makefile` and its comments. It will try to autodetect if OpenMP is available on your system, which is very helpful for speeding up the code at very low cost of code complexity. Some people seem to experience problems compiling on Ubuntu, have a look at [Issue 19](https://github.com/karpathy/llm.c/issues/19), TLDR you'd want to modify the `CFLAGS`:
You can have a look inside the `Makefile` and its comments. It will try to autodetect a lot of tools and libraries (e.g. cuDNN, OpenMP, OpenMPI, nvcc), and you want to get as many checkmarks as possible. For example when I run `make train_gpt2cu USE_CUDNN=1` on my fully configured machine, we see:

```
✓ cuDNN found, will run with flash-attention
✓ OpenMP found
✓ OpenMPI found, OK to train with multiple GPUs
✓ nvcc found, including GPU/CUDA support
```

Some people seem to experience problems compiling on Ubuntu, have a look at [Issue 19](https://github.com/karpathy/llm.c/issues/19), TLDR you'd want to modify the `CFLAGS`:

```
# try this first
Expand All @@ -122,7 +149,7 @@ CFLAGS="-Ofast -fno-finite-math-only -Wno-unused-result -march=native" make trai
CFLAGS="-O3 -Wno-unused-result -march=native" make train_gpt2
```

Once `train_gpt2` is compiled, you can run it:
Once the binary is compiled, we can run it. For example the simplest CPU reference version runs as:

```bash
OMP_NUM_THREADS=8 ./train_gpt2
Expand Down Expand Up @@ -164,26 +191,35 @@ Allay
---
```

I like how Netflix comes up, it's clear that the shadow of the training past is still lurking in the model. I did not attempt to tune the finetuning hyperparameters so it's quite likely this can be improved quite a bit. I also noticed that slightly different platforms (e.g. MacOS / Linux) will (sadly) give very slightly different results, so perhaps don't expect to get the exact numbers or generation above. Also note that if you are seeing token ids instead of text in the generation, it might be because your code is out of date, as Tokenizer decoding was added April 14, 2024. `git pull` the updates, and then re-run `python train_gpt2.py`, which will now also save the tokenizer, which C can read and then use to print text instead of token ids.
I like how Netflix comes up, it's clear that the shadow of the training past is still lurking in the model. I did not attempt to tune the finetuning hyperparameters so it's quite likely this can be improved quite a bit. I also noticed that slightly different platforms (e.g. MacOS / Linux) will (sadly) give very slightly different results, so perhaps don't expect to get the exact numbers or generation above.

Finally, the code is in flux. If anything weird happens that you didn't expect or that worked previously, try to `git pull`, re-run all the commands above, reference back to this README, etc.

## test

I am also attaching a simple unit test for making sure our C code agrees with the PyTorch code. Compile and run with:
I am also attaching a simple unit test for making sure our C code agrees with the PyTorch code. On the CPU as an example, compile and run with:

```bash
make test_gpt2
./test_gpt2
```

This now loads the `gpt2_124M_debug_state.bin` file, runs a forward pass, compares the logits and loss with the PyTorch reference implementation, then it does 10 iterations of training with Adam and makes sure the losses match PyTorch.
This now loads the `gpt2_124M_debug_state.bin` file, runs a forward pass, compares the logits and loss with the PyTorch reference implementation, then it does 10 iterations of training with Adam and makes sure the losses match PyTorch. To test the GPU version I run:

```bash
# fp32 test (cudnn not supported)
make test_gpt2cu PRECISION=FP32 && ./test_gpt2cu
# mixed precision cudnn test
make test_gpt2cu USE_CUDNN=1 && ./test_gpt2cu
```

## tutorial

I attached a very small tutorial here, in [doc/layernorm/layernorm.md](doc/layernorm/layernorm.md). It's a simple, step-by-step guide to implementing a single layer of the GPT-2 model, the layernorm layer. This is a good starting point to understand how the layers are implemented in C.

## CUDA

The full training loop is also implemented in pure CUDA in one file, but optimizations of the kernels are ongoing. Currently, we roughly match the speed of PyTorch. The way we organize code is that we have a growing collection of kernels of increasing complexity in the `dev/cuda` folder, see [dev/cuda/README.md](dev/cuda/README.md). We then copy paste the best kernels into the main training loop in the single training file `train_gpt2cu.cu`.
The full training loop is also implemented in pure CUDA in one file, but optimizations of the kernels are ongoing. Currently, we slightly exceed the speed of PyTorch Nightly. The way we organize code is that we have a growing collection of kernels of increasing complexity in the `dev/cuda` folder, see [dev/cuda/README.md](dev/cuda/README.md). We then copy paste the best kernels into the main training loop in the single training file `train_gpt2cu.cu`.

**WIP alert, April 23**. We merged the first version of mixed precision training code. I checkpointed the fp32 version to separate files that include `_fp32` in their filename, and would like to preserve this version in the root of the repo because it 1) doesn't require the most up to date CUDA and will a lot more likely compile and is more portable, 2) it is a lot simpler and acts as reference. In fact, we'd like to diverge the fp32 version in the direction of being pure CUDA (e.g. do not even call cuBLAS by default), to be used as an educational reference, maybe even a kernel of a course on CUDA. The "mainline" development concerned with speed will from there on move to the [train_gpt2.cu](train_gpt2.cu) file, which includes mixed precision training.

Expand All @@ -198,7 +234,7 @@ make test_gpt2fp32cu

This prints `overall okay: 1`. So the forward activations, backward gradients, and the individual loss values for 10 iterations all match exactly.

**Training**. To train GPT-2 in a single file of CUDA, run the train script:
**Training**. To train on single GPU in fp32:

```bash
make train_gpt2fp32cu
Expand Down Expand Up @@ -228,9 +264,7 @@ For on his rock shall he be opencast.
Keep on with me, my
```

This runs on my A100 in about ~10 seconds. This training loop in the PyTorch script is about 80ms/iteration, so we are slightly better than PyTorch here. However, this is measured with PyTorch that is a bit stale (I'm on 2.1.0) and we're not yet including FlashAttention or the PyTorch scaled_dot_product_attention fused operation.

We can compare to naive PyTorch like this, where we turn on `torch.compile` and the use of TensorCores, which use tf32 type:
This runs on my A100 in about ~10 seconds. We can compare to naive PyTorch like this, where we turn on `torch.compile` and the use of TensorCores, which use tf32 type:

```bash
python train_gpt2.py --write_tensors 0 --sequence_length 1024 --batch_size 4 --compile 1 --tensorcores 1
Expand Down
Loading

0 comments on commit 69f1221

Please sign in to comment.