Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for FP16/BF16 in train_gpt2.cu (1.86x Perf) #218

Merged
merged 12 commits into from Apr 23, 2024

Conversation

ademeure
Copy link
Contributor

@ademeure ademeure commented Apr 22, 2024

Now finished and reasonably happy with it!
1.86x performance on my RTX 4090:

  • FP32: ~80ms
  • BF16: ~43ms (with layernorm params in FP32, but all activations in BF16)

This allows the same train_gpt2.cu to work as full FP32, full BF16, full FP16, or full (BF/FP)16 + FP32 layernorm simply by changing the define at the top of the file. Also included stochastic rounding for the Adam kernel (but nothing else at this point, possibly worth adding to gradients in general when we move to FP8?)

To simplify the logic compared to the first version of the PR, all activation tensors are now always "floatX", we cannot mix-and-match. However, because atomicAdd on 16-bit values in some of the backwards kernels are HORRIBLY slow (10x slower or worse), and because this kind of flexibility seems useful in general for layernorm accuracy, layernorm is kept at FP32 by defining "floatN" as "float".

I reduced the amount of code duplication by using very lightweight templates for the kernel types. It's still a BIG change though, unfortunately I don't think there's any way around that!

@ademeure ademeure marked this pull request as draft April 22, 2024 15:23
@ademeure
Copy link
Contributor Author

It is trivial to use the exact same code with everything in FP32, at the top of train_gpt2.cu simply replace this:

typedef __nv_bfloat16 floatX;
#define CUBLAS_LOWP CUDA_R_16BF

with this:

typedef float floatX;
#define CUBLAS_LOWP CUDA_R_32F


This is now able to train in BF16 for many layers and kinda-sorta works for test_gpt2.cu, as the loss converges much slower than FP32 for now (need to debug how to improve that afterwards):

LOSS MISMATCH AT STEP 1: 4.598247 4.059707
LOSS MISMATCH AT STEP 2: 4.152971 3.375123
LOSS MISMATCH AT STEP 3: 3.828835 2.800783
LOSS MISMATCH AT STEP 4: 3.538793 2.315382
LOSS MISMATCH AT STEP 5: 3.260888 1.849029
LOSS MISMATCH AT STEP 6: 3.000814 1.394656
LOSS MISMATCH AT STEP 7: 2.768756 0.999147
LOSS MISMATCH AT STEP 8: 2.557551 0.624080
LOSS MISMATCH AT STEP 9: 2.352901 0.376511

It does eventually converge:

step 99: loss 0.001294 (took 8.712021 ms)

@ademeure
Copy link
Contributor Author

Debugged the BF16 convergence issue and fixed it by adding stochastic rounding support. Simplified code by making all activations the same type (but params can still be different types partly due to severe perf issues for atomicAdd otherwise).

The PR is now in a good state in my opinion where it's worth thinking about what it would take to integrate it.

@ademeure ademeure changed the title WIP support for FP16/BF16 in train_gpt2.cu (compiles, not correct yet) Support for FP16/BF16 in train_gpt2.cu (1.86x Perf) Apr 23, 2024
@ademeure ademeure marked this pull request as ready for review April 23, 2024 05:59
@ademeure
Copy link
Contributor Author

BTW this approach should work perfectly fine for FP8 as well, the main issue (besides loss scaling) to get that working is cuBLAS non-Lt doesn't support FP8 at all, so we can't use StridedBatched GEMMs for attention, and we need padding to move the other cuBLAS calls to Lt, etc...

By hacking things so all the "cannot be FP8" GEMMs stay at BF16 while halving k (obviously not functionally correct), I got FP8 to run at ~29.5ms (vs ~43ms for BF16). So it does seem to scale reasonably well despite suffering a little bit from Amdahl's Law.

It should be possible to keep gradients as e5m2 and everything else as e4m3 by just adding a "floatG" for e5m2 and casting appropriately, since the storage requirements are the same. What would not work without a LOT more complexity is using types with different sizes (e.g. activations at FP8 and gradients at BF16) but I think we agreed we shouldn't really need that anytime soon.

@karpathy
Copy link
Owner

merging this. we'll iterate in master.

@karpathy karpathy merged commit 6b6ad35 into karpathy:master Apr 23, 2024
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants