Skip to content

Commit

Permalink
Use proper GeLU on CPU
Browse files Browse the repository at this point in the history
This change removes the tanh GeLU approximation. This gives us the
benefit of better accuracy, roughly equal perf and strict standard
conformance, since we no longer need any compiler-specific tricks.

Here's the last lines of train_gpt2 output before this change:

    step 37: train loss 3.739647 (took 598.548076 ms)
    step 38: train loss 4.611735 (took 596.626145 ms)
    step 39: train loss 3.970751 (took 598.439552 ms)
    val loss 4.016658
    generating:
    ---
    Come Running Away,
    Greater conquer
    With the Imperial blood
    the heaviest host of the gods
    into this wondrous world beyond.
    I will not back thee, for how sweet after birth
    Netflix against repounder,
    will not
    flourish against the earlocks of
    Allay
    ---
    step 40: train loss 4.377756 (took 592.704936 ms)

Here's the last lines of train_gpt2 output after this change:

    step 37: train loss 3.731596 (took 594.893995 ms)
    step 38: train loss 4.561646 (took 600.064035 ms)
    step 39: train loss 3.933512 (took 599.666173 ms)
    val loss 4.014135
    generating:
    ---
    Whether Hipocrates,
    Bigon Nicinius, or rep'd
    With Thy fair winter-tail your outraged hand,
    The richness of the good smour
    Nine years by turns covered my Member. Thou art
    Nay, I fear be; but
    Lets o' thee know, if it
    ---
    step 40: train loss 4.358461 (took 597.594065 ms)

This change has the disadvantage of diverging from PyTorch. I view
this as being justified and worthwhile, for numerous reasons, e.g.

  "I used the tanh approximation simply because the error function
   erf was slow in tensorflow some years ago. If the exact version
   is fast enough now and does not have numerical issues, I do not
   see a reason to use an inexact version."  ──Quoth Dan Hendrycks

See pytorch/pytorch#39853
  • Loading branch information
jart committed May 23, 2024
1 parent 69f1221 commit 34e2a98
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 29 deletions.
24 changes: 12 additions & 12 deletions test_gpt2.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ int main(int argc, char *argv[]) {
// overall OK signal for the test
int allok = 1;

// let's do 10 training iterations, following the pytorch code
// let's do 10 training iterations
float expected_losses[10] = {
5.270007133483887,
4.059706687927246,
3.3751230239868164,
2.8007826805114746,
2.315382242202759,
1.8490285873413086,
1.3946564197540283,
0.9991465210914612,
0.6240804195404053,
0.37651097774505615
5.270957,
4.209763,
3.635266,
3.099755,
2.646840,
2.240989,
1.831270,
1.422460,
1.050359,
0.729938,
};
for (int step = 0; step < 10; step++) {

Expand Down Expand Up @@ -178,7 +178,7 @@ int main(int argc, char *argv[]) {
allok = allok && step_loss_ok;

// print the timing information at the end
printf("step %d: loss %f (took %f ms) OK = %d\n", step, model.mean_loss, time_elapsed_s * 1000, step_loss_ok);
printf("step %2d: loss %f (took %f ms) OK = %d\n", step, model.mean_loss, time_elapsed_s * 1000, step_loss_ok);
}

// final judgement
Expand Down
23 changes: 6 additions & 17 deletions train_gpt2.c
Original file line number Diff line number Diff line change
Expand Up @@ -360,34 +360,23 @@ void attention_backward(float* dinp, float* dpreatt, float* datt,
}
}

#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)
void gelu_forward(float* out, float* inp, int N) {
// (approximate) GeLU elementwise non-linearity in the MLP block of Transformer
// GeLU elementwise non-linearity in the MLP block of Transformer
#pragma omp parallel for
for (int i = 0; i < N; i++) {
float x = inp[i];
float cube = 0.044715f * x * x * x;
out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube)));
out[i] = .5f * x * (1.f + erff(x / sqrtf(2)));
}
}

// we want to use -Ofast optimization, but sadly GeLU breaks, so disable this flag just for it (#168)
#pragma float_control(precise, on, push)
#if defined(__GNUC__) && !defined(__clang__)
__attribute__((optimize("no-finite-math-only")))
#endif
void gelu_backward(float* dinp, float* inp, float* dout, int N) {
for (int i = 0; i < N; i++) {
float x = inp[i];
float cube = 0.044715f * x * x * x;
float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
float tanh_out = tanhf(tanh_arg);
float coshf_out = coshf(tanh_arg);
float sech_out = 1.0f / (coshf_out * coshf_out);
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
dinp[i] += local_grad * dout[i];
float cdf = .5f * (1.f + erff(x / sqrtf(2)));
float pdf = expf(-.5f * x * x) * sqrtf(2.f / M_PI);
dinp[i] = dout[i] * (cdf + x * pdf);
}
}
#pragma float_control(pop)

void residual_forward(float* out, float* inp1, float* inp2, int N) {
for (int i = 0; i < N; i++) {
Expand Down

0 comments on commit 34e2a98

Please sign in to comment.