Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix conversion of unnormalized BF16->BF16 weights #7843

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

CISC
Copy link
Contributor

@CISC CISC commented Jun 10, 2024

If source model tensors were unnormalized BF16 weights and you converted with outtype BF16, the target GGUF weights would deviate from source weights due to intermediate transition to FP32 (for numpy) and flushing/rounding to BF16.

This PR simply truncates the FP32 if the source is BF16.

Also fixes implicit upcasting and signed shifting in __compute_fp32_to_bf16, which fails in latest numpy.

Fixes #8147

@github-actions github-actions bot added the python python script changes label Jun 10, 2024
@mofosyne mofosyne added the Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix label Jun 12, 2024
@compilade
Copy link
Collaborator

compilade commented Jun 14, 2024

If source model tensors were unnormalized BF16 weights and you converted with outtype BF16, the target GGUF weights would deviate from source weights due to intermediate transition to FP32 (for numpy) and flushing/rounding to BF16.

Doing this would make the result of convert-hf-to-gguf.py --outtype bf16 different from the output of

$ ./quantize ggml-model-f32.gguf ggml-model-bf16.gguf bf16

for these models which have subnormal1 bf16 weights, no?

The goal of #7158 was to make the two paths result in exactly the same model files.

I'd be curious to find a model affected by this. I think I'll try and compare. (EDIT: oh, right in #7825 you mention Qwen2-57B-A14B)

@jart, you might want to chime in, since you likely have more background on why subnormals are flushed to zero in ggml_compute_fp32_to_bf16:

llama.cpp/ggml-impl.h

Lines 98 to 101 in 172c825

if (!(u.i & 0x7f800000)) { /* subnormal */
h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
return h;
}

Footnotes

  1. by "subnormal", I mean https://en.wikipedia.org/wiki/Subnormal_number, which are very, very close to zero, but you know this.

@CISC
Copy link
Contributor Author

CISC commented Jun 14, 2024

Doing this would make the result of convert-hf-to-gguf.py --outtype bf16 different from the output of

$ ./quantize ggml-model-f32.gguf ggml-model-bf16.gguf bf16

Ah, didn't realize that was an option. Yeah, if the F32 originates from BF16 they would be different.

@jart, you might want to chime in, since you likely have more background on why subnormals are flushed to zero in ggml_compute_fp32_to_bf16

You would have to do this if the lower 16 bits are non-zero at least, but in the case of BF16->F32->BF16 they never are.

@CISC
Copy link
Contributor Author

CISC commented Jun 14, 2024

I found this online tool very handy while diagnosing F32 bits btw:
https://www.h-schmidt.net/FloatConverter/IEEE754.html

@compilade
Copy link
Collaborator

why subnormals are flushed to zero in ggml_compute_fp32_to_bf16

You would have to do this if the lower 16 bits are non-zero at least, but in the case of BF16->F32->BF16 they never are.

The lower 16 bits don't matter for this, only for rounding (which indeed changes nothing in the BF16->F32->BF16 round-trip). Subnormals are when the exponent bits are 00000000. Only then is the mantissa zeroed when flushing subnormals to zero. This is what simple truncation is avoiding, which differs from the behavior of ggml_compute_fp32_to_bf16.

I found this online tool very handy while diagnosing F32 bits btw: https://www.h-schmidt.net/FloatConverter/IEEE754.html

Nice! There's also https://float.exposed/, which supports f16, bf16, f32 and f64.

@CISC
Copy link
Contributor Author

CISC commented Jun 14, 2024

You would have to do this if the lower 16 bits are non-zero at least, but in the case of BF16->F32->BF16 they never are.
The lower 16 bits don't matter for this, only for rounding (which indeed changes nothing in the BF16->F32->BF16 round-trip). Subnormals are when the exponent bits are 00000000. Only then is the mantissa zeroed when flushing subnormals to zero. This is what simple truncation is avoiding, which differs from the behavior of ggml_compute_fp32_to_bf16.

Oh, errr, I think I see the problem now when looking more closely at ggml_compute_fp32_to_bf16, notice the returns... :)

@compilade
Copy link
Collaborator

Oh, errr, I think I see the problem now when looking more closely at ggml_compute_fp32_to_bf16, notice the returns... :)

llama.cpp/ggml-impl.h

Lines 87 to 104 in 172c825

static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
ggml_bf16_t h;
union {
float f;
uint32_t i;
} u;
u.f = s;
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
h.bits = (u.i >> 16) | 64; /* force to quiet */
return h;
}
if (!(u.i & 0x7f800000)) { /* subnormal */
h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
return h;
}
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
return h;
}

Yep, lots of returns. And the Numpy implementation still behaves in the same way despite it not early-returning, because NANs are not subnormals and their lower-bits are cleared to avoid rounding.

def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
n = n.astype(np.float32, copy=False).view(np.int32)
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
# flush subnormals to zero
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
# round to nearest even
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
return n.astype(np.int16)

If it was not working exactly in the same way, then the checksums in #7158 and #7234 would not match (but they all do).

At least going the other way (from bfloat16 to float32) is much simpler (which is good for inference):

llama.cpp/ggml-impl.h

Lines 71 to 78 in 172c825

static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h.bits << 16;
return u.f;
}

@CISC
Copy link
Contributor Author

CISC commented Jun 14, 2024

Only problem with that is that you are right-shifting a possible negatively signed integer here:

n = (n + (0x7fff + ((n >> 16) & 1))) >> 16

Actually, just tested, and it seems numpy discards the sign and converts to int64 when you do n & 0x80000000, giving you 2147483648! So, the result of n = (n + (0x7fff + ((n >> 16) & 1))) >> 16 will be 0 regardless of sign.

Sigh, it does the same if n is uint32 too, it seems you have to do & np.uint32(0x80000000) and n must be uint32 too!

The same with n & 0xffff0000, but & 1 is fine, so it's just when you mask the sign-bit.

@CISC
Copy link
Contributor Author

CISC commented Jun 14, 2024

@compilade So, an alternate fix would be to use uint32/uint16 and masking with uint32 values, then everything should be ok (disregaring the flushing), though unnecessarily complex for BF16->BF16.

Any opinion if I should do that and remove the simpler truncation, or if it's worth having both?

@compilade
Copy link
Collaborator

Only problem with that is that you are right-shifting a possible negatively signed integer [...]

Actually, just tested, and it seems numpy discards the sign and converts to int64 when you do n & 0x80000000, giving you 2147483648! So, the result of n = (n + (0x7fff + ((n >> 16) & 1))) >> 16 will be 0 regardless of sign.

@CISC Huh, you're right about the int64 implicit conversion!. But the result of the rounding is still correct. Proof:

import numpy as np

rand = np.random.Generator(np.random.PCG64(seed=42))

n = rand.standard_normal((3, 4), dtype=np.float32)

print("n:")
print(n, n.dtype)

n = n.view(np.int32)

print("n:")
print(n, n.dtype)

sign = n & 0x80000000

print("sign:")
print(sign, sign.dtype)

# round to nearest even 
rounded = ((n + (0x7fff + ((n >> 16) & 1))) >> 16).astype(np.int16)
rounded_sign = ((sign + (0x7fff + ((sign >> 16) & 1))) >> 16).astype(np.int16)

print("rounded:")
print(rounded, rounded.dtype)

print("rounded_sign:")
print(rounded_sign, rounded_sign.dtype)

Output:

n:
[[ 0.14190717 -1.6685079  -1.332108    0.58255345]
 [ 0.3740366   1.7554446   0.13962899  0.41919115]
 [-0.52172035 -0.1815749  -0.85166216  2.568927  ]] float32
n:
[[ 1041322013 -1076522581 -1079344508  1058349625]
 [ 1052737978  1071690345  1041169127  1054253113]
 [-1090154633 -1103499579 -1084619128  1076128077]] int32
sign:
[[         0 2147483648 2147483648          0]
 [         0          0          0          0]
 [2147483648 2147483648 2147483648          0]] int64
rounded:
[[ 15889 -16426 -16469  16149]
 [ 16064  16353  15887  16087]
 [-16634 -16838 -16550  16420]] int16
rounded_sign:
[[     0 -32768 -32768      0]
 [     0      0      0      0]
 [-32768 -32768 -32768      0]] int16

So, an alternate fix would be to use uint32/uint16 and masking with uint32 values, then everything should be ok (disregaring the flushing), though unnecessarily complex for BF16->BF16.

I agree that the C version is using uint32_t, so this is what the Numpy version ought to use too, even if it happened to result in the correct thing with the other types.

Your suggestion does make the implementation more sane, but it won't change anything about the resulting bits in the model files (if truncation isn't used). (so it's not really "an alternate fix", it's more like a refactor)

If the previous implementation didn't result in the exact same bits as the C version, I would definitely have noticed in #7234.

@CISC
Copy link
Contributor Author

CISC commented Jun 14, 2024

Yeah, it seems to work out correctly in the end except when doing the flushing, the question then is just if we should flush to zero? If it's from an actual F32 it might make sense to do so, but it seems wrong to do it for BF16->BF16...

@compilade
Copy link
Collaborator

Yeah, it seems to work out correctly in the end except when doing the flushing

(emphasis mine)

What is it doing incorrectly? From my testing, flushing to zero keeps the sign in the Numpy implementation (even with wrong implicit conversions), as with the original C code it was based on. Or is it that you consider flushing to be incorrect in itself?

(to be clear I agree with fixing the types in the Numpy implementation of bfloat16 conversion)

the question then is just if we should flush to zero? If it's from an actual F32 it might make sense to do so,

@CISC I've searched deeper, since you made me curious.

From https://en.wikipedia.org/wiki/Subnormal_number#Performance_issues, calculations with subnormals can be slower, and from https://cloud.google.com/tpu/docs/bfloat16#details_about_format_conversion, Google Cloud TPUs don't support subnormals in bfloat16.

From https://developer.arm.com/documentation/ddi0602/2023-06/Shared-Pseudocode/shared-functions-float?lang=en#impl-shared.FPConvertBF.2, it seems Aarch64 processors do flush subnormals to zero when in altfp mode (no idea what it is).

    boolean altfp = HaveAltFP() && !UsingAArch32() && fpcr.AH == '1';
    boolean fpexc = !altfp;                         // Generate no floating-point exceptions
    if altfp then fpcr. = '11';             // Flush denormal input and output to zero
    if altfp then rounding = FPRounding_TIEEVEN;    // Use RNE rounding mode

    // Unpack floating-point operand, with always flush-to-zero if fpcr.AH == '1'.
    (fptype,sign,value) = FPUnpack(op, fpcr, fpexc);

And regarding AMD and Intel, https://www.felixcloutier.com/x86/vcvtne2ps2bf16 is the instruction referred to by the comment above ggml_compute_fp32_to_bf16, and it does also flush subnormals to zero.

llama.cpp/ggml-impl.h

Lines 83 to 85 in 172c825

* This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
* Subnormals shall be flushed to zero, and NANs will be quiet.
* This code should vectorize nicely if using modern compilers.

but it seems wrong to do it for BF16->BF16...

Yeah, I don't really know either. On the one hand, it makes all bf16 values in GGUF models guaranteed to be directly usable on all platforms, and on the other hand, it discards some (very small) values from the original weights (though I'm not sure how useful these actually are).

From #7825 (comment):

It doesn't really fix the Qwen2 issues, no, I mainly found these because BF16->BF16 flushed weights to 0, however once you fix that suml2 will occasionally be NaN and you will have division-by-nan instead. :P

Maybe subnormal values are problematic anyway, so it's fine to flush them to zero?

@CISC
Copy link
Contributor Author

CISC commented Jun 15, 2024

What is it doing incorrectly? From my testing, flushing to zero keeps the sign in the Numpy implementation (even with wrong implicit conversions), as with the original C code it was based on. Or is it that you consider flushing to be incorrect in itself?

Yeah, as it changes the original values.

From https://en.wikipedia.org/wiki/Subnormal_number#Performance_issues, calculations with subnormals can be slower, and from https://cloud.google.com/tpu/docs/bfloat16#details_about_format_conversion, Google Cloud TPUs don't support subnormals in bfloat16.

Ok, that's interesting.

but it seems wrong to do it for BF16->BF16...
Yeah, I don't really know either. On the one hand, it makes all bf16 values in GGUF models guaranteed to be directly usable on all platforms, and on the other hand, it discards some (very small) values from the original weights (though I'm not sure how useful these actually are).

They are probably insignificant, however see below...

From #7825 (comment):

It doesn't really fix the Qwen2 issues, no, I mainly found these because BF16->BF16 flushed weights to 0, however once you fix that suml2 will occasionally be NaN and you will have division-by-nan instead. :P

Maybe subnormal values are problematic anyway, so it's fine to flush them to zero?

They are indeed problematic, however it's not just the actual subnormals, but the ones that turn subnormal as F16, these are the ones that mess up certain quants and inference when doing KQ multiplication in F16.

@CISC
Copy link
Contributor Author

CISC commented Jun 27, 2024

@compilade So, looks like the implicit upcasting and signed shifting fails with numpy 2.x, see referenced issue. That means we should at least merge that part of this PR. IMO we should also merge the truncation part, but that's up to you.

BTW, I'm leaving for my annual European tour tomorrow, so will likely not be available to do any changes for a while, but I'll grant you access to the branch just in case.

@jart
Copy link
Contributor

jart commented Jun 27, 2024

@jart, you might want to chime in, since you likely have more background on why subnormals are flushed to zero in ggml_compute_fp32_to_bf16

@compilade (1) it makes things faster, and (2) I wasn't able to determine a flawless elegant performant way to handle them in the short amount of time I had to focus on this. That function actually fully vectorizes from vanilla C. I know it's possible to make it fast and support subnormals too. I think it's something that should be pursued and is worthwhile to pursue. It's even more important to do in the conversion / quantization tools. I'll put some time aside today to see what I can do to help.

@jart
Copy link
Contributor

jart commented Jun 27, 2024

OK so I looked into this. When I wrote ggml_compute_fp32_to_bf16 my goal was simply to make it behave identically to the AMD Zen4 implementation of VCVTNEPS2BF16, since that was the platform I was focusing on optimizing. Flushing to zero is normally done for performance reasons. I assumed AMD was doing that here. However that I might not be true.

So I asked Claude what to do and I got a very silly answer involving __builtin_clz(). I looked at the TensorFlow codebase. Since who would know better how brain16 works than Google Brain? Turns out they just use memcpy() when turning float32 into bfloat16. A better codebase from Google Brain to get answers on C++ coding would be Highway, which has this:

// Returns the increment to add to the bits of a finite F32 value to round a                                                                                                     
// finite F32 to the nearest BF16 value                                                                                                                                          
static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint32_t F32BitsToBF16RoundIncr(
    const uint32_t f32_bits) {
  return static_cast<uint32_t>(((f32_bits & 0x7FFFFFFFu) < 0x7F800000u)
                                   ? (0x7FFFu + ((f32_bits >> 16) & 1u))
                                   : 0u);
}

// Converts f32_bits (which is the bits of a F32 value) to BF16 bits,                                                                                                            
// rounded to the nearest F16 value                                                                                                                                              
static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint16_t F32BitsToBF16Bits(
    const uint32_t f32_bits) {
  // Round f32_bits to the nearest BF16 by first adding                                                                                                                          
  // F32BitsToBF16RoundIncr(f32_bits) to f32_bits and then right shifting                                                                                                        
  // f32_bits + F32BitsToBF16RoundIncr(f32_bits) by 16                                                                                                                           
  
  // If f32_bits is the bit representation of a NaN F32 value, make sure that                                                                                                    
  // bit 6 of the BF16 result is set to convert SNaN F32 values to QNaN BF16                                                                                                     
  // values and to prevent NaN F32 values from being converted to an infinite                                                                                                    
  // BF16 value                                                                                                                                                                  
  return static_cast<uint16_t>(
      ((f32_bits + F32BitsToBF16RoundIncr(f32_bits)) >> 16) |
      (static_cast<uint32_t>((f32_bits & 0x7FFFFFFFu) > 0x7F800000u) << 6));
}

HWY_API HWY_BF16_CONSTEXPR bfloat16_t BF16FromF32(float f) {
#if HWY_HAVE_SCALAR_BF16_OPERATORS
  return static_cast<bfloat16_t>(f);
#else
  return bfloat16_t::FromBits(
      detail::F32BitsToBF16Bits(BitCastScalar<uint32_t>(f)));
#endif
}

And that's binary equivalent to what I wrote with the FTZ clause removed. So you're right. The smartest thing here would have probably been to do nothing. Plus our implementation actually goes slightly faster than Highway's.

static ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
    union {
        float f;   
        unsigned i;
    } u = {s};
    ggml_bf16_t h;
    if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
        h.x = u.i >> 16; /* unrounded conversion */
        h.x |= 64; /* maintain nan and have it be quiet */
        return h;
    }
    h.x = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
    return h;
}

The question is probably better asked of Intel and AMD. It's probably just because x86 is famous for having very slow handling of subnormals. But it should be dependent on the MXCSR register FTZ bit being set.

So we should probably just remove that if statement.

@CISC
Copy link
Contributor Author

CISC commented Jun 27, 2024

Fun fact; numpy 2.x will not do the implicit upcasting as long as the resulting value fits inside the original type, it will however issue a runtime warning on overflow, so committed a small change that works for both 1.x and 2.x

@CISC
Copy link
Contributor Author

CISC commented Jun 27, 2024

@jart If I understood you correctly you're suggesting that the flush-to-zero should be removed from both conversion and quantizing (in ggml)?

@jart
Copy link
Contributor

jart commented Jun 28, 2024

Yes. It makes the f32->bf16 algorithm simpler, faster, and more consistent with brain codebases. With normal floating point, flush to zero is only enabled in -funsafe-math-optimizations mode.

Please note functions like ggml_fp32_to_bf16_row() would still do the FTZ though on Zen4. So if Intel and AMD had a good reason for making their AVX512 BF16 instructions always flush to zero, then that would still happen in practice.

In any case, it's likely of little consequence, since subnormals are rare in the models whose floats I've analyzed, e.g. Mistral.

* ggml : add reference fp32 to bf16 conversion

The fast version is no longer equivalent for all platforms
because of the handling of subnormal values.

* gguf-py : remove flush to zero for bf16 subnormals

* gguf-py : remove float32 truncation to bf16

Rounding achieves the same thing in the cases where this was used.
@compilade
Copy link
Collaborator

compilade commented Jun 28, 2024

So we should probably just remove that if statement.

@jart Doing so would require to change that masterfully aligned comment.

I might have made it bad though.

diff --git a/ggml-impl.h b/ggml-impl.h
index 5e77471f..397e22a6 100644
--- a/ggml-impl.h
+++ b/ggml-impl.h
@@ -80,8 +80,9 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
 /**
  * Converts float32 to brain16.
  *
- * This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
- * Subnormals shall be flushed to zero, and NANs will be quiet.
+ * This is binary identical with Google Brain float conversion.
+ * Floats shall round to nearest even, and NANs shall be quiet.
+ * Subnormals aren't flushed to zero, except perhaps when used.
  * This code should vectorize nicely if using modern compilers.
  */
 static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
@@ -95,10 +96,6 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
         h.bits = (u.i >> 16) | 64; /* force to quiet */
         return h;
     }
-    if (!(u.i & 0x7f800000)) { /* subnormal */
-        h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
-        return h;
-    }
     h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
     return h;
 }

Please note functions like ggml_fp32_to_bf16_row() would still do the FTZ though on Zen4. So if Intel and AMD had a good reason for making their AVX512 BF16 instructions always flush to zero, then that would still happen in practice.

This seems to be used by ggml_quantize_chunk, which is indirectly used by examples/quantize.cpp. I'm not sure if it's acceptable for model files to be different when quantized on different platforms. This wasn't a problem before because FTZ was always done.

A workaround is to add ggml_fp32_to_bf16_row_reference() and make it use ggml_compute_fp32_to_bf16() on all platforms, a bit like q8_0 has a reference implementation used when quantizing model files, and a platform-specific optimized implementations when used as a vec_dot_type.

Although I don't know if f16 also has this problem (probably not?)

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Jun 28, 2024
@compilade
Copy link
Collaborator

compilade commented Jun 28, 2024

I ran some tests with some small models I had (less than 3B), and I can't find one with subnormals. It seems like they are usually already flushed to zero upstream. Qwen2-57B-A14B is too big for my system. I'd like to find a smaller model with subnormals.

(EDIT: Qwen1.5-MoE-A2.7B-Chat does not have subnormals)

@jart
Copy link
Contributor

jart commented Jun 28, 2024

I recommend having the quantize example program set a global variable, that tells the ggml function to not use the AVX512 version.

@compilade
Copy link
Collaborator

I recommend having the quantize example program set a global variable, that tells the ggml function to not use the AVX512 version.

No need, there's already a clean way to handle reference implementations:

llama.cpp/ggml/src/ggml.c

Lines 894 to 895 in 675a741

.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row_reference,

And during inference, ggml_compute_forward_mul_mat already picks the fast one for the vec_dot_type:

llama.cpp/ggml/src/ggml.c

Lines 12141 to 12142 in 675a741

enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;

Which is then used shortly after:

llama.cpp/ggml/src/ggml.c

Lines 12203 to 12205 in 675a741

from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10);

ggml_quantize_chunk is only used1 in llama_tensor_quantize_internal, which in turn is only used in llama_model_quantize_internal, only used by llama_model_quantize, which is then only used by examples/quantize/quantize.cpp.

So I think simply using the reference implementation in ggml_quantize_chunk is appropriate, because it's also what's already done for most of the other types (e.g. Q8_0).

Footnotes

  1. ggml_quantize_chunk is actually also used in examples/benchmark/benchmark-matmult.cpp, but only for Q4_0 and Q4_1.
    It's also used in clip_model_quantize in examples/llava/clip.cpp for the same purpose as llama_model_quantize.

jart added a commit to Mozilla-Ocho/llamafile that referenced this pull request Jun 29, 2024
jart added a commit to jart/gguf-tools that referenced this pull request Jul 3, 2024
After closely analyzing Google Brain codebases, we decided that flushing
to zero was the wrong thing to do. Intel and AMD probably designed their
microprocessors to always flush to zero for the wrong reasons. It should
have been made conditional on FTZ being set in MXCSR like other opcodes.

See ggerganov/llama.cpp#7843
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning python python script changes Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug: Cannot quantize a model to BF16 due to an overflow in gguf/quants.py
4 participants