Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SOTA 3-bit quants #5196

Merged
merged 14 commits into from Jan 30, 2024
Merged

SOTA 3-bit quants #5196

merged 14 commits into from Jan 30, 2024

Conversation

ikawrakow
Copy link
Contributor

@ikawrakow ikawrakow commented Jan 29, 2024

TL;DR

This PR adds "true" 3-bit quants (3.0625 bpw due to block structure) as IQ3_XXS. Both, model size and perplexity are lower compared to Q3_K_XS

Details

Table shows a comparison between IQ3_XXS added by this PR and Q3_K_XS for several models. Sizes are in GiB, perplexity is for a context of 512 tokens and uses an importance matrix from wiki.train.raw.

Model Model size Q3_K_XS Model size IQ3_XXS Perplexity Q3_K_XS Perplexity IQ3_XXS
Mistral-7B 2.786 2.705 6.1632 6.0578
Mixtral-8x7B 17.75 17.05 4.5174 4.4570
LLaMA-v2-7B 2.581 2.503 6.3633 6.3013
LLaMA-v2-13B 4.947 4.782 5.5148 5.4469
LLaMA-v2-70B 26.31 25.17 3.7808 3.7268
LLaMA-v1-7B 2.581 2.503 6.5129 6.3780
LLaMA-v1-13B 4.947 4.782 5.6853 5.6126
LLaMA-v1-30B 12.27 11.84 4.5903 4.5223
LLaMA-v1-65B 24.58 23.67 3.9516 3.8880

Even though nobody uses LLaMA-v1 these days, I have added the results in view of the fact that early quantization work in this repository did happen using LLaMA-v1. When k-quants were first released in PR #1684, the smallest quantized model at the time was Q2_K with a size of 2.67 GiB and perplexity of 6.773.

To avoid confusion with the PPL tables in the 2-bit quant PRs, here is a table of PPL values for the new IQ3_XXS quantization type for a context of 4096 tokens:

Model PPL IQ3_XXS
LLaMA-v2-7B 5.3724
Mistral-7B 5.0700
LLaMA-v2-13B 4.6948
Mixtral-8x7B 3.7689
LLaMA-v2-70B 3.2757

How

Approach follows in the footsteps of IQ2_XXS and IQ2_XS (#4773, #4856, #4897):

  • Same trick is used to save 1 bit per group of 8 weights when encoding the signs
  • A lattice is being used for groups of weights. In IQ2_XXS and IQ2_XS it is the E8-lattice for groups of 8 quants, here it is the D4-lattice (https://en.wikipedia.org/wiki/16-cell_honeycomb) for groups of 4 quants
  • 256 D4-lattice points are used, so for a block of 32 weights one needs 8 x 8 + 4 x 7 = 92 bits to encode quant magnitudes and signs. This leaves 4 spare bits for an unsigned block scale to use exactly 3 bpw. The super-block fp16 scale needs another 16 bits per super-block of 256, so we end up using 3.0625 bpw.

RMSE seems a bit high-ish at about half-way between q2_K and
q3_K, so need to check more.
PPL on wiki.test.raw
LLaMA-v1-7B: 6.4218
LLaMA-v2-7B: 6.3560
Mistral-7B : 6.0717

This is better than Q3_K_XS, with a 5% reduction in quantized model
size.
We have
PP-512: 5891 t/s
TG-128: 143.9 t/s
Metal performance is decent, ARM_NEON is pathetic
@Artefact2
Copy link
Collaborator

Build fails on ROCm.

g++ -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS  -std=c++11 -fPIC -O3 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -pthread  -march=native -mtune=native -Wno-array-bounds -Wno-format-truncation -Wextra-semi -c common/console.cpp -o console.o
ggml-cuda.cu:4437:28: error: use of undeclared identifier '__vsub4'
        const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
                           ^
ggml-cuda.cu:4438:28: error: use of undeclared identifier '__vsub4'
        const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
                           ^
ggml-cuda.cu:4447:28: error: use of undeclared identifier '__vsub4'
        const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
                           ^
ggml-cuda.cu:4448:28: error: use of undeclared identifier '__vsub4'
        const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
                           ^
ggml-cuda.cu:4481:28: error: use of undeclared identifier '__vsub4'
        const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]);
                           ^
