diff --git a/test_gpt2.c b/test_gpt2.c index e49b73fad..4b8e6b358 100644 --- a/test_gpt2.c +++ b/test_gpt2.c @@ -85,18 +85,18 @@ int main(int argc, char *argv[]) { // overall OK signal for the test int allok = 1; - // let's do 10 training iterations, following the pytorch code + // let's do 10 training iterations float expected_losses[10] = { - 5.270007133483887, - 4.059706687927246, - 3.3751230239868164, - 2.8007826805114746, - 2.315382242202759, - 1.8490285873413086, - 1.3946564197540283, - 0.9991465210914612, - 0.6240804195404053, - 0.37651097774505615 + 5.270957, + 4.209763, + 3.635266, + 3.099755, + 2.646840, + 2.240989, + 1.831270, + 1.422460, + 1.050359, + 0.729938, }; for (int step = 0; step < 10; step++) { @@ -178,7 +178,7 @@ int main(int argc, char *argv[]) { allok = allok && step_loss_ok; // print the timing information at the end - printf("step %d: loss %f (took %f ms) OK = %d\n", step, model.mean_loss, time_elapsed_s * 1000, step_loss_ok); + printf("step %2d: loss %f (took %f ms) OK = %d\n", step, model.mean_loss, time_elapsed_s * 1000, step_loss_ok); } // final judgement diff --git a/train_gpt2.c b/train_gpt2.c index 57bdfe929..e41d45fa8 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -360,34 +360,23 @@ void attention_backward(float* dinp, float* dpreatt, float* datt, } } -#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI) void gelu_forward(float* out, float* inp, int N) { - // (approximate) GeLU elementwise non-linearity in the MLP block of Transformer + // GeLU elementwise non-linearity in the MLP block of Transformer + #pragma omp parallel for for (int i = 0; i < N; i++) { float x = inp[i]; - float cube = 0.044715f * x * x * x; - out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube))); + out[i] = .5f * x * (1.f + erff(x / sqrtf(2))); } } -// we want to use -Ofast optimization, but sadly GeLU breaks, so disable this flag just for it (#168) -#pragma float_control(precise, on, push) -#if defined(__GNUC__) && !defined(__clang__) -__attribute__((optimize("no-finite-math-only"))) -#endif void gelu_backward(float* dinp, float* inp, float* dout, int N) { for (int i = 0; i < N; i++) { float x = inp[i]; - float cube = 0.044715f * x * x * x; - float tanh_arg = GELU_SCALING_FACTOR * (x + cube); - float tanh_out = tanhf(tanh_arg); - float coshf_out = coshf(tanh_arg); - float sech_out = 1.0f / (coshf_out * coshf_out); - float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x); - dinp[i] += local_grad * dout[i]; + float cdf = .5f * (1.f + erff(x / sqrtf(2))); + float pdf = expf(-.5f * x * x) * sqrtf(2.f / M_PI); + dinp[i] = dout[i] * (cdf + x * pdf); } } -#pragma float_control(pop) void residual_forward(float* out, float* inp1, float* inp2, int N) { for (int i = 0; i < N; i++) {