-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use proper GeLU on CPU #441
base: master
Are you sure you want to change the base?
Conversation
Sure works! ... Is this applicable to the CUDA version? How will this affect fine tuning? (Or a hypothetical retraining run of the base model?) (I'm still learning A LOT here) It's even a few ms faster 😅 Good way to run benchmarks:
Benchmark:
(Note to self, different activation functions and resources: #168 and optimizations: master...dagelf:llm.c:activation_function_tests_cpu) Wow, what CPU is that?! Also, maybe this would pique your interest: #253 |
Haven't tried.
No idea.
It's an AMD Ryzen Threadripper PRO 7995WX.
Your fastest activation function is going to be vectorized SiLU. https://news.ycombinator.com/item?id=40371612 erff() is a lot simpler than tanhf() but SiLU uses expf() which is even simpler and less branchy. /* Efficient implementation of erff()
using either a pure polynomial approximation or
the exponential of a polynomial.
Worst-case error is 1.09ulps at 0x1.c111acp-1.
From the Optimized Routines by Arm Limited. */
float erff(float x) {
union {
float f;
unsigned i;
} pun = {x};
float r, x2, u;
unsigned ix = pun.i;
unsigned sign = ix >> 31;
unsigned ia12 = (pun.i >> 20) & 0x7ff;
if (ia12 < 0x3f6) {
if (ia12 >= 0x318) {
x2 = x * x;
r = -0x1.3a1a82p-11f;
r = fmaf(r, x2, +0x1.473f48p-08f);
r = fmaf(r, x2, -0x1.b68bd2p-06f);
r = fmaf(r, x2, +0x1.ce1a46p-04f);
r = fmaf(r, x2, -0x1.8126e0p-02f);
r = fmaf(r, x2, +0x1.06eba6p-03f);
r = fmaf(r, x, x);
} else {
if (ia12 >= 0x040)
r = x + 0x1.06eba8p-3f * x;
else
r = fmaf(0x1.06eba8p-3f, x, x);
}
} else if (ia12 < 0x408) {
float a = fabsf(x);
r = fmaf(0x1.222900p-16f, a, -0x1.91d2ccp-12f);
u = fmaf(0x1.fd1336p-9f, a, -0x1.8d6300p-6f);
x2 = x * x;
r = fmaf(r, x2, u);
r = fmaf(r, a, 0x1.b55cb0p-4f);
r = fmaf(r, a, 0x1.450aa0p-1f);
r = fmaf(r, a, 0x1.079d0cp-3f);
r = fmaf(r, a, a);
r = expf(-r);
if (sign)
r = -1.f + r;
else
r = 1.f - r;
} else {
if (ia12 < 0x7f8) {
if (sign)
r = -1.f;
else
r = 1.f;
} else {
r = (1.f - (float)((ix >> 31) << 1)) + 1.f / x;
}
}
return r;
}
|
This change removes the tanh GeLU approximation. This gives us the benefit of better accuracy, roughly equal perf and strict standard conformance, since we no longer need any compiler-specific tricks. Here's the last lines of train_gpt2 output before this change: step 37: train loss 3.739647 (took 598.548076 ms) step 38: train loss 4.611735 (took 596.626145 ms) step 39: train loss 3.970751 (took 598.439552 ms) val loss 4.016658 generating: --- Come Running Away, Greater conquer With the Imperial blood the heaviest host of the gods into this wondrous world beyond. I will not back thee, for how sweet after birth Netflix against repounder, will not flourish against the earlocks of Allay --- step 40: train loss 4.377756 (took 592.704936 ms) Here's the last lines of train_gpt2 output after this change: step 37: train loss 3.731596 (took 594.893995 ms) step 38: train loss 4.561646 (took 600.064035 ms) step 39: train loss 3.933512 (took 599.666173 ms) val loss 4.014135 generating: --- Whether Hipocrates, Bigon Nicinius, or rep'd With Thy fair winter-tail your outraged hand, The richness of the good smour Nine years by turns covered my Member. Thou art Nay, I fear be; but Lets o' thee know, if it --- step 40: train loss 4.358461 (took 597.594065 ms) This change has the disadvantage of diverging from PyTorch. I view this as being justified and worthwhile, for numerous reasons, e.g. "I used the tanh approximation simply because the error function erf was slow in tensorflow some years ago. If the exact version is fast enough now and does not have numerical issues, I do not see a reason to use an inexact version." ──Quoth Dan Hendrycks See pytorch/pytorch#39853
Hi @jart it's nice to see you stop by! I don't think I can merge this because (for educational and historic reasons) I am trying to be compatible with GPT-2 and the checkpoints that OpenAI has released, in the current version of the code. It's possible that in the future we'll diverse from Exact-GPT-2 and this would make a lot more sense then, but in that case we'd probably also shift from GeLU to something that (probably?) works a bit better - GeGLU / SwiGLU or etc. |
This change removes the tanh GeLU approximation. This gives us the benefit of better accuracy, roughly equal perf and strict standard conformance, since we no longer need any compiler-specific tricks.
Here's the last lines of train_gpt2 output before this change:
Here's the last lines of train_gpt2 output after this change:
This change has the disadvantage of diverging from PyTorch. I view this as being justified and worthwhile, for numerous reasons, e.g.
See pytorch/pytorch#39853