ggml-cuda.cu:4482:28: error: use of undeclared identifier '__vsub4'
        const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]);

@sorasoras
Copy link

sorasoras commented Jan 29, 2024

PPL 2.31 bpw got 5.822 for mistral7B, but mistral7b IQ3_XXS has 6.0578
That's pretty strange thb.
Should also include Q2K and Q2Ks to the dataset for comparison.

@ikawrakow
Copy link
Contributor Author

PPL 2.31 bpw got 5.822 for mistral7B, but mistral7b IQ3_XXS has 6.0578 That's pretty strange thb. Should also include Q2K and Q2Ks to the dataset for comparison.

5.822 is for a context size of 4096 token, 6.0578 is for 512. The IQ3_XXS PPL for context of 4096 is 5.0700

I have always been very careful to state the context length of a PPL result. When I published the 2-bit quants, I did state PPL's for context of 4096 because this allowed direct comparison with PPL values from a recent paper claiming SOTA performance. Sorry that this is causing a confusion.

@ikawrakow
Copy link
Contributor Author

@Artefact2 Does the change I pushed fix it? I don't have an AMD card to test. I have used __vsub4 instead of __vsubss4 because it is slightly faster (no need to worry about sign saturation). I have now added a ROCm __vsub4 that just calls __vsubss4.

@sorasoras
Copy link

@Artefact2 Does the change I pushed fix it? I don't have an AMD card to test. I have used __vsub4 instead of __vsubss4 because it is slightly faster (no need to worry about sign saturation). I have now added a ROCm __vsub4 that just calls __vsubss4.

The build are fine on windows with ROCM 5.7.1 now.

This time the dot product accuracy did find an actual bug
in the AVX2 implementation.
@Artefact2
Copy link
Collaborator

Artefact2 commented Jan 29, 2024

Some KL-divergence data on bagel-34b. IQ3_XXS seems to completely obsolete Q3_K_XS and Q3_K_S.

foo

Edit: pure speculation, but IQ4 would be promising
image

@JianbangZ
Copy link

Are the number of columns in the ffn_down tensor stil required to be a multiple of 256? can we get some support on 128 instead in addition to 256? If not what could be issues?
For the Qwen model I have been testing on, the dimension is ffn_down is dim13696x5120, so 13696 cols.

@sorasoras
Copy link

sorasoras commented Jan 29, 2024

Are the number of columns in the ffn_down tensor stil required to be a multiple of 256? can we get some support on 128 instead in addition to 256? If not what could be issues? For the Qwen model I have been testing on, the dimension is ffn_down is dim13696x5120, so 13696 cols.

yup, I tried it. it's still 256. we need quant that could do multiple of 128 instead fallback to legacy quants.
They group tensor of 256 due to size reason.
you could do "LLAMA_QKK_64"
That would give you size of 64 k-quants.

@JiHa-Kim
Copy link

Hello. Is it possible to use the techniques from the Additive Quantization for Language Models (AQLM) paper? It seems to have excellent results.

https://arxiv.org/abs/2401.06118

@ikawrakow
Copy link
Contributor Author

Hello. Is it possible to use the techniques from the Additive Quantization for Language Models (AQLM) paper? It seems to have excellent results.

https://arxiv.org/abs/2401.06118

More excellent than these results?

From the quoted paper, table 2 (3-bit quantization, which is the subject of this PR), I see WikiText perplexities of 5.46, 4.83, and 3.36 for the LLaMA-v2 models. The authors of such papers always use a context of 4096 (even if not mentioned explicitly in the text, one can deduce it from the PPL of the fp16 model). The second table in my post shows the 3-bit results for a context of 4096 for the 3 LLaMA-v2's as 5.37, 4.68, 3.27. In fairness, the PPL's that llama.cpp calculates are slightly lower compared to what they get in the Python universe (~3% lower for the LLaMA-v2's), so their results are comparable to, or perhaps slightly better than, the results reported here.

What is difference between what they have done and what is being done here? Basically:

  • I use 256 D4-lattice points, they use 4096 E8-lattice points
  • The 256 D4-lattice points are determined by minimizing the distance to the full possible D4 set that comes out from the quantization of a bunch of models. Once this is done, the same set is used for all quantizations. Theirs is determined individually for every tensor in every model by running an optimization according to Algorithm-1.
  • IQ3_XXS quantization runs in 1.5 minutes on my CPU for 7B models, theirs requires a training run on a GPU. They don't say how long it takes, only that it is possible to use a single GPU, and that it takes significantly longer than GPTQ. My guess is that "significantly longer" is so significant that a) they didn't want to put it in the paper and b) it is not practical for quantizing the myriad of LLMs and fine-tunes out there.
  • IQ3_XXS inference is ready and can be used today, the implementation of AQLM is left for future work. I get ~160 tokens/second for a 7B model quantized with IQ3_XXS on RTX-4080. I'm curious to see what they will get, given the large size of their E8-lattice codebook (the codebook becoming too large for 3-bit quants, and the associated performance concerns, was the main reason I went with D4 rather than E8).

So, in short, I selected the approach in this PR for a) inference performance reasons (D4-lattice instead of E8, which is theoretically better), and b) for practicality (llama.cpp users can quantize their model of choice using their consumer-grade hardware). Can one do better? Yes, by sacrificing b) and some of a).

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add IQ3_XXS to test-backend-ops like this:

diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 55ce14e0..3eec0554 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1481,6 +1481,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
         GGML_TYPE_Q6_K,
         GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
+        GGML_TYPE_IQ3_XXS,
     };
 
     // unary ops

@ikawrakow ikawrakow merged commit f4d7e54 into master Jan 30, 2024
51 of 54 checks passed
@ikawrakow ikawrakow deleted the ik/iq3_xxs branch January 30, 2024 13:14
@ggerganov
Copy link
Owner

ggerganov commented Jan 30, 2024

The IQ3_XXS tests are failing on Metal - looks like some floating-point rounding differences between the CPU and GPU:

  GET_ROWS(type=iq3_xxs,n=256,m=5,r=4,b=1,v=0): [GET_ROWS] NMSE = 0.000186960 > 0.000000100
    0  0.338793  0.338793, diff =  0.000000
    1 -0.203276 -0.203276, diff =  0.000000
    2 -0.067759 -0.067759, diff =  0.000000
    3 -0.474310 -0.474310, diff =  0.000000
    4  0.880861  0.880861, diff =  0.000000
    5 -0.474310 -0.474310, diff =  0.000000
    6 -0.880861 -0.880861, diff =  0.000000
    7 -0.474310 -0.474310, diff =  0.000000
    8  0.338793  0.338793, diff =  0.000000
    9  0.338793  0.338793, diff =  0.000000
   10  1.016378  1.050258, diff = -0.033879
   11  0.203276  0.203276, diff =  0.000000
   12 -0.609827 -0.609827, diff =  0.000000
   13  0.203276  0.203276, diff =  0.000000
   14 -1.016378 -1.050258, diff =  0.033879
   15  0.067759  0.067759, diff =  0.000000
   16  0.203276  0.203276, diff =  0.000000
   17 -0.880861 -0.880861, diff =  0.000000
   18  1.016378  1.050258, diff = -0.033879
   19  0.067759  0.067759, diff =  0.000000
   20 -0.067759 -0.067759, diff =  0.000000
   21 -0.474310 -0.474310, diff =  0.000000
   22  0.474310  0.474310, diff =  0.000000
   23 -0.880861 -0.880861, diff =  0.000000
   24 -0.745344 -0.745344, diff =  0.000000
   25 -0.067759 -0.067759, diff =  0.000000
   26 -0.203276 -0.203276, diff =  0.000000
   27  0.338793  0.338793, diff =  0.000000
   28  0.067759  0.067759, diff =  0.000000
   29  0.203276  0.203276, diff =  0.000000
   30 -0.338793 -0.338793, diff =  0.000000
   31  0.474310  0.474310, diff =  0.000000
   32 -0.474310 -0.474310, diff =  0.000000
   33 -0.203276 -0.203276, diff =  0.000000
   34  0.609827  0.609827, diff =  0.000000
   35  0.067759  0.067759, diff =  0.000000
   36  0.338793  0.338793, diff =  0.000000
   37 -0.067759 -0.067759, diff =  0.000000
   38  0.745344  0.745344, diff =  0.000000
   39 -1.016378 -1.050258, diff =  0.033879
   40 -0.067759 -0.067759, diff =  0.000000
   41  0.338793  0.338793, diff =  0.000000
   42 -0.745344 -0.745344, diff =  0.000000
   43  0.203276  0.203276, diff =  0.000000
   44 -0.609827 -0.609827, diff =  0.000000
   45 -0.609827 -0.609827, diff =  0.000000
   46 -0.203276 -0.203276, diff =  0.000000
   47 -0.474310 -0.474310, diff =  0.000000
   48  0.203276  0.203276, diff =  0.000000
   49  0.067759  0.067759, diff =  0.000000
   50 -0.609827 -0.609827, diff =  0.000000
   51  0.203276  0.203276, diff =  0.000000
   52  0.067759  0.067759, diff =  0.000000
   53 -0.745344 -0.745344, diff =  0.000000
   54  0.880861  0.880861, diff =  0.000000
   55  0.745344  0.745344, diff =  0.000000
   56 -0.203276 -0.203276, diff =  0.000000
   57 -0.745344 -0.745344, diff =  0.000000
   58  0.338793  0.338793, diff =  0.000000
   59  0.067759  0.067759, diff =  0.000000
   60  1.016378  1.050258, diff = -0.033879
   61 -0.067759 -0.067759, diff =  0.000000
   62  0.474310  0.474310, diff =  0.000000
   63 -0.609827 -0.609827, diff =  0.000000
   64  0.420103  0.420103, diff =  0.000000

Might want to take a look and if there is no way to make them match, we should probably disable the GET_ROWS test for IQ3_XXS

It's strange, because I'm pretty sure that earlier when I wrote to enable the tests, they passed on my Mac, but now they are failing and there haven't been any changes since then

@eramax
Copy link

eramax commented Jan 31, 2024

I cannot find any of the new SOTA 2 or 3 bit quants (e.g, IQ3_XXS or Q3_K_XS) available yet when i checked

➜  llama.cpp git:(master) quantize
usage: quantize [--help] [--allow-requantize] [--leave-output-tensor] [--pure] model-f32.gguf [model-quant.gguf] type [nthreads]

  --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit
  --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing
  --pure: Disable k-quant mixtures and quantize all tensors to the same type

Allowed quantization types:
   2  or  Q4_0   :  3.56G, +0.2166 ppl @ LLaMA-v1-7B
   3  or  Q4_1   :  3.90G, +0.1585 ppl @ LLaMA-v1-7B
   8  or  Q5_0   :  4.33G, +0.0683 ppl @ LLaMA-v1-7B
   9  or  Q5_1   :  4.70G, +0.0349 ppl @ LLaMA-v1-7B
  10  or  Q2_K   :  2.63G, +0.6717 ppl @ LLaMA-v1-7B
  12  or  Q3_K   : alias for Q3_K_M
  11  or  Q3_K_S :  2.75G, +0.5551 ppl @ LLaMA-v1-7B
  12  or  Q3_K_M :  3.07G, +0.2496 ppl @ LLaMA-v1-7B
  13  or  Q3_K_L :  3.35G, +0.1764 ppl @ LLaMA-v1-7B
  15  or  Q4_K   : alias for Q4_K_M
  14  or  Q4_K_S :  3.59G, +0.0992 ppl @ LLaMA-v1-7B
  15  or  Q4_K_M :  3.80G, +0.0532 ppl @ LLaMA-v1-7B
  17  or  Q5_K   : alias for Q5_K_M
  16  or  Q5_K_S :  4.33G, +0.0400 ppl @ LLaMA-v1-7B
  17  or  Q5_K_M :  4.45G, +0.0122 ppl @ LLaMA-v1-7B
  18  or  Q6_K   :  5.15G, -0.0008 ppl @ LLaMA-v1-7B
   7  or  Q8_0   :  6.70G, +0.0004 ppl @ LLaMA-v1-7B
   1  or  F16    : 13.00G              @ 7B
   0  or  F32    : 26.00G              @ 7B
          COPY   : only copy tensors, no quantizing

any idea how to quantize using IQ3_XXS ?

@JiHa-Kim
Copy link

JiHa-Kim commented Feb 3, 2024

any idea how to quantize using IQ3_XXS ?

ikawrakow listed the commands in the previous pull request for SOTA 2 bit quants:

./imatrix -m some_model -f wiki.train.raw -o imatrix_some_model.dat --chunks 100
./quantize --imatrix imatrix_some_model.dat some_model quantized_model.gguf iq2_xxs

I am guessing you replace iq2_xxs by iq3_xxs. I would appreciate a tutorial since, as a total beginner, I am also trying to use this but it is not working, so I am not sure how to do it.

Edit: I tried to use the method and it showed a perplexity of about 3200 for the model deepseek-ai/deepseek-coder-7b-base-v1.5
The code can be found here:
https://colab.research.google.com/drive/1qGmyHl2MXVScovQuu4vh2nheRXawZprj?usp=sharing
If someone could help, it would be appreciated. Thanks.

Edit #2: It turns out that you need to use a .imatrix extension rather than .dat now.

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Feb 3, 2024
* iq3_xxs: quantize/dequantize

RMSE seems a bit high-ish at about half-way between q2_K and
q3_K, so need to check more.

* iq3_xxs: CUDA dequantize works

* iq2_xxs: tuning quantization

* iq3_xxs: starting to look better

PPL on wiki.test.raw
LLaMA-v1-7B: 6.4218
LLaMA-v2-7B: 6.3560
Mistral-7B : 6.0717

This is better than Q3_K_XS, with a 5% reduction in quantized model
size.

* iq3_xxs: CUDA dot product

We have
PP-512: 5891 t/s
TG-128: 143.9 t/s

* iq3_xxs: scalar and AVX2 dot products

* iq3_xxs: ARM_NEON and Metal

Metal performance is decent, ARM_NEON is pathetic

* iq3_xxs: slightly better grid points

* Faster iq3_xxs and iq2_xs dot products on CUDA

* iq3_xxs: add some quant mix

* iq3_xxs: fix failing quantization test

Dot product still fails. Is this real?

* iq3_xxs: hopefully fix ROCm

* iq3_xxs: failing tests

This time the dot product accuracy did find an actual bug
in the AVX2 implementation.

* Add IQ3_XXS to test-backend-ops

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* iq3_xxs: quantize/dequantize

RMSE seems a bit high-ish at about half-way between q2_K and
q3_K, so need to check more.

* iq3_xxs: CUDA dequantize works

* iq2_xxs: tuning quantization

* iq3_xxs: starting to look better

PPL on wiki.test.raw
LLaMA-v1-7B: 6.4218
LLaMA-v2-7B: 6.3560
Mistral-7B : 6.0717

This is better than Q3_K_XS, with a 5% reduction in quantized model
size.

* iq3_xxs: CUDA dot product

We have
PP-512: 5891 t/s
TG-128: 143.9 t/s

* iq3_xxs: scalar and AVX2 dot products

* iq3_xxs: ARM_NEON and Metal

Metal performance is decent, ARM_NEON is pathetic

* iq3_xxs: slightly better grid points

* Faster iq3_xxs and iq2_xs dot products on CUDA

* iq3_xxs: add some quant mix

* iq3_xxs: fix failing quantization test

Dot product still fails. Is this real?

* iq3_xxs: hopefully fix ROCm

* iq3_xxs: failing tests

This time the dot product accuracy did find an actual bug
in the AVX2 implementation.

* Add IQ3_XXS to test-backend-ops

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants