Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error: must forward with targets before backward #19

Closed
lizhipengpeng opened this issue Apr 9, 2024 · 38 comments · Fixed by #200
Closed

Error: must forward with targets before backward #19

lizhipengpeng opened this issue Apr 9, 2024 · 38 comments · Fixed by #200

Comments

@lizhipengpeng
Copy link

image

when i ./train_gpt2.c , i get the result above Error.

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 9, 2024

did you check if you have OpenMP installed?

# on macOS
$ brew install libomp
# on ubuntu
$ sudo apt-get install libomp-dev

after installing build it again with make train_gpt2, and tune the OMP_NUM_THREADS=8 based on how many cores your cpu have.

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 9, 2024

we have this error when the mean_loss is = -1.0f which is the initial value, and this attribute is only updeted in the gpt2_forward function when the target param is not NULL. So somewhere in the training loop or in the dataloader gpt2_backward function is called before than the gpt2_forward.

@aircwt
Copy link

aircwt commented Apr 9, 2024

I got the same case.
Edit the Makefile, remove -Ofast flag as below:
CFLAGS = -O3 -Wno-unused-result

@karpathy
Copy link
Owner

karpathy commented Apr 9, 2024

Does anyone understand how -O3 can possibly be causing this issue?

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 9, 2024

-Ofast already enables all the optimizations from -O3 and also other aggressive optimizations, maybe it's better to either use -O3 or -Ofast, not both?

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 9, 2024

@karpathy or maybe the optim could be custom made llvm passes targeting train_gpt2 and not automatically generated from the -O flags

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 9, 2024

The -Ofast flag enables the -ffast-math flag.
The -ffast-math flag allows non-finite values and it can also ignore some mathematical rules that dictate that certain optimizations are not allowed.
It's possible that -Ofast and -ffast-math are causing errors in the floating point arithmetic.
It works with only -O3 enabled.

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 9, 2024

Ok I guess I found the problem, now -Ofast and -O3 enabled work togheter, the issue was caused by -ffast-math being enabled messing up with floating point arithmetic.
This will work also:

CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result

@Infatoshi
Copy link

I still get the "backward before forward" error when using CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result. I'm not a systems guy but omitting fast math seemed to work. Try changing to the following at the top of the make file:

# CFLAGS = -O3 -Ofast -Wno-unused-result
CFLAGS = -O3 -Wno-unused-result
LDFLAGS =
LDLIBS = -lm
INCLUDES =```

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 9, 2024

@Infatoshi Yes, if CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result,
using only CFLAGS = -O3 -Wno-unused-result without -0fast should work

@mikepapadim
Copy link

No fast math flag seems to solve the issue for me as well.
-O3 -Ofast -fno-fast-math -Wno-unused-result

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 9, 2024

Ok so from what i saw, and to recap, on ubuntu machines (i am on 22.04):

CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result

will work. If this doesn't work just remove -Ofast and everything should be fine:

CFLAGS = -O3 -Wno-unused-result

@karpathy
Copy link
Owner

karpathy commented Apr 9, 2024

I can't repro this issue on my computer. Would it maybe work to change the check as:

if (model->mean_loss < 0) {
    printf("Error: must forward with targets before backward\n");
    exit(1);
}

as the loss can never be negative.

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 9, 2024

Ok, i will try this change in train_gpt2.c and then build it with your original Makefile, since i also had the issue as lot of others ubuntu machines.
I updated train_gpt2.c with:

if (model->mean_loss < 0) {
    printf("Error: must forward with targets before backward\n");
    exit(1);
}

i used the original Makefile with -Ofast -O3 enabled (without disabling -ffast-math) and it throws nan:

step 0: train loss 5.356172 (took 11253.211517 ms)
step 1: train loss nan (took 9569.404843 ms)
step 2: train loss nan (took 9525.026318 ms)
step 3: train loss nan (took 9518.282173 ms)
step 4: train loss nan (took 9924.498914 ms)

@jrrk2
Copy link

jrrk2 commented Apr 10, 2024

The previously suggested fix works for me on Ubuntu 22.04.4 LTS (I have OpenMP installed)

CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result

I had no joy on MacOS due to the complexity of achieving a consistent python environment. However I suppose the Python stages could be run on another machine and just the results copied to MacOS to be used as the input to the C phase.

With this change my final output was the following two poems:

<|endoftext|>I was so frightened with your face: to come and though they would not do it any more than as
Let us; but who ever can turn
Against a world so full,
That there'll have been none of our fightmen but
Weaver-bats and tearing men, and stir them utterly;

<|endoftext|>CLAUSE:
I cannot, sir, can; as I would
do Obama to please
many poor tinted statesmen, boiling the o' the air. Newton:
Now let my neck as disengaged as
the of a farmer
mungs the soil,
during her westerly

Not precisely Booker prize material but nevertheless begins to give an insight into the work involved in producing useful research results.

@gudeming
Copy link

using WSL, only this CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result work.

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 10, 2024

so @karpathy the issue seems to be only on ubuntu systems (or linux in general), on macos works fine.
on ubuntu it's solved with the previous suggestions, maybe a comment about this on the readme? or a little comment in the makefile maybe to underline the issue, no need for a pr for such a small fix.
but im still confused about the use of -O3 and -Ofast togheter, since -Ofast already enables all the opt from -O3 and add other more aggressive opts, i think it's reduntant, or you noticed some perf improvements using them togheter?

@bexcite
Copy link

bexcite commented Apr 10, 2024

Stumbled on the same error Error: must forward with targets before backward on my Ubuntu22.04 machine, the fix apparently is not using both -O3 and -Ofast together. Or if they used together then -fno-fast-math may be used to fix the thing.

And the thing is that model->mean_loss and underlying values in model->acts.probs that us used in crossentropy_forward(model->acts.losses, model->acts.probs, targets, B, T, V); are nan :( which tips us to the thing that it's indeed some underlying math optimizations that messed the thing (with allowed nans) and ruined the expected invariant.

Some info on my machine/compiler for others (I believe it's also depend on specific compiler used):

cc -v
Using built-in specs.
COLLECT_GCC=cc
COLLECT_LTO_WRAPPER=/usr/lib/gcc/x86_64-linux-gnu/11/lto-wrapper
OFFLOAD_TARGET_NAMES=nvptx-none:amdgcn-amdhsa
OFFLOAD_TARGET_DEFAULT=1
Target: x86_64-linux-gnu
Configured with: ../src/configure -v --with-pkgversion='Ubuntu 11.4.0-1ubuntu1~22.04' --with-bugurl=file:///usr/share/doc/gcc-11/README.Bugs --enable-languages=c,ada,c++,go,brig,d,fortran,objc,obj-c++,m2 --prefix=/usr --with-gcc-major-version-only --program-suffix=-11 --program-prefix=x86_64-linux-gnu- --enable-shared --enable-linker-build-id --libexecdir=/usr/lib --without-included-gettext --enable-threads=posix --libdir=/usr/lib --enable-nls --enable-bootstrap --enable-clocale=gnu --enable-libstdcxx-debug --enable-libstdcxx-time=yes --with-default-libstdcxx-abi=new --enable-gnu-unique-object --disable-vtable-verify --enable-plugin --enable-default-pie --with-system-zlib --enable-libphobos-checking=release --with-target-system-zlib=auto --enable-objc-gc=auto --enable-multiarch --disable-werror --enable-cet --with-arch-32=i686 --with-abi=m64 --with-multilib-list=m32,m64,mx32 --enable-multilib --with-tune=generic --enable-offload-targets=nvptx-none=/build/gcc-11-XeT9lY/gcc-11-11.4.0/debian/tmp-nvptx/usr,amdgcn-amdhsa=/build/gcc-11-XeT9lY/gcc-11-11.4.0/debian/tmp-gcn/usr --without-cuda-driver --enable-checking=release --build=x86_64-linux-gnu --host=x86_64-linux-gnu --target=x86_64-linux-gnu --with-build-config=bootstrap-lto-lean --enable-link-serialization=2
Thread model: posix
Supported LTO compression algorithms: zlib zstd
gcc version 11.4.0 (Ubuntu 11.4.0-1ubuntu1~22.04)
cat /etc/os-release

cat /etc/os-release
PRETTY_NAME="Ubuntu 22.04.4 LTS"
NAME="Ubuntu"
VERSION_ID="22.04"
VERSION="22.04.4 LTS (Jammy Jellyfish)"
VERSION_CODENAME=jammy
ID=ubuntu

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 10, 2024

@bexcite yes, apparently the flag -ffast-math messes up the floating point arithmetic, but does it work on your machine with -Ofast enabled and -fno-fast-math to disable -ffast-math?
Also -ffast-math does this:

- Floating-point math obeys regular algebraic rules for real numbers 
  (e.g. + and * are associative, x/y == x * (1/y), and (a + b) * c == a * c + b * c),

- Operands to floating-point operations are not equal to NaN and Inf, and

- +0 and -0 are interchangeable.

@bexcite
Copy link

bexcite commented Apr 10, 2024

@ent0n29 Yes, disabling -ffast-math works (when -Ofast used).

On my machine I've tried this combos:

-O3 -Ofast -> failure (default)
-O3 -> good
-O3 -Ofast -fno-fast-math -> good

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 10, 2024

yes, this is what I mentioned here for ubuntu machines:

Ok so from what i saw, and to recap, on ubuntu machines (i am on 22.04):

CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result

will work. If this doesn't work just remove -Ofast and everything should be fine:

CFLAGS = -O3 -Wno-unused-result

@karpathy
Copy link
Owner

I added a comment in README for now. I don't have a good sense of when the code works or does not work, so it feels hard to change the Makefile generically atm. @ent0n29 does removing these flags make the code slower? (I'd expect yes)

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 10, 2024

Yes removing -O3 or -Ofast will make the code slower, i don't know how much -ffast-math influence the speed, and also -Ofast is just a wrapper of -O3 with more aggresive opts, i would use either one or the other

@modigeko
Copy link

I've tried both options on my laptop (i7-6600U CPU @ 2.60GHz, 16GB RAM) running Debian Bookworm:

  • CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result
  • CFLAGS = -O3 -Wno-unused-result

and both ended up with the same results. Each step took ~40 seconds:

$ make train_gpt2
NICE Compiling with OpenMP support
cc -O3 -Ofast -fno-fast-math -Wno-unused-result -fopenmp -DOMP   train_gpt2.c -lm -lgomp -o train_gpt2

$ OMP_NUM_THREADS=8 ./train_gpt2
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
train dataset num_batches: 1192
val dataset num_batches: 128
num_activations: 73323776
val loss 5.251911
step 0: train loss 5.356082 (took 25964.746220 ms)
step 1: train loss 4.300639 (took 24825.226970 ms)
step 2: train loss 4.623087 (took 25338.944616 ms)
step 3: train loss 4.599362 (took 49262.943949 ms)
step 4: train loss 4.616664 (took 41084.071553 ms)
step 5: train loss 4.231427 (took 39915.190249 ms)
...
step 36: train loss 4.101098 (took 45423.354491 ms)
step 37: train loss 3.740978 (took 45066.928727 ms)
step 38: train loss 4.618744 (took 42166.453508 ms)
step 39: train loss 3.972259 (took 54226.899437 ms)
val loss 4.109253
generated: 50256 16827 19108 25 198 40 2314 11 15967 11 460 26 355 314 561 198 4598 2486 284 3387 198 21834 3595 34791 276 2585 3653 11 24372 262 267 6 262 1633 13 17321 25 198 3844 1309 616 7393 355 39402 1886 355 198 1169 286 257 18739 198 76 2150 82 262 9260 11 198 42122 607 266 7834 306 
step 40: train loss 4.378441 (took 41650.695340 ms)

Meanwhile this guy used Rasp Pi 5 and it took him ~13 seconds for each step.
Something isn't right.

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 11, 2024

[#21] what i have is:

  • CFLAGS = -Ofast -fno-fast-math -Wno-unused-result :
       step 0: train loss 5.356082 (took 19942.187432 ms)
       step 1: train loss 4.300639 (took 20333.230115 ms)
       step 2: train loss 4.623087 (took 20157.211989 ms)
       step 3: train loss 4.599362 (took 19900.653274 ms)
       step 4: train loss 4.616664 (took 19071.652778 ms)
       step 5: train loss 4.231427 (took 19976.814785 ms)
       step 6: train loss 3.753161 (took 20760.803743 ms)
       step 7: train loss 3.650458 (took 19308.340202 ms)
       step 8: train loss 4.182242 (took 19559.064261 ms)
       step 9: train loss 4.199580 (took 18556.248236 ms)

-Ofast without -O3 and -ffast-math disabled average at ~20 seconds.

  • CFLAGS = -O3 -Wno-unused-result :
step 0: train loss 5.356082 (took 28424.368136 ms)
step 1: train loss 4.300639 (took 20445.511701 ms)
step 2: train loss 4.623087 (took 22656.468311 ms)
step 3: train loss 4.599362 (took 19115.014434 ms)
step 4: train loss 4.616664 (took 19833.797978 ms)
step 5: train loss 4.231427 (took 18573.217460 ms)
step 6: train loss 3.753161 (took 18102.854112 ms)
step 7: train loss 3.650458 (took 18000.311629 ms)
step 8: train loss 4.182242 (took 28836.764671 ms)
step 9: train loss 4.199580 (took 24153.199814 ms)

only -O3 without -Ofast is a bit slower but still around ~20 seconds average.

  • CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result is the fastest:
step 0: train loss 5.356082 (took 17718.687714 ms)
step 1: train loss 4.300639 (took 17256.573805 ms)
step 2: train loss 4.623087 (took 16764.518172 ms)
step 3: train loss 4.599362 (took 16864.526737 ms)
step 4: train loss 4.616664 (took 16765.048234 ms)
step 5: train loss 4.231427 (took 16944.676229 ms)
step 6: train loss 3.753161 (took 20110.965357 ms)
step 7: train loss 3.650458 (took 18992.590776 ms)
step 8: train loss 4.182242 (took 19528.572922 ms)
step 9: train loss 4.199580 (took 17612.805042 ms)

Every step is around ~17 seconds, also every other step until step 40 are around 16/17 seconds.
So apparently using both -O3 and -Ofast make the code faster, but I can't test this other 2 scenarios:
CFLAGS = -O3 -Ofast -Wno-unused-result vs CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result
because the first (original Makefile) doesn't work on my machine. I will try some other sample code to understand if -ffast-math makes a significant difference in terms of speed.

@snow-ghost
Copy link

CFLAGS = -O3 -Wno-unused-result

step 0: train loss 5.356082 (took 6041.560620 ms)
step 1: train loss 4.300639 (took 5486.315285 ms)
step 2: train loss 4.623087 (took 5652.058295 ms)
step 3: train loss 4.599362 (took 5609.653087 ms)
step 4: train loss 4.616664 (took 5935.522067 ms)
step 5: train loss 4.231427 (took 6457.001960 ms)
step 6: train loss 3.753161 (took 5523.612809 ms)
step 7: train loss 3.650458 (took 5445.255213 ms)
step 8: train loss 4.182242 (took 5285.022331 ms)
step 9: train loss 4.199580 (took 5515.513513 ms)

CFLAGS = -Ofast -Wno-unused-result

step 0: train loss 5.356082 (took 5844.298649 ms)
step 1: train loss 4.300639 (took 5520.787391 ms)
step 2: train loss 4.623087 (took 5472.389458 ms)
step 3: train loss 4.599362 (took 5507.571574 ms)
step 4: train loss 4.616664 (took 5529.754121 ms)
step 5: train loss 4.231427 (took 5490.549419 ms)
step 6: train loss 3.753161 (took 5749.419360 ms)
step 7: train loss 3.650458 (took 5714.291643 ms)
step 8: train loss 4.182242 (took 5503.192129 ms)
step 9: train loss 4.199580 (took 5517.630225 ms)

CFLAGS = -O3 -Ofast -fno-fast-math -Wno-unused-result

step 0: train loss 5.356082 (took 6111.609870 ms)
step 1: train loss 4.300639 (took 5805.856485 ms)
step 2: train loss 4.623087 (took 5695.227162 ms)
step 3: train loss 4.599362 (took 5646.601722 ms)
step 4: train loss 4.616664 (took 5530.085910 ms)
step 5: train loss 4.231427 (took 5462.099164 ms)
step 6: train loss 3.753161 (took 5424.970717 ms)
step 7: train loss 3.650458 (took 5531.754343 ms)
step 8: train loss 4.182242 (took 5513.774748 ms)
step 9: train loss 4.199580 (took 5466.057209 ms)

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 12, 2024

guys wake up, new optimization just dropped,
check [#21], we found a better alternative to -fno-fast-math for almost 2x performance improvements,
try to use:
CFLAGS = -Ofast -fno-finite-math-only -Wno-unused-result

@azret
Copy link
Contributor

azret commented Apr 14, 2024

I'm working on Windows and see the same with the MSVC. I use /O2 and /fp:fast.

I think I've nailed it down to the GELU layer. Something is being over optimized here which results in blown out grads.

So first I've disabled the optimization just there for proof:

#pragma optimize( "", off )
#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)
void gelu_forward(float* out, float* inp, int N) {
    // (approximate) GeLU elementwise non-linearity in the MLP block of Transformer
    for (int i = 0; i < N; i++) {
        float x = inp[i];
        float cube = 0.044715f * x * x * x;
        out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube)));
    }
}

void gelu_backward(float* dinp, float* inp, float* dout, int N) {
    for (int i = 0; i < N; i++) {
        float x = inp[i];
        float cube = 0.044715f * x * x * x;
        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
        float tanh_out = tanhf(tanh_arg);
        float coshf_out = coshf(tanh_arg);
        float sech_out = 1.0f / (coshf_out * coshf_out);
        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
        dinp[i] += local_grad * dout[i];
    }
}
#pragma optimize( "", on )

And it passed.

[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
[State]
batch_size: 4
seq_len: 64
num_activations: 73323776
-43.431702 -43.431740
-39.836426 -39.836460
-43.065937 -43.066002
OK (LOGITS)
LOSS OK: 5.269984 5.270009
step 0: loss 5.269984 (took 10125.000000 ms)
step 1: loss 4.059654 (took 15625.000000 ms)
step 2: loss 3.374920 (took 16672.000000 ms)
step 3: loss 2.800694 (took 12406.000000 ms)
step 4: loss 2.315437 (took 9984.000000 ms)
step 5: loss 1.849182 (took 9969.000000 ms)
step 6: loss 1.394892 (took 10156.000000 ms)
step 7: loss 0.999076 (took 10391.000000 ms)
step 8: loss 0.624470 (took 10391.000000 ms)
step 9: loss 0.376848 (took 11422.000000 ms)
loss ok at step 0: 5.269984 5.270007
loss ok at step 1: 4.059654 4.059707
loss ok at step 2: 3.374920 3.375123
loss ok at step 3: 2.800694 2.800783
loss ok at step 4: 2.315437 2.315382
loss ok at step 5: 1.849182 1.849029
loss ok at step 6: 1.394892 1.394656
loss ok at step 7: 0.999076 0.999147
loss ok at step 8: 0.624470 0.624080
loss ok at step 9: 0.376848 0.376511
overall okay: 1

I've rearranged the code a bit to trick the compiler. Would someone please try on their end see if this solves it for you as well on Linux or Mac. If so we can setup a PR.

#define GELU_SCALING_FACTOR 0.7978845608 // sqrtf(2.0f / M_PI)
void gelu_forward(float* out, float* inp, int N) {
    // (approximate) GeLU elementwise non-linearity in the MLP block of Transformer
    for (int i = 0; i < N; i++) {
        float x = inp[i];
        out[i] = 0.5 * x * (1 + tanhf(x * 0.7978845608 * (1 + 0.044715 * x * x)));
    }
}

float gelu_grad(float x) {
    float square = 0.044715f * x * x;
    float cube = square * x;
    float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
    float tanh_out = tanhf(tanh_arg);
    float coshf_out = coshf(tanh_arg);
    float sech_out = 1.0 / (coshf_out * coshf_out);
    float local_grad = 0.5 * (1.0 + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0 + 3.0 * square);
    return local_grad;
}

void gelu_backward(float* dinp, float* inp, float* dout, int N) {
    for (int i = 0; i < N; i++) {
        dinp[i] += gelu_grad(inp[i]) * dout[i];
    }
}

[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
[State]
batch_size: 4
seq_len: 64
num_activations: 73323776
-43.431702 -43.431694
-39.836426 -39.836422
-43.065937 -43.065956
OK (LOGITS)
LOSS OK: 5.269984 5.270009
step 0: loss 5.269984 (took 10016.000000 ms)
step 1: loss 4.059654 (took 9765.000000 ms)
step 2: loss 3.374920 (took 9656.000000 ms)
step 3: loss 2.800694 (took 9828.000000 ms)
step 4: loss 2.315437 (took 10000.000000 ms)
step 5: loss 1.849182 (took 9968.000000 ms)
step 6: loss 1.394891 (took 9719.000000 ms)
step 7: loss 0.999076 (took 9750.000000 ms)
step 8: loss 0.624470 (took 9781.000000 ms)
step 9: loss 0.376849 (took 10047.000000 ms)
loss ok at step 0: 5.269984 5.270007
loss ok at step 1: 4.059654 4.059707
loss ok at step 2: 3.374920 3.375123
loss ok at step 3: 2.800694 2.800783
loss ok at step 4: 2.315437 2.315382
loss ok at step 5: 1.849182 1.849029
loss ok at step 6: 1.394891 1.394656
loss ok at step 7: 0.999076 0.999147
loss ok at step 8: 0.624470 0.624080
loss ok at step 9: 0.376849 0.376511
overall okay: 1

@azret
Copy link
Contributor

azret commented Apr 14, 2024

Another thing to note is that if compiling without fast fp one of the check tensors fails compared to pytorch version. It's very small delta but def. goes away when compiling with fast fp.

OK (LOGITS)
LOSS OK: 5.269893 5.270009
dln2w: NOT OK
[8053] 0.951992 0.962970, DIFF: 0.010978

step 0: loss 5.269893 (took 24500.000000 ms)

int check_tensor(float* a, float* b, int n, char* label) {
    int print_upto = 5;
    int ok = 1;
    int labelPrinted = 0;
    for (int i = 0; i < n; i++) {
        if (fabsf(a[i] - b[i]) > 1e-2) {
            // only report if mismatch
            if (!labelPrinted) {
                printf("%s: NOT OK\n", label);
                labelPrinted = 1;
            }
            if (print_upto-- > 0) {
                printf("\t[%d] %f %f, DIFF: %f\n", i, a[i], b[i], fabsf(a[i] - b[i]));
            }
            ok = 0;
        }
    }
    return ok;
}

@dagelf
Copy link
Contributor

dagelf commented Apr 16, 2024

@azret These changes make no difference on Linux.

FYI compiling with clang always works, no matter what flags are used, albeit a few % slower as compared to gcc with -fno-finite-math-only

@azret
Copy link
Contributor

azret commented Apr 16, 2024

@azret These changes make no difference on Linux.

FYI compiling with clang always works, no matter what flags are used, albeit a few % slower as compared to gcc with -fno-finite-math-only

Thank you for trying!

#pragma optimize( "", off )
#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)
void gelu_forward(float* out, float* inp, int N) {
    // (approximate) GeLU elementwise non-linearity in the MLP block of Transformer
    for (int i = 0; i < N; i++) {
        float x = inp[i];
        float cube = 0.044715f * x * x * x;
        out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube)));
    }
}

void gelu_backward(float* dinp, float* inp, float* dout, int N) {
    for (int i = 0; i < N; i++) {
        float x = inp[i];
        float cube = 0.044715f * x * x * x;
        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
        float tanh_out = tanhf(tanh_arg);
        float coshf_out = coshf(tanh_arg);
        float sech_out = 1.0f / (coshf_out * coshf_out);
        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
        dinp[i] += local_grad * dout[i];
    }
}
#pragma optimize( "", on )

Thank you! Can you please also try to wrap with #pragma optimize( "", off/on) just this block? If this works for you than we can be sure that it is in fact this function that is being over optimized. It would be of much help as I am trying to get the /fp:fast working on Windows as well.

#pragma optimize( "", off )
void gelu_backward(float* dinp, float* inp, float* dout, int N) {
for (int i = 0; i < N; i++) {
float x = inp[i];
float cube = 0.044715f * x * x * x;
float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
float tanh_out = tanhf(tanh_arg);
float coshf_out = coshf(tanh_arg);
float sech_out = 1.0f / (coshf_out * coshf_out);
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
dinp[i] += local_grad * dout[i];
}
}
#pragma optimize( "", on )

@dagelf
Copy link
Contributor

dagelf commented Apr 17, 2024

Good idea @azret! The optimizations seem to work on clang and elsewhere, just had an issue on gcc.

I looked at this again tonight, and after finding no NaNs returned from any -ffast-math functions through hooking them, and noticing that the error doesn't occur when I logged the return values to a file, and backtracking to only adding a asm volatile ( "nop" ) after tanhf (which invalidates some optimizations)... I found the culprit to be the tanhf function in gelu_backward indeed. So we can still get the speedup of -ffast-math just not for that function!

Someone can still optimize further, I couldn't figure out a way to get all the optimizations except on that single tanhf call, you might have to compile and link an overloaded or custom version of that function separately to do it.

In this process I noticed a lot of 0.0 and 1.0 inputs and outputs to the math functions, which I'm pretty sure they don't optimize for, so more low hanging fruit for CPU (and GPU?) optimization.

@dagelf
Copy link
Contributor

dagelf commented Apr 17, 2024

@jrrk2 re "complexity of achieving Python environment", at risk of stating the obvious, this should work if your Python version is good:

pip install venv # get venv
cd llm.c
python -m venv . # create a local python repository of packages underneath the current directory
source bin/activate # activate the separate environment 
pip install -r requirements.txt # install the required packages and/or versions

Also, surprisingly, Python compiles really easily from source, and is surprisingly small, with minimal dependencies, last I checked. On Linux at least, so worth trying, if just installing a new binary version seems too mundane 😅

@dagelf
Copy link
Contributor

dagelf commented Apr 17, 2024

@ent0n29: You prompted me to look deeper into the compiler flags, I guess this is where you found and tested -ffinite-math-only. Anyways, playing with compiler flags is a pasttime in the Gentoo/Funtoo communities, and I guess other ports based OSes.

~$  gcc -O3 -Q --help=optimizers > a
~$  gcc -Ofast -Q --help=optimizers > b
~$ diff --suppress-common-lines --side-by-side a b

  -fallow-store-data-races              [disabled]            |   -fallow-store-data-races              [enabled]
  -fassociative-math                    [disabled]            |   -fassociative-math                    [enabled]
  -fcx-limited-range                    [disabled]            |   -fcx-limited-range                    [enabled]
  -fexcess-precision=[fast|standard|16]         [default]     |   -fexcess-precision=[fast|standard|16]         fast
  -ffinite-math-only                    [disabled]            |   -ffinite-math-only                    [enabled]
  -fmath-errno                          [enabled]             |   -fmath-errno                          [disabled]
  -freciprocal-math                     [disabled]            |   -freciprocal-math                     [enabled]
  -fsemantic-interposition              [enabled]             |   -fsemantic-interposition              [disabled]
  -fsigned-zeros                        [enabled]             |   -fsigned-zeros                        [disabled]
  -ftrapping-math                       [enabled]             |   -ftrapping-math                       [disabled]
  -funsafe-math-optimizations           [disabled]            |   -funsafe-math-optimizations           [enabled]

TIL 🙄 And of course the last one on the command line takes precedence 😅

The only difference between -O3 -ffast-math and -Ofast:

~$ gcc -O3 -ffast-math -Q --help=optimizers > a
~$ gcc -Ofast -Q --help=optimizers > b
~$ diff --suppress-common-lines --side-by-side a b
  -fallow-store-data-races              [disabled]            |   -fallow-store-data-races              [enabled]
  -fsemantic-interposition              [enabled]             |   -fsemantic-interposition              [disabled]
~$ gcc --help=optimizers |grep allow-store-data-races 
  -fallow-store-data-races    Allow the compiler to introduce new data races on stores.
~$ gcc --help=optimizers |grep semantic-interposition
  -fsemantic-interposition    Allow interposing function (or variables) by ones with different semantics (or initializer) respectively by dynamic linker.

Sounds like -Ofast should in theory actually be faster than -O3. That former one makes me nervous though... maybe at some point we'll have to go back to -O3 -ffast-math 😬

@dagelf
Copy link
Contributor

dagelf commented Apr 17, 2024

The only difference between -O3 -ffast-math and -Ofast:

~$ gcc -O3 -ffast-math -Q --help=optimizers > a
~$ gcc -Ofast -Q --help=optimizers > b
~$ diff --suppress-common-lines --side-by-side a b
  -fallow-store-data-races              [disabled]            |   -fallow-store-data-races              [enabled]
  -fsemantic-interposition              [enabled]             |   -fsemantic-interposition              [disabled]

@ent0n29
Copy link
Contributor

ent0n29 commented Apr 18, 2024

@dagelf yeah lately i'm deep into compilers, writing opt passes directly in the LLVM source code, that's why i opened this too #18 (comment), why use flags like -Ofast or -O3, we write our own optimization in LLVM passes targeting all the functions and operations in train_gpt2.c

@dagelf
Copy link
Contributor

dagelf commented Apr 20, 2024

I made a cleaner PR here: #200

I also worked on some more CPU optimizations, details here #168

@dagelf
Copy link
Contributor

dagelf commented Apr 25, 2024

Lets continue the discussion about performance improvement efforts on CPU here: #253

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet