-
Notifications
You must be signed in to change notification settings - Fork 9.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : rewrite silu and softmax for cpu #7154
Conversation
Not deeply analysing the changes but these are the general observation if it would help other reviewers:
|
On AMD Ryzen 9 5950X and M2 Ultra Using the following command to benchmark: make -j tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf |
I'm glad to hear that. Here's the avx2 and avx512 variations if you want to try them out: inline __m256 llamafile_expf_avx2(__m256 x) {
const __m256 r = _mm256_set1_ps(0x1.8p23f);
const __m256 z = MADD256(x, _mm256_set1_ps(0x1.715476p+0f), r);
const __m256 n = _mm256_sub_ps(z, r);
const __m256 b = NMADD256(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
NMADD256(n, _mm256_set1_ps(0x1.62e4p-1f), x));
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
const __m256 k = _mm256_castsi256_ps(
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
const __m256i c = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(126), _CMP_GT_OQ));
const __m256 u = _mm256_mul_ps(b, b);
const __m256 j = MADD256(MADD256(MADD256(_mm256_set1_ps(0x1.0e4020p-7f), b,
_mm256_set1_ps(0x1.573e2ep-5f)),
u,
MADD256(_mm256_set1_ps(0x1.555e66p-3f), b,
_mm256_set1_ps(0x1.fffdb6p-2f))),
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
return MADD256(j, k, k);
const __m256i g = _mm256_and_si256(
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
_mm256_set1_epi32(0x82000000u));
const __m256 s1 =
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
const __m256i d = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(192), _CMP_GT_OQ));
return _mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
_mm256_andnot_ps(
_mm256_castsi256_ps(d),
_mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(c),
_mm256_mul_ps(MADD256(s2, j, s2), s1)),
_mm256_andnot_ps(_mm256_castsi256_ps(c), MADD256(k, j, k)))));
}
inline __m512 llamafile_expf_avx512(__m512 x) {
const __m512 r = _mm512_set1_ps(0x1.8p23f);
const __m512 z = MADD512(x, _mm512_set1_ps(0x1.715476p+0f), r);
const __m512 n = _mm512_sub_ps(z, r);
const __m512 b = NMADD512(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
NMADD512(n, _mm512_set1_ps(0x1.62e4p-1f), x));
const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
const __m512 k = _mm512_castsi512_ps(
_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
const __mmask16 c =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
const __m512 u = _mm512_mul_ps(b, b);
const __m512 j = MADD512(MADD512(MADD512(_mm512_set1_ps(0x1.0e4020p-7f), b,
_mm512_set1_ps(0x1.573e2ep-5f)),
u,
MADD512(_mm512_set1_ps(0x1.555e66p-3f), b,
_mm512_set1_ps(0x1.fffdb6p-2f))),
u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
if (_mm512_kortestz(c, c))
return MADD512(j, k, k);
const __m512i g = _mm512_and_si512(
_mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
_mm512_set1_epi32(0x82000000u));
const __m512 s1 =
_mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
const __mmask16 d =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
return _mm512_mask_blend_ps(
d,
_mm512_mask_blend_ps(c, MADD512(k, j, k),
_mm512_mul_ps(MADD512(s2, j, s2), s1)),
_mm512_mul_ps(s1, s1));
} Here's the numbers I got with the script I used for developing these functions:
|
@ggerganov Running your command, I'm noticing the advantage here increases from 1.5x to 1.9x if we include AVX2. On znver4 if we also include avx512 then that goes up to 2.1x. I'd expect that to go higher in the future, since znver4 only really implements the AVX512 ISA and uses 2 cycles for each vector operation. So I've gone ahead and included the code for you. |
This change upstreams llamafile's vectorized expf() functions. This lets us compute softmax and silu more accurately than the short[65536] lookup table that GGML previously used to make this operation go faster. We can support aarch64 and sse2+ with the worst case rounding error of 2ulp. It makes make -j8 tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf go 1.5x faster for SSE2+FMA, 1.9x faster for AVX2+FMA and 2.1x on AVX512
With AVX512, you may want to use vscalefps. It overflows and underflows properly, letting you remove checks + blends. I have an implementation in Julia, e.g. a loop with 4x unrolling and interleaving. L304:
vmovups zmm15, zmmword ptr [r11 + 4*rax]
vmovups zmm14, zmmword ptr [r11 + 4*rax + 64]
vmovups zmm13, zmmword ptr [r11 + 4*rax + 128]
vmovups zmm12, zmmword ptr [r11 + 4*rax + 192]
vmovaps zmm16, zmm1
vfmadd213ps zmm16, zmm15, zmm0 # zmm16 = (zmm15 * zmm16) + zmm0
vmovaps zmm17, zmm1
vfmadd213ps zmm17, zmm14, zmm0 # zmm17 = (zmm14 * zmm17) + zmm0
vmovaps zmm18, zmm1
vfmadd213ps zmm18, zmm13, zmm0 # zmm18 = (zmm13 * zmm18) + zmm0
vmovaps zmm19, zmm1
vfmadd213ps zmm19, zmm12, zmm0 # zmm19 = (zmm12 * zmm19) + zmm0
vaddps zmm16, zmm16, zmm2
vaddps zmm17, zmm17, zmm2
vaddps zmm18, zmm18, zmm2
vaddps zmm19, zmm19, zmm2
vfmadd231ps zmm15, zmm16, zmm3 # zmm15 = (zmm16 * zmm3) + zmm15
vfmadd231ps zmm14, zmm17, zmm3 # zmm14 = (zmm17 * zmm3) + zmm14
vfmadd231ps zmm13, zmm18, zmm3 # zmm13 = (zmm18 * zmm3) + zmm13
vfmadd231ps zmm12, zmm19, zmm3 # zmm12 = (zmm19 * zmm3) + zmm12
vfmadd231ps zmm15, zmm16, zmm4 # zmm15 = (zmm16 * zmm4) + zmm15
vfmadd231ps zmm14, zmm17, zmm4 # zmm14 = (zmm17 * zmm4) + zmm14
vfmadd231ps zmm13, zmm18, zmm4 # zmm13 = (zmm18 * zmm4) + zmm13
vfmadd231ps zmm12, zmm19, zmm4 # zmm12 = (zmm19 * zmm4) + zmm12
vmovaps zmm20, zmm6
vfmadd213ps zmm20, zmm15, zmm5 # zmm20 = (zmm15 * zmm20) + zmm5
vmovaps zmm21, zmm6
vfmadd213ps zmm21, zmm14, zmm5 # zmm21 = (zmm14 * zmm21) + zmm5
vmovaps zmm22, zmm6
vfmadd213ps zmm22, zmm13, zmm5 # zmm22 = (zmm13 * zmm22) + zmm5
vmovaps zmm23, zmm6
vfmadd213ps zmm23, zmm12, zmm5 # zmm23 = (zmm12 * zmm23) + zmm5
vfmadd213ps zmm20, zmm15, zmm7 # zmm20 = (zmm15 * zmm20) + zmm7
vfmadd213ps zmm21, zmm14, zmm7 # zmm21 = (zmm14 * zmm21) + zmm7
vfmadd213ps zmm22, zmm13, zmm7 # zmm22 = (zmm13 * zmm22) + zmm7
vfmadd213ps zmm23, zmm12, zmm7 # zmm23 = (zmm12 * zmm23) + zmm7
vfmadd213ps zmm20, zmm15, zmm8 # zmm20 = (zmm15 * zmm20) + zmm8
vfmadd213ps zmm21, zmm14, zmm8 # zmm21 = (zmm14 * zmm21) + zmm8
vfmadd213ps zmm22, zmm13, zmm8 # zmm22 = (zmm13 * zmm22) + zmm8
vfmadd213ps zmm23, zmm12, zmm8 # zmm23 = (zmm12 * zmm23) + zmm8
vfmadd213ps zmm20, zmm15, zmm9 # zmm20 = (zmm15 * zmm20) + zmm9
vfmadd213ps zmm21, zmm14, zmm9 # zmm21 = (zmm14 * zmm21) + zmm9
vfmadd213ps zmm22, zmm13, zmm9 # zmm22 = (zmm13 * zmm22) + zmm9
vfmadd213ps zmm23, zmm12, zmm9 # zmm23 = (zmm12 * zmm23) + zmm9
vfmadd213ps zmm20, zmm15, zmm10 # zmm20 = (zmm15 * zmm20) + zmm10
vfmadd213ps zmm21, zmm14, zmm10 # zmm21 = (zmm14 * zmm21) + zmm10
vfmadd213ps zmm22, zmm13, zmm10 # zmm22 = (zmm13 * zmm22) + zmm10
vfmadd213ps zmm23, zmm12, zmm10 # zmm23 = (zmm12 * zmm23) + zmm10
vfmadd213ps zmm20, zmm15, zmm11 # zmm20 = (zmm15 * zmm20) + zmm11
vfmadd213ps zmm21, zmm14, zmm11 # zmm21 = (zmm14 * zmm21) + zmm11
vfmadd213ps zmm22, zmm13, zmm11 # zmm22 = (zmm13 * zmm22) + zmm11
vfmadd213ps zmm23, zmm12, zmm11 # zmm23 = (zmm12 * zmm23) + zmm11
vfmadd213ps zmm20, zmm15, zmm11 # zmm20 = (zmm15 * zmm20) + zmm11
vfmadd213ps zmm21, zmm14, zmm11 # zmm21 = (zmm14 * zmm21) + zmm11
vfmadd213ps zmm22, zmm13, zmm11 # zmm22 = (zmm13 * zmm22) + zmm11
vfmadd213ps zmm23, zmm12, zmm11 # zmm23 = (zmm12 * zmm23) + zmm11
vscalefps zmm12, zmm20, zmm16, {rn-sae}
vscalefps zmm13, zmm21, zmm17, {rn-sae}
vscalefps zmm14, zmm22, zmm18, {rn-sae}
vscalefps zmm15, zmm23, zmm19, {rn-sae}
vmovups zmmword ptr [r14 + 4*rax], zmm12
vmovups zmmword ptr [r14 + 4*rax + 64], zmm13
vmovups zmmword ptr [r14 + 4*rax + 128], zmm14
vmovups zmmword ptr [r14 + 4*rax + 192], zmm15
add rax, 64
cmp rax, r10
jl L304 These gave me a significant performance improvement. What hardware are you on? I'm using skylake-avx512/cascadelake with 2x fma units. Note that it doesn't use a lookup table. |
After this PR has been merged the server has been producing nondeterministic results when using >1 slots. Minimal example for reproduction: make clean && make server
./server -m models/opt/llama_2-7b-q4_0.gguf --parallel 2 --threads 1 In another shell: curl --request POST --url http://localhost:8080/completion --header "Content-Type: application/json" --data '{"prompt": "", "n_predict":10, "n_probs": 2, "temperature": -1}' | python3 -m json.tool The token probabilities for the last token cycle between two values with every |
@chriselrod Could you help me modify my avx512 intrinsics to use _mm512_scalef_ps (vscalefps) like your code? I'm currently talking to ARM Limited about getting these functions into Glibc, since our code goes faster. ARM-software/optimized-routines#69 |
@jart Sure. If it helps, I just wrote a C++ implementation you can look at here: The README contains benchmarks. I haven't done much analysis other than a glance to see that unrolling w/ interleaving and especially larger SIMD vectors boost performance on smaller array sizes that fit in cache on my desktop. The basic alg for
So, with that math in mind, the basic algorithm is:
|
Perhaps I made a mistake when merging it in. I will look through it again. Thanks. |
@ggerganov @jart @leejet Not entirely sure why, but I found the cause! The reason is: I have been building with Before this, it has been working fine for me thus far, everything else (except CUDA) works fine with Switching back to Also - I did dig a bit further into
versions... but if I force a no intrinsic version using only the pure f32 CPU
then everything works well even with Building directly from sd.cpp's cmake would by default pick This is pretty interesting - do you think you could get it working with Or maybe there's a more elegant solution? |
If you want to dig into this more, look at the GCC compiler flags enabled by |
I've narrowed it down to one of the aspects of |
Okay! I think I nailed down the problematic flag that causes this PR to break, it is I guess it could get rid of that flag... but I am curious if it's something that might be solvable, considering all other ops are fine. |
What could be happening is that the exponential function in SiLU, instead of flushing small values to 0, returns NaN or some other garbage. I've essentially had this same issue in CUDA for the exponentiation in FlashAttention. I fixed it by explicitly zeroing the exponentials of all values < -20.0f since those are going to be < 2e-9 and therefore negligible anyways. The relevant FP32 code looks something like this:
|
…to work if desired
The |
In my particular case (CUDA FP16 FlashAttention) I'm not sure the issue is fixable by a compiler flag; The NVCC documentation does not seem to mention any equivalent flag. I have not actually encountered the issue for FP32 precision but I still implemented the same flush-to-zero behavior for consistency. I generally agree that just not using |
Yeah now that I know what's wrong I'd probably just not use that flag. I'm actually more surprised that there have been no (apparent) issues with finite math prior to this one - so I guess 'what we stand to gain' is simply retaining existing compatibility with In any case I think we can consider the issue resolved then. |
Your accuracy tests run almost 4x faster walltime and 1.5x less CPU time under
Note how Denormals are extremely slow on Intel hardware (I believe the penalty is much lower on AMD). To confirm, I built clang-19 from source, as it added the
I see that essentially the entirety of the performance improvement in that benchmark comes flushing denormals to 0. @LostRuins it would be nice if you could pin down what is causing the performance improvements you see. |
I concur. I tested every single one of the #ifdef __x86_64__
//
// Enable hardware optimizations in violation of the IEEE standard.
//
// - 0x0040 enables "DAZ: Denormals Are Zeros" in MXCSR. This causes the
// processor to turn denormal inputs into zero, before computing them.
// See Intel Manual Vol. 1 §10.2.3.4
//
// - 0x8000 enables "FTZ: Flush To Zero" in MXCSR. This means a floating
// point operation that results in underflow will be set to zero, with
// the same sign, rather than producing a denormalized output. It will
// happen only if underflow trapping hasnt been enabled. See the Intel
// Manual Vol. 1 §10.2.3.3.
//
unsigned mxcsr;
asm("stmxcsr\t%0" : "=m"(mxcsr));
mxcsr |= 0x8040;
asm("ldmxcsr\t%0" : /* no inputs */ : "m"(mxcsr));
#endif Then I saw the same 20% improvement. So I think what was happening is the That doesn't mean this will help with inference. I've yet to find an LLM that underflows enough where FTZ will matter. For example, if I ask Mistral 7b v0.3 to process a 215 token prompt, then the process underflow 4500 times. No difference in clock() time or latency. |
Can we enforce a compile error if |
@ggerganov I think there is a flag |
I was thinking a change in the source code rather - the build system is not standardised, so nothing prevents 3rd party projects from building with |
Yes both GCC and Clang define #if __FINITE_MATH_ONLY__
#error "llama.cpp needs infinity (try passing -fno-finite-math-only)"
#endif So saying the above somewhere is what I would do. If someone figures out a place where |
Yeah that would work too. Probably better than just failing to compile completely |
This enforces a check that -fno-finite-math-only was set and that the operating compiling mode is not in finite maths mode. This is because during rewriting of silu and softmax for cpu #7154 there emerged an issue where the result that was observed when >1 slot was nondeterministic as found by @JohannesGaessler. @LostRuins narrowed the problem down to -ffinite-math-only which was theorised to be due to SiLU, instead of flushing small values to 0, returns NaN or some other garbage. @jart proposed a fix that @ggerganov then implemented in this fix ref #7154 (comment)
== Relevant log messages from source repo: commit 6d1616944d9efd342ed2a4fd318722adfc9febcd Author: Georgi Gerganov <ggerganov@gmail.com> Date: Tue Jun 4 10:01:09 2024 +0300 ggml : prevent builds with -ffinite-math-only (#7726) This enforces a check that -fno-finite-math-only was set and that the operating compiling mode is not in finite maths mode. This is because during rewriting of silu and softmax for cpu #7154 there emerged an issue where the result that was observed when >1 slot was nondeterministic as found by @JohannesGaessler. @LostRuins narrowed the problem down to -ffinite-math-only which was theorised to be due to SiLU, instead of flushing small values to 0, returns NaN or some other garbage. @jart proposed a fix that @ggerganov then implemented in this fix ref ggerganov/llama.cpp#7154 (comment)
This enforces a check that -fno-finite-math-only was set and that the operating compiling mode is not in finite maths mode. This is because during rewriting of silu and softmax for cpu #7154 there emerged an issue where the result that was observed when >1 slot was nondeterministic as found by @JohannesGaessler. @LostRuins narrowed the problem down to -ffinite-math-only which was theorised to be due to SiLU, instead of flushing small values to 0, returns NaN or some other garbage. @jart proposed a fix that @ggerganov then implemented in this fix ref ggerganov/llama.cpp#7154 (comment)
This enforces a check that -fno-finite-math-only was set and that the operating compiling mode is not in finite maths mode. This is because during rewriting of silu and softmax for cpu #7154 there emerged an issue where the result that was observed when >1 slot was nondeterministic as found by @JohannesGaessler. @LostRuins narrowed the problem down to -ffinite-math-only which was theorised to be due to SiLU, instead of flushing small values to 0, returns NaN or some other garbage. @jart proposed a fix that @ggerganov then implemented in this fix ref ggerganov/llama.cpp#7154 (comment)
This enforces a check that -fno-finite-math-only was set and that the operating compiling mode is not in finite maths mode. This is because during rewriting of silu and softmax for cpu #7154 there emerged an issue where the result that was observed when >1 slot was nondeterministic as found by @JohannesGaessler. @LostRuins narrowed the problem down to -ffinite-math-only which was theorised to be due to SiLU, instead of flushing small values to 0, returns NaN or some other garbage. @jart proposed a fix that @ggerganov then implemented in this fix ref ggerganov/llama.cpp#7154 (comment)
This enforces a check that -fno-finite-math-only was set and that the operating compiling mode is not in finite maths mode. This is because during rewriting of silu and softmax for cpu #7154 there emerged an issue where the result that was observed when >1 slot was nondeterministic as found by @JohannesGaessler. @LostRuins narrowed the problem down to -ffinite-math-only which was theorised to be due to SiLU, instead of flushing small values to 0, returns NaN or some other garbage. @jart proposed a fix that @ggerganov then implemented in this fix ref ggerganov/llama.cpp#7154 (comment)
In our experience for CPU testing, we have been able to use In addition to This is all based on the strict testing and scrutiny that whisper-cpp and llama-cpp are under as candidate benchmarks for SPEC CPUv8, using the generic code paths for CPU. |
@heshpdx There are much better ways to make whisper.cpp go faster than using
That first one in particular is probably going to have extreme impact on your 512 core CPUs. |
Here's another one:
|
@jart Thanks for the links. We did know about the thread scaling issue, and we've capped the benchmark to use 32 threads as it showed the best performance across a variety of systems (ISA/compilers) with minimal locking overhead. I'll check the other links, but we are only benchmarking inference. I don't like I'm speaking while wearing my SPEC CPU hat, not my Ampere hat :-) |
This enforces a check that -fno-finite-math-only was set and that the operating compiling mode is not in finite maths mode. This is because during rewriting of silu and softmax for cpu #7154 there emerged an issue where the result that was observed when >1 slot was nondeterministic as found by @JohannesGaessler. @LostRuins narrowed the problem down to -ffinite-math-only which was theorised to be due to SiLU, instead of flushing small values to 0, returns NaN or some other garbage. @jart proposed a fix that @ggerganov then implemented in this fix ref ggerganov/llama.cpp#7154 (comment)
This enforces a check that -fno-finite-math-only was set and that the operating compiling mode is not in finite maths mode. This is because during rewriting of silu and softmax for cpu #7154 there emerged an issue where the result that was observed when >1 slot was nondeterministic as found by @JohannesGaessler. @LostRuins narrowed the problem down to -ffinite-math-only which was theorised to be due to SiLU, instead of flushing small values to 0, returns NaN or some other garbage. @jart proposed a fix that @ggerganov then implemented in this fix ref ggerganov/llama.cpp#7154 (comment)
This change upstreams llamafile's vectorized expf() functions. This lets us compute softmax and silu more accurately than the short[65536] lookup table that GGML previously used to make this operation go faster. We can support aarch64 and sse2+ with the worst case rounding error of 2 ulp. I wrote avx2 and avx512 implementations as well but they didn't offer much advantage compared to sse2+fma to be worth the code complexity.