Replies: 11 comments 20 replies
-
Next up "this is how to train Llama 3 8B in 72 hours for 1500$"🫡🫡🫡 |
Beta Was this translation helpful? Give feedback.
-
400B token run: "This model dramatically beats GPT-2 and GPT-3 of its size on HellaSwag (it gets up to ~61%), but sadly becomes unstable there on and explodes. " Would you be able to release the model right before the explosion? I would be interested to learn what instability and explosion look like in a model. |
Beta Was this translation helpful? Give feedback.
-
I love the simplicity, power, and attention to detail. Well done! I hope to experiment with this code myself someday soon. |
Beta Was this translation helpful? Give feedback.
-
What about new flash attention 3? Will it slash price in half? |
Beta Was this translation helpful? Give feedback.
-
Thanks again for your work to democratize AI. I'm only part way through makemore now, but am blown away by how simple you can make these tough topics. |
Beta Was this translation helpful? Give feedback.
-
Also want to thank Ubicloud (www.ubicloud.com) for providing the GitHub Nvidia GPU Runners for CI and Nvidia for sponsoring this. Thank you!!! CC: @karpathy |
Beta Was this translation helpful? Give feedback.
-
Hi, Andrej I'm watching your videos. Especially llama.c. Can you please share resources how you created own transformer. |
Beta Was this translation helpful? Give feedback.
-
@karpathy I have an integrated GPU(ryzen) with 16GB VRAM. Would like to train a 256m gpt with 32k context length (using https://github.com/anthonix/llm.c). |
Beta Was this translation helpful? Give feedback.
-
If I only have 1 GPU can I use |
Beta Was this translation helpful? Give feedback.
-
Currently running this on one H100 and it's on step 20654 after 20 days of training I was having problems with memory so had to tweak the object size. |
Beta Was this translation helpful? Give feedback.
-
Just completed training and it took one month on one H100.
|
Beta Was this translation helpful? Give feedback.
-
In this post we are reproducing GPT-2 in llm.c. This is "the GPT-2", the full, 1558M parameter version that was introduced in OpenAI's blog post Better Language Models and their Implications in February 14, 2019. llm.c does so directly in C/CUDA (total of ~5,000 lines of code), without the typical training stack that would involve the Python interpreter and a significantly more complex deep learning library like PyTorch/JAX, huggingface/transformers, or etc. In 2019, training GPT-2 was an involved project from an entire team and considered a big model run but, ~5 years later, due to improvements in compute (H100 GPUs), software (CUDA, cuBLAS, cuDNN, FlashAttention) and data (e.g. the FineWeb-Edu dataset), we can reproduce this model on a single 8XH100 node in 24 hours, and for $672, which is quite incredible. There are some caveats and asterisks involved though - llm.c is still not perfectly tuned, sufficiently stabilized (we still see loss spikes and bad activation ranges now and then), and our evals are not comprehensive (e.g. we do not carefully eval multilingual, code, math); A lot of this work is still ongoing.
Unicorns. The natural place to start is probably with unicorns. In the GPT-2 blog post, OpenAI prompted GPT-2 with this unicorn prompt. Let's see what our model thinks about English-speaking unicorns in the Andes mountains:
Prompt: In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.
Completion: It may sound like a hoax, or just a dream of the mythical animals, but it appears this may be a real case of the unicorn language. The Andean area has an abundance of minerals and is also very rich in biodiversity. According to the researchers' study, the unicorns' language is called Elveseo. To produce it, the animal forms a consonant from the lips and makes a vowel sound with the teeth, then produces a consonant sound with the mouth and makes an 'ow' sound from its chest. "The question is, what caused this, and what does this tell us about intelligence and communication?" commented study co-author Paul Marek. Speaking in the video, the unicorns had the right way to speak, the animals were able to communicate and talk in perfect English.
Fun! :) The model is fairly coherent and qualitatively somewhere around the level of GPT-2. You can find 20 samples from both GPT-2 and the llm.c model here, or generate many more using instructions down below.
Training. Training a GPT-2 with llm.c is quite simple because it is written in C/CUDA, so there is no need for minconda, Python, PyTorch, etc. You will want an 8XH100 GPU box, I recommend spinning one up from Lambda labs. But llm.c is flexible on its compute - if you have only 1 GPU you can still get your GPT-2, you'll just have to wait 8 days instead of 1. If you have 16 GPUs (e.g. using the new Lambda 1 Click Clusters), you'll be able to train multinode and only have to wait 12 hours. Once you spin up your node, here are the complete instructions to train your GPT-2 (this only takes a ~minute from blank box to start stepping):
I will describe the args in a second. You'll see a bunch of prints scroll through and then the optimization will begin:
We can see that each step is about 2.75 seconds and there are 32,000 of them, so now we wait ~24 hours. At every step, this training run takes a chunk of ~1 million tokens of FineWeb-EDU (these are educational web pages from the internet), and updates the 1558 million weights of the model to be slightly better at predicting the next token in a sequence. By the end we'll have processed 32,000 * 1048576 = 33.6B tokens in total. The loss goes down as we do a better job predicting the next token. The norm will stabilize around 0.1-1, the learning rate is being warmed up over the first few steps. Our model flops utilization (MFU) is around 50%, i.e. quite efficient.
Now wait 24 hours for this to finish, then you can visualize the
main.log
log file using the dev/vislog.ipynb jupyter notebook. For this you will need to also have Python and matplotlib installed, and you will see the following:Evals. On the left we are tracking the loss on FineWeb-EDU validation data. If you simply run the GPT-2 released by OpenAI and evaluate its loss on this split, you get the red horizontal line (loss 2.83). You see that our run outperforms this very very quickly, by step ~5,000. However, this is not a fair comparison because GPT-2 was trained on the never-released WebText dataset, so there is a possibly large distribution shift. So e.g. if you finetune the OpenAI model for 1,000 steps at LR 1e-4, the loss quickly plunges to the blue line (loss 2.61), because it's quickly adapting to the new data statistics. I like to look at the validation loss as a sanity check, but for the actual comparison we'd want to look at fixed, 3rd party evaluations. One of the well-behaved, smooth, common, often-cited evals that also offer early signal is the HellaSwag eval. These are simple common sense scenarios and the model has to pick the correct continuation. We evaluate HellaSwag on the right pane, where we see that we cross over the GPT-2 model around step ~25K (earlier than GPT-2, which is estimated to have been trained on ~100B tokens. This possibly has to do with increased data quality, as we also observed in our earlier 124M run). The green line is the GPT-3 model of the same size, which is pretty much the same model architecture as GPT-2 with minor differences (context length 1024 -> 2048) but trained for 300B tokens (i.e. ~10X more tokens than what we trained on here). I should say that even HellaSwag is not an ideal single point of comparison because it tests simple English and common sense, it does not test e.g. multilingual, math or code. It could have been that the WebText data mixture was a lot heavier on these, and these domains were "stealing" model capacity to some extent, we don't know because it was never released. Lastly, in general, good evals are harder at low model capability like GPT-2 because e.g. the models don't understand multiple choice, and their samples are not high enough quality to make above chance dent into standard math or code evals.
Args guide. Let's look at the args we passed into the training now in more detail. The GPT-2 release from OpenAI included model weights but very few details, while GPT-3 release had no weights but many details. So in many cases, we follow the GPT-3 paper hyperparameters because the GPT-2 paper has very very little information:
mpirun -np 8 ./train_gpt2cu \
the launch command: we're using mpi to launch 8 processes (each process runs training on 1 GPU, for 8 GPUs total on this example 8XH100 node). If you have 4 GPUs, use-np 4
. If you have 1 GPU, you can skip mpi, i.e. simply change this to./train_gpt2cu
.-i -j
are training and validation splits token files, downloaded earlier withedu_fineweb.sh
-o
is the output directory to write logs and checkpoints into-v 250
asks to evaluate and log the validation loss every 250 steps-s 300000
asks to sample some tokens every 300000 steps. Because the total number of steps will be less than this, this is hacky way to turn sampling off and we will only sample a single time at the very end.-g 384
sets the number of tokens to be sampled at the end to be 384-h 1
asks to evaluate the HellaSwag accuracy-b 16
sets the micro-batch size to 16 . If you are running out of memory, decrease this value, e.g. try 8, 4, 2, all the way down to 1 potentially.-t 1024
sets the maximum sequence length to 1024, as GPT-2 did-d 1048576
asks that the total batch size be 2 to the power 20, following the GPT-3 paper hyperparameters table. The code will make sure to meet this desired total batch size and calculate the needed gradient accumulation "inner loop" steps of the optimization. For example up above, we saw that we have 8 GPUs each doing 16 X 1024 tokens, so that is 8 X 16 X 1024 = 131,072 tokens per micro-step (a single forward backward), so the code calculated gradient accumulation steps of 8 to meet the desired 1M batch size per step. i.e. it does forward+backward 8 times and then a single update.-r 0
sets recompute to zero. Recompute is a way to trade off compute and memory. If-r 1
, then we recompute a piece of the forward pass (the GeLU) during backward. This means we don't have to cache it and save memory, at the cost of some more compute. So if you're running out of memory, try -r 1, or -r 2 (also recompute layernorms).-z 1
turns on ZeRO-1 (i.e. optimizer state sharding) across multiple GPUs. If you're training with > 1 GPU, this setting is a no-brainer and should basically always be on. On 1 GPU this setting is a no-op.-c 0.1
sets the weight decay to 0.1. Only (2D) weights are decayed exactly as in GPT-2, and this number comes from the GPT-3 paper-k "cosine"
sets the cosine learning rate schedule, which is the default so this is a bit spurious.-l 0.0006
sets the maximum learning rate to 6e-4. The GPT-3 paper says to use 2e-4 for this model size, but here we triple and it and seems to train faster and without any issues. This wasn't tuned very carefully yet.-q 0.1
says that we will decay the learning rate to 10% of max LR over the course of training, following GPT-3 paper.-u 700
says that we will ramp up the learning rate from 0 to max learning rate over the first 700 iterations, which at total batch size 0.5M is 350M tokens, following GPT-3 paper.-n 2000
asks to save model checkpoints every 2000 steps.-x 32000
asks for 32K steps in total. I chose this number because it is a nice number, and just fits into 24 hours.-ge 1
sets a very recently merged gelu recompute setting for CublasLt (optional)-y 1
sets the "resume" flag on. If your training for any reason crashes or hangs, you can CTRL+C and re-run this command, and it will attempt to resume the optimization. llm.c is bitwise-deterministic, so you'll get the identical result as if you didn't crash.-e "d48"
asks to initialize, a depth 48 GPT-2 model from scratch.Memory guide. The biggest constraint most people will probably face is that their GPU doesn't have 80GB. That's okay you should still be able to run everything above if you are patient, it would just run slower. So if the model doesn't fit, what do you play with? The most important one is the micro batch size
-b
. Try to decrease it but keep it to nice numbers. So e.g. 16 -> 8 -> 4 -> 2 -> 1. From there, try to also play with the recompute setting-r
which is 0 (fastest, a lot of memory), 1 (very slightly slower, but a huge memory saving), or 2 (slightly slower, smaller memory saving). The next thing you can do is disable master weights in fp32, which you can do with-w 0
(1 is default). We won't maintain fp32 copy of params. Empirically in a few runs before this seems to be okay, likely due to our use of stochastic rounding. If even that doesn't fit (that's unlikely right?), you could try to decrease the maximum sequence length with-t
, default is 1024 you can take it down to 512, 256, etc., but now you are making your model worse because you're decreasing its maximum attention span.Code. Certainly I feel biased but llm.c is quite beautiful:
The main entry point and the majority of the code is in the file train_gpt2.cu. It contains the GPT-2 model definition and the training loop in ~2,000 LOC, and it imports a bunch of helper files with various utilities and the individual layer implementations from the
llmc
directory.cloc llmc
reports 23 files with 3170 LOC, andcloc train_gpt2.cu
is 1353 LOC atm.Multi-node training. If you are part of the privileged GPU-rich upper class, llm.c supports multi-node training and the most GPUs I've seen someone train llm.c with is ~500 GPUs. This biggest run I've done personally so far is on Lambda's new 1-click cluster feature with 16XH100 GPUs in 2 nodes. The downsides of unemployment. The lambda team has put up detailed instructions on how you can train llm.c models on their 1-click clusters. E.g. with the 512-GPU H100 cluster for $2,300/hr, you might be able to train your GPT-2 in ~30 minutes. You'd have to increase the total batch size (e.g. to ~8M) and possibly tune the hyperparameters a little. I haven't tried but it probably works and would be very cool :)
PyTorch comparison. A relatively comparable run in PyTorch would I think look something like this, using our parallel PyTorch implementation:
The PyTorch code is meant as a testing reference not an actual implementation, so the training loop is a little bit different in some places (e.g. the dataloader doesn't permute the shards, etc.), but this is still possibly useful as a point of reference. I also hacked the default vocab size to be 50257 -> 50304 to get added efficiency, then the currently PyTorch nightly gives:
Now I wouldn't say I have full confidence that the PyTorch script is maximally tuned, but the following observations can be made. PyTorch seems to be taking a lot more memory (this run is ~80GB), while llm.c is at 57GB (29% improvement). Memory is important because it allows you to crank up the batch size (e.g. llm.c can go up to 24 microbatch here), which goes a bit faster. Second, we're seeing about 3386 vs. 2750ms per iteration, so llm.c is stepping ~19% faster. Some of the gains here have known origin, e.g. llm.c includes optimizations like the Fused classifier that kicks off the backward pass, which is something torch.compile does not do today afaik. But it's also possible that this script isn't fully maximally tuned, but in any case I'm showing the comparison in case 1) others would like to take a look, play with, compare, help tune and 2) to just say that llm.c is quite optimized and fast - in the specific case of GPT-2/3 training.
The final model. A few links that may be helpful, for posterity:
Model export. The model export can be done as follows, for example:
This then lets you run the Eleuther eval harness, or run the huggingface sampling pipeline to get model samples:
Also have a look at dev/eval for instructions on how to run the Eleuther Evaluation Harness, the evals from the HuggingFace Open LLM Leaderboard, etc.
400B token run. I have also made the attempt to train GPT-2 for significantly longer than 33B tokens. In particular, I changed -x to 400,000 to train for 420B tokens (even more than GPT-3 model of this size, which was trained with 300B). This model run looked great until about step 330,000:
This model dramatically beats GPT-2 and GPT-3 of its size on HellaSwag (it gets up to ~61%), but sadly becomes unstable there on and explodes. There are more smaller spikes along the way but the code is configured to detect the more simple instantaneous instability and skips update (I used the flags
-sl 5.0 -sg 5.0
), which helps mitigate and defers issues. However, I think we're not yet being sufficiently careful with our initialization, activation ranges, and overall model training stability and there are deeper issues that gradually drift the model into instability, especially for larger models and over long training duration. To be continued. If you have ideas or recommendations for stabilizing LLM model training please contribute your experience in the discussion below.FAQ:
GPT-2 (124M). I wanted to also link to an earlier post on training the GPT-2 (124M) model in llm.c, which has some more related information to llm.c runs. 124M is a smaller model in the GPT-2 miniseries, only 124M parameters compared to 1558M parameters.
Authors
Substantial contributions to llm.c came from what now feels like the llm.c core dev team, in addition to self:
Coming up. Some of the next big steps we are interested in and looking at these days:
The goal of llm.c remains to have a simple, minimal, clean training stack for a full-featured LLM agent, in direct C/CUDA, and companion educational materials to bring many people up to speed in this awesome field.
Please feel free to use the Discussions for any FAQ and related, or if you'd like something faster, #llmc on Discord, or #llmdotc on CUDA MODE Discord.
We'll see you next time!
Beta Was this translation helpful? Give feedback.
All reactions