Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ Quantization (3-bit and 4-bit) #9

Closed
MarkSchmidty opened this issue Mar 11, 2023 · 49 comments
Closed

GPTQ Quantization (3-bit and 4-bit) #9

MarkSchmidty opened this issue Mar 11, 2023 · 49 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@MarkSchmidty
Copy link

MarkSchmidty commented Mar 11, 2023

4-bit quantization tends to come at a cost of output quality losses. GPTQ quantization is a state of the art quantization method which results in negligible output performance loss when compared with the prior state of the art in 4-bit (and 3-bit/2-bit) quantization methods and even when compared with uncompressed fp16 inference.

image

It would be good to see benchmarks on the existing implementation. It's possible there is substantial quality loss from the 4-bit quantization. It's also possible that it isn't very substantial. We'd have to see benchmarks to know.

The related project GPTQ-for-LLaMA has some benchmarks available for their implementation.

Refernces:
GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers
The case for 4-bit precision: k-bit Inference Scaling Laws

Related work:
https://github.com/qwopqwop200/GPTQ-for-LLaMA/

@MarkSchmidty
Copy link
Author

The GitHub Issue for text-generation-webui's implementation of GPTQ-for-LLaMA may also be a helpful reference.

@dustydecapod
Copy link

There’s no real benefit to using this method for 4-bit. Would be neat to see 3-bit or 2-bit attempts though.

@MarkSchmidty
Copy link
Author

MarkSchmidty commented Mar 11, 2023

Well three is likely a minor, currently unknown (pending benchmarks), benefit to GPTQ for 4bit-- yes?

Additionally, once 4bit GPTQ is implemented 3bit and 2bit are not much additional work and could have much larger benefits to VRAM/URAM consumption on top of the current 4bit implementation with potentially very little (if acceptable) output quality loss.

@ggerganov ggerganov added enhancement New feature or request help wanted Extra attention is needed labels Mar 12, 2023
@MarkSchmidty
Copy link
Author

WebAssembly implementation is blocked pending 3-bit inference due to WASM's 4GB memory constraint.

@ggerganov
Copy link
Owner

I had a quick glance at the GPTQ paper yesterday, but haven't dug into details yet.

Do you think it is possible to demonstrate a simple routine for performing quantization using this method?
For example, what is the most trivial way (not necessary to be optimal) to implement a function like this:

// src - input 32-bit floats
// dst - output quantized data
// n - number of input floats
void quantize_gptq(float * src, void * dst, int n);

If I can get a prototype of this and it does not look too complex, I can try to plug it in ggml.
The main challenge will be to implement it efficiently with SIMD, but I need to see some initial implementation to work on.

@MarkSchmidty
Copy link
Author

@zoidbb or @qwopqwop200 might have an answer for the question above.

@blackhole89
Copy link
Collaborator

The actual quantization algorithm (spread across that file and another) seems to be a little hairy, and uses some nontrivial linear algebra (Cholesky decomposition) that we'd either have to reimplement or pull in another dependency for (LAPACK?).

However, if I read the CUDA kernels that they have for evaluation correctly, the format in which the quantized weights wind up is more or less equivalent to the Q4_1 mode (4-bit quantization with blockwise f16 zero offset) that we already have support for, though there currently is no AVX2 implementation for that mode. Since someone has uploaded precomputed GPTQ weights to Huggingface, it might be worthwhile to start out by implementing the right form of accelerated Q4_1 evaluation plus some to directly convert the Huggingface pickles into an appropriate format for us.

@Ayushk4
Copy link

Ayushk4 commented Mar 14, 2023

https://twitter.com/NolanoOrg/status/1635409631530057728

Q4_0 mode will be worse with GPTQ for 7B than current Round-to-nearest quantization approach.
However, in other cases it's better (only tested upto 13B models).

In Q4_1 and 13B it can not only reduce RAM (by changing bin size QK from 32 to higher - like 128), but also improve performance.

Also 3-bit 13B GPTQ will perform better than 7B at FP16.

Disclaimer - these were observed on a small subset of WikiText and Penn TreeBank (following GPTQ).

@Ayushk4
Copy link

Ayushk4 commented Mar 14, 2023

Also, the above has results on both Q4_0 and Q4_1.
GPTQ can be implemented for Q4_0 by hardcoding zero-offset at https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/9bbfdeda2c80c20d8bc1babf4d4368df574afe30/quant.py#L6

@qwopqwop200
Copy link

I have no understanding of GPTQ quantization algorithms.
However, for GPTQ quantization, refer to the following:
GPTQ
Quantizer
llama_sequential
For model inference, you can refer to:
vecquant4matmul_cuda and VecQuant4MatMulKernel
QuantLinear
For inference code, it has been implemented to replace Linear layer with QuantLinear.

@qwopqwop200
Copy link

According to this paper, 3 or 2 bit quantization is not a very good idea.
Also, unlike GPTQ, RTN can perform poorly on LLM.
image

@Ayushk4
Copy link

Ayushk4 commented Mar 14, 2023

That does not consider groupping/binning of less than 64 and data-dependent quantization with weight reconstruction - which is already being used now (with QK=32 --- bin of size 32)
https://arxiv.org/pdf/2206.09557.pdf, https://arxiv.org/pdf/2210.17323.pdf and GPTQ (table 4 - last row, table 6 in GPTQ) all show otherwise - promise of lower than 4-bit quantization when using groupping/binning.

@Ayushk4
Copy link

Ayushk4 commented Mar 14, 2023

RTN is the one currently being used and it is the go-to baseline (not the best and decent for int4). But empirically, when zeroOffset is fixed, then GPTQ does worse than RTN when done on 7B LLaMa.

@MarkSchmidty
Copy link
Author

MarkSchmidty commented Mar 14, 2023

I'm curious what your actual benchmark results were. A handful of use cases are blocked pending fitting 7B with inference into 4GB of RAM, including LLaMA in WebAssembly (which has 32bit addressing with 4GB max address space) and LLaMA on Raspberry Pi.

Aside from that, the benefits of GPTQ seem to go up with model size; like potentially enabling both 16GB and 32GB devices to move up one entire model size as well as increasingly better performance.

@Ayushk4
Copy link

Ayushk4 commented Mar 14, 2023

Depends on what you mean by benchmark:

  1. If you mean perplexity/loss then you can find graph and link to google docs containing the table in the blog: https://nolanoorg.substack.com/p/int-4-llama-is-not-enough-int-3-and
  2. If you mean running time - then that is still pending with int-3 quant and quant 4 with 128 bin size.

@Ayushk4
Copy link

Ayushk4 commented Mar 14, 2023

I should be out with atleast one of the two (int-3 quant or quant 4 kernels and their running time) by tomorrow - will share the code once I am done.

@MarkSchmidty
Copy link
Author

Those 3bit graphs look better than I expected actually. This is quite promising. Thanks for your contributions @Ayushk4!

@Ronsor
Copy link
Contributor

Ronsor commented Mar 14, 2023

I'm curious what your actual benchmark results were. A handful of use cases are blocked pending fitting 7B with inference into 4GB of RAM, including LLaMA in WebAssembly (which has 32bit addressing with 4GB max address space) and LLaMA on Raspberry Pi.

Aside from that, the benefits of GPTQ seem to go up with model size; like potentially enabling both 16GB and 32GB devices to move up one entire model size as well as increasingly better performance.

I'm curious if anyone has tried LLaMA on an 8GB RPi. If not, I might be the first.

@MarkSchmidty
Copy link
Author

I'm curious what your actual benchmark results were. A handful of use cases are blocked pending fitting 7B with inference into 4GB of RAM, including LLaMA in WebAssembly (which has 32bit addressing with 4GB max address space) and LLaMA on Raspberry Pi.
Aside from that, the benefits of GPTQ seem to go up with model size; like potentially enabling both 16GB and 32GB devices to move up one entire model size as well as increasingly better performance.

I'm curious if anyone has tried LLaMA on an 8GB RPi. If not, I might be the first.

Please post your results in the issue for that: #58
Lots of people will be interested to see how it goes and potentially help you work out how to get the best performance.

@qwopqwop200
Copy link

That does not consider groupping/binning - which is already being used now (with QK=32 --- bin of size 32) https://arxiv.org/pdf/2206.09557.pdf, https://arxiv.org/pdf/2210.17323.pdf and GPTQ (table 4 - last row, table 6 in GPTQ) all show otherwise - promise of lower than 4-bit quantization when using groupping/binning.

It sure looks interesting.
The reason I didn't consider grouping is because I failed to implement grouping with CUDA. Clearly, implementing grouping seems to have practical advantages.

@blackhole89
Copy link
Collaborator

blackhole89 commented Mar 14, 2023

I just committed in some code to get AVX2-accelerated Q4_1 inference into a new branch.

Since I've never written any sort of SIMD code before, this was a bit of a learning experience, and I can't guarantee its optimality. As it stands, I think it's somewhere around 50% slower than Q4_0 on the 7B model. (I get 300~400ms/tokm whereas I had 200~300 on Q4_0)

However, it's still significantly faster than the unvectorised implementation that was there before, which was more in the region of a second or two per token, and in fact seems to make the difference between Q4_1 being academic and tolerable.

@ggerganov @Const-me Care to take a look?

@Const-me
Copy link

@blackhole89 You can probably combine xsum/ysum into a single vector, like that:

__m256i xSumInt = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
__m256i ySumInt = _mm256_sad_epu8( by, _mm256_setzero_si256() );
__m256i sumInt = _mm256_or_si256( xSumInt, _mm256_slli_si256( ySumInt, 4 ) );
// Even lanes = xSumInt, odd lanes = ySumInt
__m256 sum = _mm256_cvtepi32_ps( sumInt );

Similarly, combine the multipliers:

const __m256 scale_01 = _mm256_blend_ps( scale_0, scale_1, 0b10101010 );

These instructions are cheap, but they are saving _mm256_cvtepi32_ps and _mm256_fmadd_ps, which are slightly more expensive. Might become slightly faster overall.

Another thing, you should definitely remove the * QK from that line: acc_offset += (*m0)*(*m1)*QK;
Instead, apply that * QK multiplier after the loop.
Scaling floats by powers of 2 is a lossless operation, anyway.
BTW, note there's std::fmaf in the standard library, in <cmath> header.

Also, you’re loading from these float pointers multiple times to get the same values.
I’m pretty sure recent versions of clang going to optimize these away, but there’re also older less optimal compilers out there.
I would rather load only once from each pointer with _mm256_broadcast_ss, then use _mm256_cvtss_f32 to accumulate the scalars.

@antonkoenig
Copy link

I did not expect the comments to contain generated code suggestions.

GPT-4 can fit about 9000 lines of code in a single input or output with in the 32k token context window version with its code optimized tokenizer. That is plenty enough to write fairly complex programs with multiple parts. To wit, llama.cpp is around 9k lines in its entirety irrc.

Sounds well suited for such tasks. GPT-3 is a bit limited in that regard. Thank you for the info.

@dorianr666
Copy link

dorianr666 commented Mar 20, 2023

Another example, showing how it suggests to implement it (let me know if these are wrong and I'll delete them! Just trying to be helpful and don't want to pollute the thread with useless stuff!):

#include "ggml.h"
#include "utils.h"

#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include <regex>
#include <Eigen/Dense>

using namespace Eigen;

// ... (existing code, same as before)

// Cholesky decomposition
MatrixXd cholesky_decomposition(MatrixXd &matrix) {
    LLT<MatrixXd> llt(matrix);
    return llt.matrixL();
}

// Cholesky inverse
MatrixXd cholesky_inverse(MatrixXd &matrix) {
    MatrixXd L = cholesky_decomposition(matrix);
    MatrixXd Linv = L.inverse();
    return Linv.transpose() * Linv;
}

// ... (existing code, same as before)

bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) {
    // ... (existing code, same as before)

    if (quantize) {
        // ... (existing code, same as before)

        // Apply Cholesky decomposition and Cholesky inverse
        MatrixXd m_f32(ne[0], ne[1]);
        for (int i = 0; i < ne[0]; ++i) {
            for (int j = 0; j < ne[1]; ++j) {
                m_f32(i, j) = data_f32[i * ne[1] + j];
            }
        }

        MatrixXd m_chol_inv = cholesky_inverse(m_f32);
        for (int i = 0; i < ne[0]; ++i) {
            for (int j = 0; j < ne[1]; ++j) {
                data_f32[i * ne[1] + j] = m_chol_inv(i, j);
            }
        }

        // ... (existing quantization code, same as before)
    }

    // ... (existing code, same as before)
}

// ... (existing code, same as before)

im definitely not a math expert, but ne[0] and ne[1] have different values and cholesky decomposition requires a square matrix, no?

@loretoparisi
Copy link

loretoparisi commented Mar 20, 2023

Hello Georgi! Thank you for your exellent job. There are 4 models that already have GPTQ Quantization. https://rentry.org/llama-tard-v2 https://huggingface.co/maderix/llama-65b-4bit/tree/main Meaybe you just convert them into your format?

Not sure but after converting HF 7B int4 GPTQ to ggml bin format:

python convert-pth-to-ggml.py models/7Bb4 1
{'dim': 4096, 'multiple_of': 256, 'n_heads': 32, 'n_layers': 32, 'norm_eps': 1e-06, 'vocab_size': -1}
n_parts = 1

Processing part 0

Processing variable: model.decoder.embed_tokens.weight with shape: torch.Size([32000, 4096]) and type: torch.float16
Processing variable: model.decoder.layers.0.self_attn.q_proj.zeros with shape: torch.Size([4096, 1]) and type: torch.float32
Processing variable: model.decoder.layers.0.self_attn.q_proj.scales with shape: torch.Size([4096, 1]) and type: torch.float32
Processing variable: model.decoder.layers.0.self_attn.q_proj.bias with shape: torch.Size([4096]) and type: torch.float32
  Converting to float32
...
Processing variable: lm_head.weight with shape: torch.Size([32000, 4096]) and type: torch.float16
Done. Output file: models/7Bb4/ggml-model-f16.bin, (part 0)

I'm getting an unknown tensor error while loading

./main -m ./models/7Bb4/ggml-model-f16.bin -t 4 -n 512

main: seed = 1679334253
llama_model_load: loading model from './models/7Bb4/ggml-model-f16.bin' - please wait ...
llama_model_load: n_vocab = 32000
llama_model_load: n_ctx   = 512
llama_model_load: n_embd  = 4096
llama_model_load: n_mult  = 256
llama_model_load: n_head  = 32
llama_model_load: n_layer = 32
llama_model_load: n_rot   = 128
llama_model_load: f16     = 1
llama_model_load: n_ff    = 11008
llama_model_load: n_parts = 1
llama_model_load: ggml ctx size = 13365.09 MB
llama_model_load: memory_size =   512.00 MB, n_mem = 16384
llama_model_load: loading model part 1/1 from './models/7Bb4/ggml-model-f16.bin'
llama_model_load: llama_model_load: unknown tensor 'model.decoder.embed_tokens.weight' in model file
main: failed to load model from './models/7Bb4/ggml-model-f16.bin'

@MarkSchmidty
Copy link
Author

Hello Georgi! Thank you for your exellent job. There are 4 models that already have GPTQ Quantization. https://rentry.org/llama-tard-v2 https://huggingface.co/maderix/llama-65b-4bit/tree/main Meaybe you just convert them into your format?

Not sure but after converting HF 7B int4 GPTQ to ggml bin format:

Unfortunately it is not that simple. GPTQ quantized weights are kind of compressed in a way.* The inference code needs to know how to "decompress" the GPTQ compression to run inference with them.

*Its technically not compression. But it is a good enough analogy for what is really going on.

@ggerganov
Copy link
Owner

See the work in #301 and the TODO in #362
GPTQ inference in ggml should be now possible thanks to this contribution

@KerfuffleV2
Copy link
Collaborator

If it's helpful, there are 3bit GPTQ quantized llama models here: https://huggingface.co/decapoda-research/llama-smallint-pt/tree/main

It has 7B, 13B and 30B models. I think they accidentally made the 4bit ones in that directory the same as the 3bit. I have a llama 30B 4bit model that's 18GB.

Would doing inference on 3bit models be a difficult change? The 30B model would actually fit in a 16GB RAM machine (12.75GB) and if my math is correct the 65B parameter model would fit in 32GB (27.62GB) which is pretty crazy.

@ggerganov
Copy link
Owner

Yeah, haven't looked how 3-bit works, but I think it will be a bit more difficult to achieve. Maybe at a later stage when we have the 4bit stuff properly integrated

@MarkSchmidty
Copy link
Author

If it's helpful, there are 3bit GPTQ quantized llama models here: huggingface.co/decapoda-research/llama-smallint-pt/tree/main

Just FYI I would not trust these models. They were converted with the very first initial commit version of GPTQ-for-llama and will likely cause numerous problems if they even behave slightly.

@redthing1
Copy link

Well, let's keep an eye out for a better conversion of a GPTQ HF model; once we have one is the procedure to convert it the same using the updated script?

@bvanslyke
Copy link

bvanslyke commented Mar 23, 2023

Existing converter script for gptq->ggml doesn't work for this model: https://huggingface.co/elinas/alpaca-30b-lora-int4/tree/main

> python3 convert-gptq-to-ggml.py models/alpaca-30b-lora-int4/alpaca-30b-4bit.pt models/alpaca-30b-lora-int4/tokenizer.model models/alpaca-30B-ggml/out.bin
Traceback (most recent call last):
  File "/Users/bvanslyke/Code/llama.cpp/convert-gptq-to-ggml.py", line 31, in <module>
    tokenizer = SentencePieceProcessor(fname_tokenizer)
  File "/Users/bvanslyke/Code/llama.cpp/venv/lib/python3.10/site-packages/sentencepiece/__init__.py", line 447, in Init
    self.Load(model_file=model_file, model_proto=model_proto)
  File "/Users/bvanslyke/Code/llama.cpp/venv/lib/python3.10/site-packages/sentencepiece/__init__.py", line 905, in Load
    return self.LoadFromFile(model_file)
  File "/Users/bvanslyke/Code/llama.cpp/venv/lib/python3.10/site-packages/sentencepiece/__init__.py", line 310, in LoadFromFile
    return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
RuntimeError: Internal: /Users/runner/work/sentencepiece/sentencepiece/src/sentencepiece_processor.cc(1102) [model_proto->ParseFromArray(serialized.data(), serialized.size())] 

(I know it's early days and this is just some random model from HF, but just posting in case it's helpful data or shows a new failure case 🤷)

@anzz1
Copy link
Contributor

anzz1 commented Mar 23, 2023

the 4-bit gptq models seem to work fine in llama.cpp and anecdotally produce marginally better results, however i havent done any proper perplexity testing or such yet.

also i cannot run 65b properly because i run out of ram.

if someone with better pc want to try 4b 65b gptq #382 (comment) i would be interested how that works out

@loretoparisi
Copy link

the 4-bit gptq models seem to work fine in llama.cpp and anecdotally produce marginally better results, however i havent done any proper perplexity testing or such yet.

also i cannot run 65b properly because i run out of ram.

if someone with better pc want to try 4b 65b gptq #382 (comment) i would be interested how that works out

the integration in llama.cpp of gptq is not ready yet, you cannot load it directly

@anzz1
Copy link
Contributor

anzz1 commented Mar 24, 2023

the integration in llama.cpp of gptq is not ready yet, you cannot load it directly

See #9 (comment). Seems to work just fine.

@Orevantum
Copy link

Seems, there is superior GPTQ approach already exist (from authors of original GPTQ paper). @ggerganov, @blackhole89 care to take a look?

abetlen pushed a commit to abetlen/llama.cpp that referenced this issue Apr 10, 2023
…_linux

Adds instructions and works on linux as well
@sw
Copy link
Collaborator

sw commented May 12, 2023

From @Tom-Neverwinter in #1411:

https://github.com/0cc4m/GPTQ-for-LLaMa/blob/latestmerge/gptq/mpt.py
last week in an issue @ggerganov was looking for someone to write the C++ for gptq and 4bit. this appears to be a working implementation.

@jjbenes
Copy link

jjbenes commented May 24, 2023

@ggerganov Thanks for this code. I enjoy running it on my computer. I have an observation on compatibility with GPTQ, or any asymmetric uniform quantizers. Don't we want the number zero preserved for neural nets? I expect dequantize(quantize(0)) to return 0, but I don't believe q4_1 does.

The q4_1 quantization code and the q4_1 dequantization code don't seem to preserve zero. The floating-point bias (m in the code) could make a quantized 0 non-zero when dequantized. This bias may also flip the sign of a dequantized number, which changes the threshold for ReLU. See the second and the third elements in the array x in the Python interpretation of the C code below. The second element is 0 but becomes 0.16 after dequantization (xhat[1]). So does the third element, -0.12. For -0.12 in x, the sign flips in xhat[2].

You must have a good reason to use a floating-point bias, and I may have misread the code. I'm also aware of the better perplexity due to q4_1 vs. q4_0 (symmetric).

I wonder what your thoughts are regarding q4_1 vs. other asymmetric quantizers, which evaluate Q(x) = clamp( round(x/s) + zero_point, Q_min, Q_max), where s is a floating-point scaling factor, and zero_point is an index offset that recovers zero exactly after quantization followed by dequantization.

import numpy as np

def quant_4_1(x):
  x_min, x_max = x.min(), x.max()
  assert x_min <=0 <= x_max
  d = (x_max - x_min)/15
  id = 1.0/d if d else 0.0
  x_int = ((x - x_min) * id + 0.5).astype(int) # should use round?
  x_int[x_int>15.0] = 15
  return x_int, d, x_min

def dequant_4_1(x_q, d, m):
  return x_q*d + m

x = np.array([-1.25, 0, -0.012, 3.5])
xq, d, x_min = quant_4_1(x)
xhat = dequant_4_1(xq, d, x_min)

print(f'x = {x}\nd = {d}\nm = {x_min}')
print(f'xq = {xq}\nxhat = {xhat}')

x = [-1.25   0.    -0.012  3.5  ]
d = 0.31666666666666665
m = -1.25
xq = [ 0  4  4 15]
xhat = [-1.25        0.01666667  0.01666667  3.5       ]

@ggerganov
Copy link
Owner

@jjbenes

Don't we want the number zero preserved for neural nets?

Not sure we want such requirement.

The quantized values are used only during matrix multiplication. They are not directly used to compute the activations

I could be wrong, but my intuition so far from the experiments is that quantization quality mainly depends on the amount of data that you effectively use. The specific approach for representing that data has second-order effects.
This does not mean that Q4_1 is optimal - in fact it is wasting non-negligible amount of bits to store the 16-bit scaling factors. In reality, the information there is less than 16-bits

In any case, it is very easy to add a zero-preserving quantization and see the effect on the perplexity.
We should do this at some point and it's probably a nice exercise to play with quantizations.
Doesn't have to be efficiently implemented

@jjbenes
Copy link

jjbenes commented May 27, 2023

Don't we want the number zero preserved for neural nets?

Not sure we want such requirement.

The quantized values are used only during matrix multiplication. They are not directly used to compute the activations

Yes, I understand that only the inputs to the matrix multiplication are converted from the quantized format to what the hardware supports (for example, fp16 or even fp32). In fact, I believe only the weights are converted. My question is less about optimality and more about compatibility with other asymmetric quantizers vs. q4_1. Perhaps the following pseudo code helps show what I'm thinking about.

w = dequantize(w_q4_1)) # What happens if elements in w change signs, or zeros become non-zeros?
y = matmul(w, x) # Let me know if I understand the code correctly. Only weights (w) are quantized to q4_*, not activations (x and y).
act = relu(y) # Any sign flipping in w adds noise to the dot products. Unclear how often the output is suppressed to 0 or leaks through

The fp16 perplexity baselines are different between GPTQ and llama.cpp. Has someone already compared GPTQ and llama.cpp using the same baseline?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests