Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce bfloat16 support #6412

Merged
merged 8 commits into from May 8, 2024
Merged

Introduce bfloat16 support #6412

merged 8 commits into from May 8, 2024

Conversation

jart
Copy link
Contributor

@jart jart commented Mar 31, 2024

Many models on Hugging Face (e.g. Mistral, TinyLLaMA) use bfloat16 as their canonical floating point format.

      ┌sign
      │
      │   ┌exponent
      │   │
      │   │      ┌mantissa
      │   │      │
      │┌──┴───┐┌─┴───┐
    0b0000000000000000 brain16

This encoding has the same number of exponent bits as float32. That makes conversion relatively straightforward, even in the absence of hardware support. For example, converting brain16 to binary32 means simply shifting 16 bits to the left.

      ┌sign
      │
      │   ┌exponent
      │   │
      │   │      ┌mantissa
      │   │      │
      │┌──┴───┐┌─┴───────────────────┐
    0b00000000000000000000000000000000 IEEE binary32

The issue is that converting weights from bf16 to fp16 will cause 3 bits of knowledge to be lost. There is currently no way to evaluate models like Mistral at full fidelity, without f32, using llama.cpp.

      ┌sign
      │
      │  ┌exponent
      │  │
      │  │    ┌mantissa
      │  │    │
      │┌─┴─┐┌─┴──────┐
    0b0000000000000000 IEEE binary16

This change fixes that, by adding a bf16 data type to GGML. Support for CPU inference has been implemented along with optimizations for the AVX2, AVX512F, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2 improves somewhere around -0.0024 to -0.0046 compared to using fp16

@jart jart force-pushed the bf16 branch 3 times, most recently from 436956a to e52d5e5 Compare March 31, 2024 15:09
Copy link
Contributor

github-actions bot commented Mar 31, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3: 523 iterations 🚀

  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8964.0ms p(90)=25761.73ms fails=0, finish reason: stop=523 truncated=0
  • Prompt processing (pp): avg=237.3tk/s p(90)=697.6tk/s total=203.61tk/s
  • Token generation (tg): avg=101.36tk/s p(90)=283.09tk/s total=132.93tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=bf16 commit=44d5c7070f3b33714c3d92b6e3c757e00877b4e1
Time series

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 523 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1712200211 --> 1712200833
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 93.17, 93.17, 93.17, 93.17, 93.17, 389.84, 389.84, 389.84, 389.84, 389.84, 417.01, 417.01, 417.01, 417.01, 417.01, 458.22, 458.22, 458.22, 458.22, 458.22, 479.5, 479.5, 479.5, 479.5, 479.5, 486.64, 486.64, 486.64, 486.64, 486.64, 514.45, 514.45, 514.45, 514.45, 514.45, 515.56, 515.56, 515.56, 515.56, 515.56, 528.8, 528.8, 528.8, 528.8, 528.8, 535.36, 535.36, 535.36, 535.36, 535.36, 555.53, 555.53, 555.53, 555.53, 555.53, 593.43, 593.43, 593.43, 593.43, 593.43, 607.51, 607.51, 607.51, 607.51, 607.51, 600.68, 600.68, 600.68, 600.68, 600.68, 607.12, 607.12, 607.12, 607.12, 607.12, 610.54, 610.54, 610.54, 610.54, 610.54, 611.02, 611.02, 611.02, 611.02, 611.02, 617.75, 617.75, 617.75, 617.75, 617.75, 628.59, 628.59, 628.59, 628.59, 628.59, 629.34, 629.34, 629.34, 629.34, 629.34, 629.21, 629.21, 629.21, 629.21, 629.21, 634.61, 634.61, 634.61, 634.61, 634.61, 637.1, 637.1, 637.1, 637.1, 637.1, 639.9, 639.9, 639.9, 639.9, 639.9, 638.96, 638.96, 638.96, 638.96, 638.96, 641.08, 641.08, 641.08, 641.08, 641.08, 641.61, 641.61, 641.61, 641.61, 641.61, 642.78, 642.78, 642.78, 642.78, 642.78, 649.25, 649.25, 649.25, 649.25, 649.25, 651.33, 651.33, 651.33, 651.33, 651.33, 654.43, 654.43, 654.43, 654.43, 654.43, 653.13, 653.13, 653.13, 653.13, 653.13, 652.4, 652.4, 652.4, 652.4, 652.4, 655.7, 655.7, 655.7, 655.7, 655.7, 657.83, 657.83, 657.83, 657.83, 657.83, 663.8, 663.8, 663.8, 663.8, 663.8, 668.95, 668.95, 668.95, 668.95, 668.95, 674.31, 674.31, 674.31, 674.31, 674.31, 675.06, 675.06, 675.06, 675.06, 675.06, 674.28, 674.28, 674.28, 674.28, 674.28, 674.39, 674.39, 674.39, 674.39, 674.39, 673.8, 673.8, 673.8, 673.8, 673.8, 676.58, 676.58, 676.58, 676.58, 676.58, 683.31, 683.31, 683.31, 683.31, 683.31, 673.42, 673.42, 673.42, 673.42, 673.42, 670.59, 670.59, 670.59, 670.59, 670.59, 669.69, 669.69, 669.69, 669.69, 669.69, 669.4, 669.4, 669.4, 669.4, 669.4, 667.55, 667.55, 667.55, 667.55, 667.55, 668.62, 668.62, 668.62, 668.62, 668.62, 671.56, 671.56, 671.56, 671.56, 671.56, 671.08, 671.08, 671.08, 671.08, 671.08, 664.43, 664.43, 664.43, 664.43, 664.43, 667.33, 667.33, 667.33, 667.33, 667.33, 670.62, 670.62, 670.62, 670.62, 670.62, 669.95, 669.95, 669.95, 669.95, 669.95, 673.49, 673.49, 673.49, 673.49, 673.49, 675.21, 675.21, 675.21, 675.21, 675.21, 677.55, 677.55, 677.55, 677.55, 677.55, 678.37, 678.37, 678.37, 678.37, 678.37, 678.12, 678.12]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 523 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1712200211 --> 1712200833
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 30.81, 30.81, 30.81, 30.81, 30.81, 18.56, 18.56, 18.56, 18.56, 18.56, 18.26, 18.26, 18.26, 18.26, 18.26, 18.59, 18.59, 18.59, 18.59, 18.59, 19.19, 19.19, 19.19, 19.19, 19.19, 20.2, 20.2, 20.2, 20.2, 20.2, 20.47, 20.47, 20.47, 20.47, 20.47, 20.55, 20.55, 20.55, 20.55, 20.55, 20.59, 20.59, 20.59, 20.59, 20.59, 20.43, 20.43, 20.43, 20.43, 20.43, 20.41, 20.41, 20.41, 20.41, 20.41, 20.2, 20.2, 20.2, 20.2, 20.2, 19.81, 19.81, 19.81, 19.81, 19.81, 19.41, 19.41, 19.41, 19.41, 19.41, 18.99, 18.99, 18.99, 18.99, 18.99, 18.86, 18.86, 18.86, 18.86, 18.86, 18.79, 18.79, 18.79, 18.79, 18.79, 18.94, 18.94, 18.94, 18.94, 18.94, 18.82, 18.82, 18.82, 18.82, 18.82, 18.68, 18.68, 18.68, 18.68, 18.68, 18.59, 18.59, 18.59, 18.59, 18.59, 18.46, 18.46, 18.46, 18.46, 18.46, 18.41, 18.41, 18.41, 18.41, 18.41, 18.45, 18.45, 18.45, 18.45, 18.45, 18.43, 18.43, 18.43, 18.43, 18.43, 18.44, 18.44, 18.44, 18.44, 18.44, 18.48, 18.48, 18.48, 18.48, 18.48, 18.58, 18.58, 18.58, 18.58, 18.58, 18.56, 18.56, 18.56, 18.56, 18.56, 18.53, 18.53, 18.53, 18.53, 18.53, 18.59, 18.59, 18.59, 18.59, 18.59, 18.71, 18.71, 18.71, 18.71, 18.71, 18.8, 18.8, 18.8, 18.8, 18.8, 18.94, 18.94, 18.94, 18.94, 18.94, 18.99, 18.99, 18.99, 18.99, 18.99, 18.96, 18.96, 18.96, 18.96, 18.96, 18.96, 18.96, 18.96, 18.96, 18.96, 18.9, 18.9, 18.9, 18.9, 18.9, 18.8, 18.8, 18.8, 18.8, 18.8, 18.66, 18.66, 18.66, 18.66, 18.66, 18.68, 18.68, 18.68, 18.68, 18.68, 18.67, 18.67, 18.67, 18.67, 18.67, 18.71, 18.71, 18.71, 18.71, 18.71, 18.75, 18.75, 18.75, 18.75, 18.75, 18.66, 18.66, 18.66, 18.66, 18.66, 18.58, 18.58, 18.58, 18.58, 18.58, 18.45, 18.45, 18.45, 18.45, 18.45, 18.24, 18.24, 18.24, 18.24, 18.24, 18.05, 18.05, 18.05, 18.05, 18.05, 17.9, 17.9, 17.9, 17.9, 17.9, 17.81, 17.81, 17.81, 17.81, 17.81, 17.84, 17.84, 17.84, 17.84, 17.84, 17.9, 17.9, 17.9, 17.9, 17.9, 17.92, 17.92, 17.92, 17.92, 17.92, 17.92, 17.92, 17.92, 17.92, 17.92, 17.91, 17.91, 17.91, 17.91, 17.91, 17.89, 17.89, 17.89, 17.89, 17.89, 17.86, 17.86, 17.86, 17.86, 17.86, 17.91, 17.91, 17.91, 17.91, 17.91, 17.97, 17.97, 17.97, 17.97, 17.97, 18.07, 18.07]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 523 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1712200211 --> 1712200833
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08, 0.08, 0.08, 0.08, 0.08, 0.24, 0.24, 0.24, 0.24, 0.24, 0.08, 0.08, 0.08, 0.08, 0.08, 0.17, 0.17, 0.17, 0.17, 0.17, 0.09, 0.09, 0.09, 0.09, 0.09, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.21, 0.21, 0.21, 0.21, 0.21, 0.24, 0.24, 0.24, 0.24, 0.24, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.16, 0.16, 0.16, 0.16, 0.16, 0.26, 0.26, 0.26, 0.26, 0.26, 0.19, 0.19, 0.19, 0.19, 0.19, 0.18, 0.18, 0.18, 0.18, 0.18, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.27, 0.27, 0.27, 0.27, 0.27, 0.29, 0.29, 0.29, 0.29, 0.29, 0.28, 0.28, 0.28, 0.28, 0.28, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.19, 0.19, 0.19, 0.19, 0.19, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.24, 0.24, 0.24, 0.24, 0.24, 0.21, 0.21, 0.21, 0.21, 0.21, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.09, 0.09, 0.09, 0.09, 0.09, 0.11, 0.11, 0.11, 0.11, 0.11, 0.31, 0.31, 0.31, 0.31, 0.31, 0.45, 0.45, 0.45, 0.45, 0.45, 0.38, 0.38, 0.38, 0.38, 0.38, 0.37, 0.37, 0.37, 0.37, 0.37, 0.4, 0.4, 0.4, 0.4, 0.4, 0.29, 0.29, 0.29, 0.29, 0.29, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.09, 0.09, 0.09, 0.09, 0.09, 0.19, 0.19, 0.19, 0.19, 0.19, 0.3, 0.3, 0.3, 0.3, 0.3, 0.23, 0.23, 0.23, 0.23, 0.23, 0.22, 0.22, 0.22, 0.22, 0.22, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.19, 0.19]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 523 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1712200211 --> 1712200833
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0]
                    

@JohannesGaessler
Copy link
Collaborator

The issue is that converting weights from bf16 to fp16 will cause 3 bits of knowledge to be lost. There is currently no way to evaluate models like Mistral at full fidelity, without f32, using llama.cpp.

IEEE 754 half precision floats can store values in the range $5.96 \cdot 10^{-8}$ to $65504$. For all values within this range there is no precision loss whatsoever when converting from BF16. And I would be very surprised if even a single model weight were to be outside this range since these would also be leading to vanishing/exploding gradients.

Perplexity on Mistral 7b 0.2 improves somewhere around -0.0024 to -0.0046 compared to using fp16

I think this is not due to any change in the weights but rather due to a difference in rounding error in the accumulator. I expect this improvement to not be consistent across models/text corpuses and I also expect there to be no statistically significant improvement at all for a large enough sample size.

@sorasoras
Copy link

There are some different between quant from BF16-FP32 to BF16-FP16.
It's not the same model when compare PPL between FP16 and FP32, and it behave differently.
It would be interest to inference BF16 directly.

@jart
Copy link
Contributor Author

jart commented Mar 31, 2024

@JohannesGaessler Only 13% of bf16 numbers can be represented accurately by a bf16 -> fp16 conversion. https://justine.lol/tmp/bf16-to-fp16.txt Yes, the vast majority of weights cluster within that 13%. By my calculation, only 0.29101% of Mistral 7b's numbers are broken. I want those numbers. I also don't want to accept limits on what's possible based on what's normal. Someone might find those broken intervals useful. But if that doesn't persuade you, consider this. I recently bought a Threadripper and it offers hardware acceleration for bf16 but not fp16. So this change is not just good for accuracy, it can be good for performance too.

@JohannesGaessler
Copy link
Collaborator

By my calculation, only 0.29101% of Mistral 7b's numbers are broken.

Broken in what sense? Numbers being flushed to zero is not an issue because the difference between 0 and almost 0 is negligible for matrix multiplication.

I recently bought a Threadripper and it offers hardware acceleration for bf16 but not fp16. So this change is not just good for accuracy, it can be good for performance too.

The performance point is valid.

In terms of numerical precision, this is the bottom line for me: I very much expect the difference between IEEE 754 half precision and bfloat to be completely negligible. I'm not telling you this out of malice but because I want contributors to spend their time in a way that is useful. If it turns out I'm wrong I will happily accept it.

@jart
Copy link
Contributor Author

jart commented Apr 1, 2024

You might find the differences negligible, but it's important to me. I want llamafile to be able to deliver, to the best of its ability, whatever number of bits are claimed, even if those extra bits are only good for audiophiles. In my day-to-day work as a developer, I feel more comfortable being able to compare my tradeoffs with the master copies. Furthermore, I need this data type in order to be able to exploit the full capabilities of my hardware.

Am I correct in understanding you won't merge this? That surprises me. This project recently accepted nine novel "IQ" quantization formats, which I know very little about. So I was under the impression there was a certain level of inclusiveness. Why would you not support the data type that companies like Mistral and Google widely use?

@JohannesGaessler
Copy link
Collaborator

Am I correct in understanding you won't merge this? That surprises me. This project recently accepted nine novel "IQ" quantization formats, which I know very little about. So I was under the impression there was a certain level of inclusiveness. Why would you not support the data type that companies like Mistral and Google widely use?

The ultimate decision of what gets merged is not up to me. And I am not at all opposed to adding bfloat support. I only want to stress that I do not expect the gains from this feature to be in any way proportional to the amount of effort it will take. As such I personally will not invest time into bfloat support by e.g. modifying the CUDA code. If other devs want to do it that is their decision.

@1aienthusiast

This comment was marked as off-topic.

@jart
Copy link
Contributor Author

jart commented Apr 1, 2024

I don't hold any demands on your time. In terms of resources, Mozilla is sponsoring me to help llama.cpp so you've got a lot more resources than before. At the moment, I only need this to work on CPU however I'll likely get personal enjoyment at some point in getting this to work on CUDA and Metal too. Particularly Metal, since I've been looking for a good reason to learn it for some time.

@sorasoras
Copy link

I don't hold any demands on your time. In terms of resources, Mozilla is sponsoring me to help llama.cpp so you've got a lot more resources than before. At the moment, I only need this to work on CPU however I'll likely get personal enjoyment at some point in getting this to work on CUDA and Metal too. Particularly Metal, since I've been looking for a good reason to learn it for some time.

I would imagine older cuda hardware wouldn't support it due to bf16 unsupport on Pascal. What's solution to that?

@jart
Copy link
Contributor Author

jart commented Apr 1, 2024

Here's the decoding process for bfloat16:

typedef struct {
    uint16_t x;
} ggml_bf16_t;

/**
 * Converts brain16 to float32.
 */
static inline float ggml_bf16_to_fp32(ggml_bf16_t h) {
    union {
        float f;
        uint32_t i;
    } u;
    u.i = (uint32_t)h.x << 16;
    return u.f;
}

So the only thing old CUDA needs to do, is left shift the bf16 number by 16 bits, and then it becomes a float.

@Artefact2
Copy link
Collaborator

I think bf16 support is nice to have in GGUF, if only because it makes quantizing a lot of models much less I/O intensive. Consider changing convert.py to make use of it.

@JohannesGaessler
Copy link
Collaborator

Relevant for discussion: Mozilla-Ocho/llamafile@ef0307e

It seems there seem to be at least some values above the maximum value representable by IEEE 754 half precision floats. @jart do you know in which specific matrices these weights show up? Depending on where they are relative to softmax this could be an issue.

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 1, 2024

Is there anything special needed to see performance gains? I cloned/built/tested this PR branch and am seeing no change in performance on CPU (CUDA support flags disabled at compile time)

@sorasoras
Copy link

Is there anything special needed to see performance gains? I cloned/built/tested this PR branch and am seeing no change in performance on CPU (CUDA support flags disabled at compile time)

For CPU, I think you need something that support bf16 acceleration like AVX512VNNI?
also, you need conversion script that just copy BF16 weight from py to GGUF to get any benefit.

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 1, 2024

For CPU, I think you need something that support bf16 acceleration like AVX512VNNI?
also, you need conversion script that just copy BF16 weight from py to GGUF to get any benefit.

system_info: n_threads = 55 / 128 | AVX = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |

Hardware-wise I think I have what's needed.
Is the conversion script already available? I don't see it in any obvious place in this PR

@sorasoras
Copy link

For CPU, I think you need something that support bf16 acceleration like AVX512VNNI?
also, you need conversion script that just copy BF16 weight from py to GGUF to get any benefit.

system_info: n_threads = 55 / 128 | AVX = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |

Hardware-wise I think I have what's needed. Is the conversion script already available? I don't see it in any obvious place in this PR

https://justine.lol/matmul/
I think the full implementation is in llamafile side.
LLM Performance on AMD Ryzen Threadripper PRO 7995WX w/ 96 cores ($10,000)

image

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 2, 2024

I think the full implementation is in llamafile side.

What should be expected in llama.cpp from this patch specifically? I'm seeing about 6% speed increase on prompt processing and inference and I've pulled and built the master, avx512vnni, sgemm and bf16 branches. Each of them perform almost identically on a Q8 70b.
I'm on EPYC Genoa, so if anything I'd expect better results than that threadripper system.

@jart
Copy link
Contributor Author

jart commented Apr 2, 2024

@Artefact2 I've updated gguf-py/gguf/constants.py so that BF16 is listed. I have no idea how to make the Python script generate BF16 GGML files. What I've been doing is running convert.py --outtype f32 and then running the ./quantize ggml-model-f32.gguf ggml-model-bf16.gguf bf16 program. Please take a look.

@jart
Copy link
Contributor Author

jart commented Apr 2, 2024

@cpumaxx This change only adds support for bf16. Once #6414 the next thing I'll do is upstream the llamafile bfloat16 kernels. Here's what one of them looks like:

image

I'm working on ARM64 bfloat16 kernels tonight.

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 2, 2024

the next thing I'll do is upstream the llamafile bfloat16 kernels

Nice. I'll keep an eye out for them. Is there a relevant branch on your llama.cpp fork I can test prior to a PR, or do you still need to merge changes already in llamafile?

@jart
Copy link
Contributor Author

jart commented Apr 2, 2024

@cpumaxx Could you download https://huggingface.co/jartine/Mistral-7B-Instruct-v0.2-llamafile/blob/main/mistral-7b-instruct-v0.2.BF16.gguf and then build the code in the branch I just created https://github.com/jart/llama.cpp/tree/unified which unifies #6412 and #6414? Thanks!

@jart
Copy link
Contributor Author

jart commented Apr 2, 2024

Here's an example of what you should expect to see with that branch.

wget https://huggingface.co/jartine/Mistral-7B-Instruct-v0.2-llamafile/resolve/main/mistral-7b-instruct-v0.2.BF16.gguf
wget https://justine.lol/tmp/getty.txt
make -j32 main && ./main -m /disk/mistral/mistral-7b-instruct-v0.2.BF16.gguf -f ~/getty.txt -n 22 --temp 0
[...]
It is for us, the living, rather to be dedicated here to the unfinished work which they who fought here have thus far so nobly advanced.
llama_print_timings:        load time =     773.90 ms
llama_print_timings:      sample time =       0.46 ms /    22 runs   (    0.02 ms per token, 48034.93 tokens per second)
llama_print_timings: prompt eval time =     407.51 ms /   215 tokens (    1.90 ms per token,   527.59 tokens per second)
llama_print_timings:        eval time =    1230.05 ms /    21 runs   (   58.57 ms per token,    17.07 tokens per second)
llama_print_timings:       total time =    1643.99 ms /   236 tokens
Log end

EPYC is for servers so I've heard they generally run at much lower clock rates than Threadripper Pro. So if you get a lower number than 530 tok/sec then try comparing it to llama.cpp at HEAD using the Mistral 7b f16 weights.

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 2, 2024

Here's an example of what you should expect to see with that branch.

llama_print_timings:        load time =     773.90 ms
llama_print_timings:      sample time =       0.46 ms /    22 runs   (    0.02 ms per token, 48034.93 tokens per second)
llama_print_timings: prompt eval time =     407.51 ms /   215 tokens (    1.90 ms per token,   527.59 tokens per second)
llama_print_timings:        eval time =    1230.05 ms /    21 runs   (   58.57 ms per token,    17.07 tokens per second)
llama_print_timings:       total time =    1643.99 ms /   236 tokens
Log end

EPYC is for servers so I've heard they generally run at much lower clock rates than Threadripper Pro. So if you get a lower number than 530 tok/sec then try comparing it to llama.cpp at HEAD using the Mistral 7b f16 weights.

My system is a dual 64 core 9334 running with a 3.9ghz boost clock
I've got NPS set at 4 (so 8 numa nodes) for development reasons, which may be effecting results.
I tested your unified branch vs ggerganov master, and I'm seeing a severe speed regression:

/usr/src/llama.cpp.jart# ./main -m /media/models/bf16/mistral-7b-instruct-v0.2.BF16.gguf -n 22 --temp 0

system_info: n_threads = 64 / 128 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 |

Question: Let i = 11 + -11. Let g = 1.1 + -
llama_print_timings: sample time = 0.53 ms / 22 runs ( 0.02 ms per token, 41666.67 tokens per second)
llama_print_timings: prompt eval time = 0.00 ms / 1 tokens ( 0.00 ms per token, inf tokens per second)
llama_print_timings: eval time = 4618.55 ms / 22 runs ( 209.93 ms per token, 4.76 tokens per second)
llama_print_timings: total time = 4624.48 ms / 23 tokens

vs

/usr/src/llama.cpp.master.clean# ./main -m /media/models/bf16/ggml-model-f16.gguf -n 22 --temp 0

llama_print_timings: sample time = 0.61 ms / 22 runs ( 0.03 ms per token, 36184.21 tokens per second)
llama_print_timings: prompt eval time = 0.00 ms / 1 tokens ( 0.00 ms per token, inf tokens per second)
llama_print_timings: eval time = 1363.28 ms / 22 runs ( 61.97 ms per token, 16.14 tokens per second)
llama_print_timings: total time = 1369.01 ms / 23 tokens

This was with identical build flags and after dropping all caches for a level playing field.

Anything else I should be trying in order to see the speedup?

@jart
Copy link
Contributor Author

jart commented Apr 2, 2024

Could you pass the flag -f getty.txt please after you've downloaded that file from the link above? Then re-post your results.

@jart
Copy link
Contributor Author

jart commented Apr 26, 2024

@ggerganov Thanks for showing me how to do that. I'll be sure to run those tests on future changes. All your comments have been addressed. PTAL.

Copy link
Contributor

github-actions bot commented Apr 26, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 519 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=9024.52ms p(95)=22221.76ms fails=, finish reason: stop=454 truncated=65
  • Prompt processing (pp): avg=106.05tk/s p(95)=466.16tk/s
  • Token generation (tg): avg=30.9tk/s p(95)=47.97tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=bf16 commit=632624e9d79b536084dda885f355ce393f77e38e

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 519 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1715147978 --> 1715148610
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 449.94, 449.94, 449.94, 449.94, 449.94, 584.39, 584.39, 584.39, 584.39, 584.39, 603.15, 603.15, 603.15, 603.15, 603.15, 625.35, 625.35, 625.35, 625.35, 625.35, 709.99, 709.99, 709.99, 709.99, 709.99, 708.87, 708.87, 708.87, 708.87, 708.87, 712.98, 712.98, 712.98, 712.98, 712.98, 728.81, 728.81, 728.81, 728.81, 728.81, 741.55, 741.55, 741.55, 741.55, 741.55, 742.02, 742.02, 742.02, 742.02, 742.02, 744.15, 744.15, 744.15, 744.15, 744.15, 767.88, 767.88, 767.88, 767.88, 767.88, 788.77, 788.77, 788.77, 788.77, 788.77, 814.43, 814.43, 814.43, 814.43, 814.43, 811.15, 811.15, 811.15, 811.15, 811.15, 789.29, 789.29, 789.29, 789.29, 789.29, 795.43, 795.43, 795.43, 795.43, 795.43, 792.72, 792.72, 792.72, 792.72, 792.72, 803.52, 803.52, 803.52, 803.52, 803.52, 804.77, 804.77, 804.77, 804.77, 804.77, 804.81, 804.81, 804.81, 804.81, 804.81, 809.43, 809.43, 809.43, 809.43, 809.43, 812.46, 812.46, 812.46, 812.46, 812.46, 826.76, 826.76, 826.76, 826.76, 826.76, 830.12, 830.12, 830.12, 830.12, 830.12, 830.57, 830.57, 830.57, 830.57, 830.57, 838.8, 838.8, 838.8, 838.8, 838.8, 844.01, 844.01, 844.01, 844.01, 844.01, 840.86, 840.86, 840.86, 840.86, 840.86, 840.14, 840.14, 840.14, 840.14, 840.14, 840.48, 840.48, 840.48, 840.48, 840.48, 846.02, 846.02, 846.02, 846.02, 846.02, 844.77, 844.77, 844.77, 844.77, 844.77, 844.33, 844.33, 844.33, 844.33, 844.33, 846.74, 846.74, 846.74, 846.74, 846.74, 858.02, 858.02, 858.02, 858.02, 858.02, 863.78, 863.78, 863.78, 863.78, 863.78, 871.18, 871.18, 871.18, 871.18, 871.18, 865.28, 865.28, 865.28, 865.28, 865.28, 862.97, 862.97, 862.97, 862.97, 862.97, 865.36, 865.36, 865.36, 865.36, 865.36, 867.73, 867.73, 867.73, 867.73, 867.73, 867.55, 867.55, 867.55, 867.55, 867.55, 859.69, 859.69, 859.69, 859.69, 859.69, 821.64, 821.64, 821.64, 821.64, 821.64, 809.04, 809.04, 809.04, 809.04, 809.04, 807.36, 807.36, 807.36, 807.36, 807.36, 806.25, 806.25, 806.25, 806.25, 806.25, 809.02, 809.02, 809.02, 809.02, 809.02, 812.15, 812.15, 812.15, 812.15, 812.15, 812.11, 812.11, 812.11, 812.11, 812.11, 818.46, 818.46, 818.46, 818.46, 818.46, 817.04, 817.04, 817.04, 817.04, 817.04, 819.99, 819.99, 819.99, 819.99, 819.99, 824.75, 824.75, 824.75, 824.75, 824.75, 823.62, 823.62, 823.62, 823.62, 823.62, 830.48, 830.48, 830.48, 830.48, 830.48, 830.16, 830.16, 830.16, 830.16, 830.16, 831.05, 831.05, 831.05, 831.05, 831.05, 831.03, 831.03, 831.03, 831.03, 831.03, 831.91, 831.91]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 519 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1715147978 --> 1715148610
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 38.0, 38.0, 38.0, 38.0, 38.0, 38.98, 38.98, 38.98, 38.98, 38.98, 29.46, 29.46, 29.46, 29.46, 29.46, 28.88, 28.88, 28.88, 28.88, 28.88, 29.39, 29.39, 29.39, 29.39, 29.39, 29.18, 29.18, 29.18, 29.18, 29.18, 30.49, 30.49, 30.49, 30.49, 30.49, 31.32, 31.32, 31.32, 31.32, 31.32, 31.34, 31.34, 31.34, 31.34, 31.34, 31.59, 31.59, 31.59, 31.59, 31.59, 31.18, 31.18, 31.18, 31.18, 31.18, 31.05, 31.05, 31.05, 31.05, 31.05, 30.96, 30.96, 30.96, 30.96, 30.96, 30.43, 30.43, 30.43, 30.43, 30.43, 30.22, 30.22, 30.22, 30.22, 30.22, 29.62, 29.62, 29.62, 29.62, 29.62, 29.59, 29.59, 29.59, 29.59, 29.59, 29.95, 29.95, 29.95, 29.95, 29.95, 29.99, 29.99, 29.99, 29.99, 29.99, 29.78, 29.78, 29.78, 29.78, 29.78, 29.55, 29.55, 29.55, 29.55, 29.55, 29.52, 29.52, 29.52, 29.52, 29.52, 29.49, 29.49, 29.49, 29.49, 29.49, 29.56, 29.56, 29.56, 29.56, 29.56, 29.59, 29.59, 29.59, 29.59, 29.59, 29.67, 29.67, 29.67, 29.67, 29.67, 29.92, 29.92, 29.92, 29.92, 29.92, 29.91, 29.91, 29.91, 29.91, 29.91, 29.34, 29.34, 29.34, 29.34, 29.34, 29.2, 29.2, 29.2, 29.2, 29.2, 29.31, 29.31, 29.31, 29.31, 29.31, 29.43, 29.43, 29.43, 29.43, 29.43, 29.47, 29.47, 29.47, 29.47, 29.47, 29.58, 29.58, 29.58, 29.58, 29.58, 29.64, 29.64, 29.64, 29.64, 29.64, 29.65, 29.65, 29.65, 29.65, 29.65, 29.58, 29.58, 29.58, 29.58, 29.58, 29.48, 29.48, 29.48, 29.48, 29.48, 29.23, 29.23, 29.23, 29.23, 29.23, 29.19, 29.19, 29.19, 29.19, 29.19, 29.37, 29.37, 29.37, 29.37, 29.37, 29.38, 29.38, 29.38, 29.38, 29.38, 29.43, 29.43, 29.43, 29.43, 29.43, 29.54, 29.54, 29.54, 29.54, 29.54, 29.31, 29.31, 29.31, 29.31, 29.31, 29.23, 29.23, 29.23, 29.23, 29.23, 29.11, 29.11, 29.11, 29.11, 29.11, 28.63, 28.63, 28.63, 28.63, 28.63, 27.92, 27.92, 27.92, 27.92, 27.92, 27.87, 27.87, 27.87, 27.87, 27.87, 27.95, 27.95, 27.95, 27.95, 27.95, 28.02, 28.02, 28.02, 28.02, 28.02, 28.08, 28.08, 28.08, 28.08, 28.08, 28.15, 28.15, 28.15, 28.15, 28.15, 28.2, 28.2, 28.2, 28.2, 28.2, 28.14, 28.14, 28.14, 28.14, 28.14, 28.09, 28.09, 28.09, 28.09, 28.09, 28.16, 28.16, 28.16, 28.16, 28.16, 28.22, 28.22, 28.22, 28.22, 28.22, 28.39, 28.39, 28.39, 28.39, 28.39, 28.47, 28.47]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 519 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1715147978 --> 1715148610
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16, 0.16, 0.16, 0.16, 0.16, 0.46, 0.46, 0.46, 0.46, 0.46, 0.26, 0.26, 0.26, 0.26, 0.26, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.24, 0.24, 0.24, 0.24, 0.24, 0.29, 0.29, 0.29, 0.29, 0.29, 0.21, 0.21, 0.21, 0.21, 0.21, 0.14, 0.14, 0.14, 0.14, 0.14, 0.35, 0.35, 0.35, 0.35, 0.35, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.24, 0.24, 0.24, 0.24, 0.24, 0.26, 0.26, 0.26, 0.26, 0.26, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.1, 0.1, 0.1, 0.1, 0.1, 0.13, 0.13, 0.13, 0.13, 0.13, 0.26, 0.26, 0.26, 0.26, 0.26, 0.37, 0.37, 0.37, 0.37, 0.37, 0.23, 0.23, 0.23, 0.23, 0.23, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.19, 0.19, 0.19, 0.19, 0.19, 0.18, 0.18, 0.18, 0.18, 0.18, 0.32, 0.32, 0.32, 0.32, 0.32, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.08, 0.08, 0.08, 0.08, 0.08, 0.11, 0.11, 0.11, 0.11, 0.11, 0.35, 0.35, 0.35, 0.35, 0.35, 0.51, 0.51, 0.51, 0.51, 0.51, 0.59, 0.59, 0.59, 0.59, 0.59, 0.61, 0.61, 0.61, 0.61, 0.61, 0.51, 0.51, 0.51, 0.51, 0.51, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.11, 0.11, 0.11, 0.11, 0.11, 0.18, 0.18, 0.18, 0.18, 0.18, 0.32, 0.32, 0.32, 0.32, 0.32, 0.19, 0.19, 0.19, 0.19, 0.19, 0.26, 0.26, 0.26, 0.26, 0.26, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.16, 0.16]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 519 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1715147978 --> 1715148610
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
                    

@jart jart force-pushed the bf16 branch 2 times, most recently from ed0f47b to 82aebcf Compare May 1, 2024 16:59
@unicomp21
Copy link

#6412 (comment)

browser? webgpu? webassembly? mesh networking w/ rtcdatachannel?

@Srihari-mcw
Copy link

Hi @jart, when the PR was tried building and running in windows, the build gave issues. In PR 1 of your fork, the same was addressed and fixed. Could you please take a look on the same. Thanks

@Srihari-mcw
Copy link

Srihari-mcw commented May 3, 2024

@jart, Further we had tried to run the prompt speedup code from https://github.com/jart/llama.cpp/tree/unified . With the current code in the fork, the code was going through second input(operand) as GGML_TYPE_F32 for mulmat functions. We tried to modify the code such that the second input is in GGML_TYPE_BF16 for mulmat kernels and removes the GGML_TYPE_F32 case, which enables the input of second operand (Btype) to get quantized to BF16 format and hence uses BF16 intrinsics in turn for dot product operation. Significant speedup was observed while comparing the code with original version in the fork where the second operand of mulmat operation is in FP32 format.

model size params backend threads test t/s speedup Commit id
llama 7B BF16 (without prompt speedup) 12.55 GiB 6.74 B CPU 6 pp 512 39.111 ± 0.03 4e57aa6
llama 7B BF16 ( prompt speedup - both inputs BF16) 12.55 GiB 6.74 B CPU 6 pp 512 103.343 ± 0.14 164.23% b25ba28
llama 7B BF16 ( prompt speedup - BF16 x FP32 - second input FP32 format) 12.55 GiB 6.74 B CPU 6 pp 512 45.126 ± 0.04 15.379% 4e57aa6

The code was tested in AMD Raphael 7600X machine which has AVX512_BF16 support in Linux platform. The original unquantized model is taken from https://huggingface.co/TheBloke/wizardLM-7B-HF . Please find the updated code in PR 2 of your fork of llama.cpp - jart#2. Changes in jart#1 (PR 1) was included while testing the same

Could you please share your thoughts here? Is prompt speedup for BF16 models planned to be included in future commits of prompt speedup changes/ BF16 model PR? Thanks

@jart
Copy link
Contributor Author

jart commented May 3, 2024

@Srihari-mcw this change doesn't modify sgemm.cpp because then it would overlap with my other change:

So BF16 optimizations are blocked on review. As for your pull request, the canonical location of the code you're modifying is here:

I've done a lot of work in the past month identifying other performance opportunities.

jart added 8 commits May 7, 2024 22:26
Many models on Hugging Face (e.g. Mistral, TinyLLaMA) use bfloat16 as
their canonical floating point format.

      ┌sign
      │
      │   ┌exponent
      │   │
      │   │      ┌mantissa
      │   │      │
      │┌──┴───┐┌─┴───┐
    0b0000000000000000 brain16

This encoding has the same number of exponent bits as float32. That
makes conversion relatively straightforward, even in the absence of
hardware support. For example, converting brain16 to binary32 means
simply shifting 16 bits to the left.

      ┌sign
      │
      │   ┌exponent
      │   │
      │   │      ┌mantissa
      │   │      │
      │┌──┴───┐┌─┴───────────────────┐
    0b00000000000000000000000000000000 IEEE binary32

The issue is that converting bf16 to fp16 can result in information
loss. Only 13% of bf16 numbers can be precisely represented in fp16
which in practice ends up being 99.71% of Mistral 7b v0.2's weights
however there is currently no way other than fp32 to get the others

      ┌sign
      │
      │  ┌exponent
      │  │
      │  │    ┌mantissa
      │  │    │
      │┌─┴─┐┌─┴──────┐
    0b0000000000000000 IEEE binary16

This change fixes that, by adding a bf16 data type to GGML. Support
for CPU inference has been implemented along with optimizations for
the AVX2, AVX512, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2
improves somewhere around -0.0024 to -0.0046 compared to using fp16
@ggerganov ggerganov merged commit 3855416 into ggerganov:master May 8, 2024
64 checks passed
@ddh0
Copy link

ddh0 commented May 8, 2024

So happy to see this land! Will convert.py and convert-hf-to-gguf.py need to be updated?

@arch-btw
Copy link

arch-btw commented May 8, 2024

So happy to see this land! Will convert.py and convert-hf-to-gguf.py need to be updated?

I'm wondering the same thing

@jart
Copy link
Contributor Author

jart commented May 8, 2024

The Python scripts do need to be updated. I was only able to add the IDs. I wasn't able to successfully figure out how to get the raw bfloat16 data from Torch because Numpy doesn't support it. Someone who knows more than me will need to figure that out.

So happy to see this merged @ggerganov! Thank you!

@jart
Copy link
Contributor Author

jart commented May 8, 2024

By the way, the workaround I'm currently using is to:

  1. Use Python to create an --outtype f32 gguf file.
  2. Run ./quantize ggml-model-f32.gguf ggml-model-bf16.gguf bf16 to create bfloat16 weights.

@teleprint-me
Copy link
Contributor

teleprint-me commented May 8, 2024

The Python scripts do need to be updated. I was only able to add the IDs. I wasn't able to successfully figure out how to get the raw bfloat16 data from Torch because Numpy doesn't support it. Someone who knows more than me will need to figure that out.

We'll need to use a custom wrapper to implement. I tried doing this last year with pure python and it was a no go. Probably ctypes interface to add FFI support? Not sure how the community feels about this. I have experience with C and my expertise is in Python, but my C++ is limited and I've been picking it up as I go.

@compilade
Copy link
Collaborator

compilade commented May 8, 2024

Note

my implementation of bfloat16 conversion was too naïve: it didn't round to nearest even and did not handle subnormals. So I've decided to exclude my flawed bfloat16 conversion from #7075

I've added support for bfloat16 conversion as part of #7075, doing the conversion with Numpy is possible even if it doesn't support the bfloat16 type. More explanations in #7075 (comment), and the relevant changes are in 1eccde6.

Important

EDIT: I've made a proper implementation in #7158 which does properly handle subnormals, and rounding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet