Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gain another 10-20%+ on CPU performance on gcc by moving -fno-finite-math-only to only gelu_backwards #168

Closed
wants to merge 0 commits into from

Conversation

dagelf
Copy link
Contributor

@dagelf dagelf commented Apr 17, 2024

More targeted flag optimizations for gcc.

It's the tanhf function in gelu_backwards that causes the model to fail with -ffast-math on gcc on Linux.

Before:

$  grep name /proc/cpuinfo |head -1                                                                                                                            
model name      : Intel(R) Core(TM) i3-9100F CPU @ 3.60GHz

step 0: train loss 5.356086 (took 6167.853384 ms)
step 1: train loss 4.300644 (took 5460.413776 ms)
step 2: train loss 4.623082 (took 5276.372294 ms)
vs
step 0: train loss 5.356185 (took 5714.622339 ms)
step 1: train loss 4.301033 (took 4814.820671 ms)
step 2: train loss 4.623316 (took 4813.711103 ms)
$  grep name /proc/cpuinfo |head -1                                                                                                                            
model name      : AMD Ryzen 5 3600 6-Core Processor

step 0: train loss 5.356085 (took 3397.901288 ms)                                                                                                                                   
step 1: train loss 4.300644 (took 2810.743621 ms)                                                                                                                                   
step 2: train loss 4.623083 (took 2813.287769 ms) 
vs
step 0: train loss 5.356185 (took 2639.362407 ms)
step 1: train loss 4.301032 (took 2258.179942 ms)
step 2: train loss 4.623315 (took 2261.548428 ms)

Timings obtained with:

( kill -STOP -1  # Stop all processes, NB don't run this outside a script!
timeout 40s ./train_gpt2
kill -CONT -1 )

Also noted:

~$ gcc -Ofast -Q --help=optimizers|grep enabled > a
~$ gcc -O3 -Ofast -Q --help=optimizers|grep enabled > b
~$ diff a b

@dagelf
Copy link
Contributor Author

dagelf commented Apr 17, 2024

Also resolves #19 for good I think

@karpathy
Copy link
Owner

So maybe this is ok to merge...
1 it looks a little funny is there no way to combine the double nested if into one condition?
2 i think a comment explaining this would go a long way

@azret
Copy link
Contributor

azret commented Apr 18, 2024

Please don't forget about the MSVC/Windows. MSVC uses pragma to turn off the optimization.

#pragma optimize( "", off )
/* unoptimized code section */
#pragma optimize( "", on )

This is really ugly. I know.

@rosslwheeler
Copy link
Contributor

rosslwheeler commented Apr 18, 2024

My issue with adding pragma's to source files (OpenMP excluded) is that you will keep adding more per platform/compiler. One suggestion was to split this function off into its own file then you can use the Makefile to compile with whatever flags are suitable for the platform/compiler. Makefile's typically have platform dependencies in them. It might be easier from a maintenance standpoint be to keep the source code as clean as possible?

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 18, 2024

@dagelf i knew we could still go further with the cpu, thanks! looking into it

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 18, 2024

So maybe this is ok to merge... 1 it looks a little funny is there no way to combine the double nested if into one condition? 2 i think a comment explaining this would go a long way

yes, you can write this @dagelf:

#if defined(__GNUC__) && !defined(__clang__)
    __attribute__((optimize("no-finite-math-only"))) 
#endif

@dagelf
Copy link
Contributor Author

dagelf commented Apr 18, 2024

@karpathy ifdefs squashed and comment added

@dagelf
Copy link
Contributor Author

dagelf commented Apr 18, 2024

Please don't forget about the MSVC/Windows. MSVC uses pragma to turn off the optimization.

#pragma optimize( "", off ) /* unoptimized code section */ #pragma optimize( "", on )

This is really ugly. I know.

Does it bug out on MSVC with -Ofast too?

@azret
Copy link
Contributor

azret commented Apr 18, 2024

Please don't forget about the MSVC/Windows. MSVC uses pragma to turn off the optimization.
#pragma optimize( "", off ) /* unoptimized code section */ #pragma optimize( "", on )
This is really ugly. I know.

Does it bug out on MSVC with -Ofast too?

yep

@dagelf
Copy link
Contributor Author

dagelf commented Apr 19, 2024

Tested to work with and speed up msvc too.

@karpathy
Copy link
Owner

I'm sorry this is too weird and ugly to merge I think.
Can someone try alternative strategies? For example tanh can be written as a function of exp quite trivially, maybe calling it that way makes it ok?

@dagelf
Copy link
Contributor Author

dagelf commented Apr 20, 2024

Tried that, will need to do both tanhf and expf, busy with the latter... but it might be even uglier ...It's really the msvc part that makes it ugly IMHO 😄

Simply adding:

 __attribute__((optimize("no-finite-math-only")))

Fixes it for gcc. clang always works, but is slow.

msvc needs the pragmas before and after. The #ifdefs are just there to eliminate warnings for foreign pragmas when compiling.

@dagelf
Copy link
Contributor Author

dagelf commented Apr 20, 2024

For now I'm just going to remove the ifdefs to get this down to only two lines, to keep it clean.

Going down the route of performant custom math functions means breaking cross platform compatibility, unless we start exploring lookup tables for CPU inference. Which I will explore next.

There sure is more performance to be gained. I quickly realized that a faster activation function might lead to slower convergence and more training steps, negating the benefits. This is my cue to learn more about what makes the activation function work so that I can develop a better intuition for it. (Any pointers appreciated!)

2024-04-20-211309_725x602_scrot

For the record, it's actually the exponential in the coshf that has the biggest influence on whatever makes gelu_backward break the model. Looking that the activation function graphs above, I think I can see why 😄

If anybody else wants to explore platform specific math function optimizations, here is a good start: https://github.com/bminor/glibc/tree/master/sysdeps/x86_64/fpu

Before playing with lookup tables, I'll compare performance of different activation functions.

@azret
Copy link
Contributor

azret commented Apr 20, 2024

Lookup tables are a great idea

@dagelf dagelf mentioned this pull request May 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants