Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVX2 is dimwitted compared to AVX512 #23

Open
jart opened this issue Feb 23, 2024 · 8 comments
Open

AVX2 is dimwitted compared to AVX512 #23

jart opened this issue Feb 23, 2024 · 8 comments
Labels
type:bug Something isn't working

Comments

@jart
Copy link

jart commented Feb 23, 2024

On a $10,000 AMD Ryzen 7995WX (znver4 avx512) Gemma 7b instruct sfp is able to solve mathematical riddles.

image

But on a $600 Intel i9-14900K (raptorlake avx2) the same Gemma model gives the fool's answer.

image

I expected both machines to produce an identical response since I set the temperature to zero. However the behavior of gemma.cpp appears to differ in a pernicious way depending on the ISA. It'd be great if people without AVX512 privilege could experience the same level of impressive brilliance from Gemma that I'm seeing on my Threadripper.

@jan-wassenberg
Copy link
Member

Interesting, thanks for making us aware. I see that the Highway targets used are AVX3_ZEN4 vs AVX2. The likeliest cause that comes to mind is native bf16 in the former, whereas we are using emulated bf16 with truncation in the latter.

google/highway#1962 changes to proper rounding, but unfortunately merging is delayed due to a compiler bug/crash. Would appreciate if you could test with that patched in, and/or after it lands :)

@jart
Copy link
Author

jart commented Feb 24, 2024

I changed CMakeLists.txt to have FetchContent_Declare(highway GIT_REPOSITORY https://github.com/johnplatts/jep_google_highway GIT_TAG 9626396e4a80e2a0c0dec24c2e5927279a4fb3ff) and rebuilt gemma.cpp. I asked it the same question and got exactly the same answer.

@jan-wassenberg
Copy link
Member

Bummer, thanks for confirming. I also tried with AVX3 (Skylake, so no native bf16) and got the better answer.
It's not clear to me at the moment what else it could be, will continue to think about it :)

@jart
Copy link
Author

jart commented Feb 24, 2024

What I want to do is is add code to the end of your MatVec function https://github.com/google/highway/blob/master/hwy/contrib/matvec/matvec-inl.h that serializes the output matrix to disk as an array of floats. I would then write a program that does a lockstep comparison, to get a better idea of what's going wrong and where it's going wrong. If you can explain to me how to turn T* HWY_RESTRICT out into an array of floats then I'll do this.

jart added a commit to jart/gemma3 that referenced this issue Feb 25, 2024
@jart
Copy link
Author

jart commented Feb 25, 2024

Reading your codebase has been a fun learning experience so far. I think your trick for supporting multiple microarchitectures by having a file repeatedly #include itself is quite possibly the most wickedly cool hack since CRTP. Your libm replacement functions look interesting and could potentially benefit Cosmopolitan Libc. I'd like to see if how well they perform compared to the Sun/FreeBSD/Musl/OpenBSD/ARM code we're currently using.

Anyway here's my first attempt at analyzing what's different about the data under avx512 versus avx2: https://github.com/jart/gemma3/blob/main/report1.txt So far they appear to be somewhat different, although there's still numerous things I need to confirm to make sure I'm measuring this right. I'm still in the process of understanding, but I'll post updates here as I learn more.

@jan-wassenberg
Copy link
Member

Great idea! I very much appreciate you looking into this. To go from T* to float, you can call the following:

template <typename MatT, size_t kCapacity, typename OutT>
HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
                           size_t compressed_ofs, OutT* out, size_t num,
                           hwy::ThreadPool& pool) {

MatT is your T (eg SfpStream), kCapacity is an upper bound on how many, compressed is a thin wrapper over std::array, let compressed_ofs = 0, out is your float* and num how many to actually decompress.
You can pass through our ThreadPool or create a new ThreadPool(0).

Reading your codebase has been a fun learning experience so far. I think your trick for supporting multiple microarchitectures by having a file repeatedly #include itself is quite possibly the most wickedly cool hack since CRTP.

Thank you :) Having once fiddled with PE internals, I also respect what you have achieved with the single portable binary :)

I suspect many libm functions are based on Cephes which is quite old and might benefit from a redesign.
You might be interested in google/highway#1650 which compares our libm with SLEEF. The latter is generally more accurate, but can be considerably slower.

BTW you can generate AVX2 outputs on a newer machine by calling hwy::DisableTargets(HWY_AVX2 - 1) in main() or before the first dispatch.

I just had an idea: it might not be the instructions (you ruled out BF16 already), but also the vector length. More per-lane accumulators can change the numerics.
We can test this by using half-length vectors also in AVX3.

In gemma.cc there is one using DF = hn::ScalableTag<float>;, and in ops.h a bunch of const hn::ScalableTag<float> df; plus three using D = hn::ScalableTag<float>;
We can wrap all the hn::ScalableTag<T> in hn::Half<hn::ScalableTag<T>>. If you'd like to make this easy to toggle, you could add a template-typedef if you like.

Your results do not necessarily look like destructive cancellation, though:
trace/000_000_00021.dat: sad 0.796044 [gold -51.4416 .. 43.7024] [out -29.2457 .. 26.0193]
That's a huge difference. But I've trawled through the AVX3-specific parts and do not see anything that could cause such divergence :/

@austinvhuang austinvhuang added the type:bug Something isn't working label Feb 26, 2024
@jan-wassenberg
Copy link
Member

An idea: I notice some of the lines in your output file have a low discrepancy, so it's not just a case of accumulating over time. It may be helpful to segregate by call site, i.e., which MatVec, to understand which are more sensitive/broken.

Is it feasible to move your logging to the call site, or should we pass through some kind of caller/line number into MatVec itself?

@jan-wassenberg
Copy link
Member

An update: even with CoT prompting (append "Think step by step and check your work"), we're currently seeing the incorrect 15 days also with AVX3. I plan to experiment with higher-precision arithmetic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants