Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AWQ quantized models #781

Closed
0x1997 opened this issue Aug 7, 2023 · 37 comments · Fixed by #1019
Closed

Add support for AWQ quantized models #781

0x1997 opened this issue Aug 7, 2023 · 37 comments · Fixed by #1019

Comments

@0x1997
Copy link

0x1997 commented Aug 7, 2023

Compared to GPTQ, AWQ is more accurate and has much better inference performance.

Benchmark: https://github.com/lm-sys/FastChat/blob/main/docs/awq.md#benchmark

Note: Multi-Query Attention is not yet supported.

@abhinavkulkarni
Copy link
Contributor

I have released a bunch of AWQ quantized models here: https://huggingface.co/abhinavkulkarni?sort_models=downloads#models

Instructions on how to run these with HuggingFace API are in the model cards.

@Narsil
Copy link
Collaborator

Narsil commented Aug 7, 2023

Can anyone run benchmarks against TGI + exllama kernels ?

Those are supposed to provide a similar speedup.

We don't want to support every quantization scheme in TGI, just the best possible subset:

  • No quantization: best PPL
  • bitsandbytes: Low vram - no quantization steps - works on every model
  • GPTQ: Low VRAM - fasted inference (should be ~2x if I'm not mistaken) with exllama.

@abhinavkulkarni
Copy link
Contributor

@Narsil: Any kernel optimizations done for GPTQ should translate to AWQ since they both are based on similar zero-point quantization schemes - they simply differ on how those exact zero-point weights are found and admittedly AWQ is superior to GPTQ.

So, someone needs to simply write a "translator" from AWQ to GPTQ state dicts and everything else should work as is.

@edwardzjl
Copy link
Contributor

@Narsil I agree that we should evaluate before adding another quantization support, but it's difficult to perform a fair comparison.

A fair comparison should be performed between TGI + gptq + exllama and TGI + awq, but not between TGI + gptq + exllama and {some_other_inference_framework} + awq

Besides, awq can be applied to LLMs other than Llama, where exllama cannot.

@Narsil
Copy link
Collaborator

Narsil commented Aug 7, 2023

Besides, awq can be applied to LLMs other than Llama, where exllama cannot.

This statement is wrong exllama is just because the kernels where created here: https://github.com/turboderp/exllama it has nothing to do with llama. (In general current quantization techniques juste replace Linear with QuantLinear basically).

@edwardzjl
Copy link
Contributor

Besides, awq can be applied to LLMs other than Llama, where exllama cannot.

This statement is wrong exllama is just because the kernels where created here: https://github.com/turboderp/exllama it has nothing to do with llama. (In general current quantization techniques juste replace Linear with QuantLinear basically).

My apologies, I made an assumption about exllama based solely on its name. I mistakenly thought it was specifically for Llama models. 😅

@Narsil
Copy link
Collaborator

Narsil commented Aug 8, 2023

No worries.

@casper-hansen
Copy link

My own benchmark from AWQ is 134 tokens/s (7.46 ms/token) on a 4090+i9-13900k for MPT 7B models.

As Narsil mentions, quantization methods mostly replace Linear with QuantLinear layers. AWQ does this with their optimized GEMM kernel. Additionally, AWQ Tinychat runs the following optimizations for LLaMa models specifically:

LLaMa models are 100+ tokens/s.

Why AWQ is faster than GPTQ

AWQ is faster than GPTQ. It is not faster than exllama because exllama runs a lot of kernel optimizations on top to make it faster. But the problem is that exllama is written explicitly to optimize LLaMa models, so the full performance boost will not be seen in other models.

From the AWQ paper:

Different weight channels have different importance; updating the salient channels to compensate for the non-salient ones will likely destroy the performance. Reordering prevents it by quantizing important channels first. However, it will lead to bad hardware efficiency due to irregular memory access (Figure 2), while our scaling method does not suffer from the issue.

image

@abhinavkulkarni
Copy link
Contributor

abhinavkulkarni commented Aug 11, 2023

Hi,

I have added rudimentary support for AWQ models at https://github.com/abhinavkulkarni/text-generation-inference/tree/abhinavkulkarni/add-awq-support

You can view the side-by-side changes here.

This requires installing AWQ library and CUDA kernels for 4-bit matrix multiplication:

git clone https://github.com/mit-han-lab/llm-awq \
&& cd llm-awq \
&& git checkout ce4a6bb1c238c014a06672cb74f6865573494d66 \
&& pip install -e . \
&& cd awq/kernels \
&& python setup.py install

After that

git clone https://github.com/abhinavkulkarni/text-generation-inference.git \
&& cd text-generation-inference \
&& git checkout abhinavkulkarni/add-awq-support \
&& make install

I did upgrade to the latest versions: pip install --upgrade transformers accelerate bitsandbytes

I was able to run TGI as follows:

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq \
--trust-remote-code \
--port 8080 \
--max-input-length 4000 --max-total-tokens 4096 \
--quantize awq

This change of course borrows from AWQ library and for zero-point quantization, I use their WQLinear layer which is very similar to QuantLinear layer in TGI. For now, I have hardcoded the values of bits to 4 and groupsize to 128, but should be possible to read them off quantize_config.json. None of my models have quantize.json yet, but I'll update the model repos with one.

I don't think this change is comprehensive and I would welcome any pull requests.

The ideal scenario would be to subsume the logic of WQLinear from AWQ into QuantLinear of TGI, so that we can benefit from flash attention goodness.

Thanks!

CC: @casperbh96, @Narsil, @TheBloke

@abhinavkulkarni
Copy link
Contributor

abhinavkulkarni commented Aug 12, 2023

I benchmarked Llama 2 7B AWQ vs GPTQ with FlashAttention v1 and vLLM on RTX 3060 (12GB of VRAM). Note, I do not have exllama installed. Following are the results:

AWQ model_id: abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq
GPTQ model_id: TheBloke/Llama-2-7b-Chat-GPTQ

Both models were run with --max-input-length 4000 --max-total-tokens 4096

GPTQ benchmarks:

| Parameter          | Value                         |
|--------------------|-------------------------------|
| Model              | TheBloke/Llama-2-7b-Chat-GPTQ |
| Sequence Length    | 10                            |
| Decode Length      | 8                             |
| N Runs             | 10                            |
| Warmups            | 1                             |
| Temperature        | None                          |
| Top K              | None                          |
| Top P              | None                          |
| Typical P          | None                          |
| Repetition Penalty | None                          |
| Watermark          | false                         |
| Do Sample          | false                         |


| Step           | Batch Size | Average   | Lowest    | Highest   | p50       | p90       | p99       |
|----------------|------------|-----------|-----------|-----------|-----------|-----------|-----------|
| Prefill        | 1          | 41.23 ms  | 41.16 ms  | 41.39 ms  | 41.22 ms  | 41.39 ms  | 41.39 ms  |
|                | 2          | 47.80 ms  | 47.72 ms  | 47.88 ms  | 47.81 ms  | 47.88 ms  | 47.88 ms  |
|                | 4          | 57.94 ms  | 57.83 ms  | 58.02 ms  | 57.95 ms  | 58.02 ms  | 58.02 ms  |
|                | 8          | 108.53 ms | 108.39 ms | 108.81 ms | 108.56 ms | 108.81 ms | 108.81 ms |
|                | 16         | 153.65 ms | 153.22 ms | 156.46 ms | 153.35 ms | 156.46 ms | 156.46 ms |
|                | 32         | 251.93 ms | 251.04 ms | 252.23 ms | 252.05 ms | 252.23 ms | 252.23 ms |
| Decode (token) | 1          | 40.33 ms  | 40.27 ms  | 40.45 ms  | 40.32 ms  | 40.32 ms  | 40.32 ms  |
|                | 2          | 40.83 ms  | 40.80 ms  | 40.90 ms  | 40.84 ms  | 40.82 ms  | 40.82 ms  |
|                | 4          | 41.07 ms  | 40.81 ms  | 41.15 ms  | 41.10 ms  | 40.81 ms  | 40.81 ms  |
|                | 8          | 41.28 ms  | 41.25 ms  | 41.34 ms  | 41.28 ms  | 41.29 ms  | 41.29 ms  |
|                | 16         | 48.03 ms  | 47.92 ms  | 48.22 ms  | 48.04 ms  | 47.95 ms  | 47.95 ms  |
|                | 32         | 59.45 ms  | 59.35 ms  | 59.65 ms  | 59.42 ms  | 59.65 ms  | 59.65 ms  |
| Decode (total) | 1          | 282.34 ms | 281.92 ms | 283.14 ms | 282.27 ms | 282.25 ms | 282.25 ms |
|                | 2          | 285.83 ms | 285.61 ms | 286.33 ms | 285.86 ms | 285.76 ms | 285.76 ms |
|                | 4          | 287.48 ms | 285.70 ms | 288.08 ms | 287.68 ms | 285.70 ms | 285.70 ms |
|                | 8          | 288.99 ms | 288.73 ms | 289.37 ms | 288.97 ms | 289.00 ms | 289.00 ms |
|                | 16         | 336.21 ms | 335.45 ms | 337.57 ms | 336.28 ms | 335.63 ms | 335.63 ms |
|                | 32         | 416.15 ms | 415.43 ms | 417.57 ms | 415.96 ms | 417.57 ms | 417.57 ms |


| Step    | Batch Size | Average            | Lowest             | Highest            |
|---------|------------|--------------------|--------------------|--------------------|
| Prefill | 1          | 24.25 tokens/secs  | 24.16 tokens/secs  | 24.30 tokens/secs  |
|         | 2          | 41.84 tokens/secs  | 41.77 tokens/secs  | 41.92 tokens/secs  |
|         | 4          | 69.04 tokens/secs  | 68.94 tokens/secs  | 69.17 tokens/secs  |
|         | 8          | 73.71 tokens/secs  | 73.52 tokens/secs  | 73.81 tokens/secs  |
|         | 16         | 104.14 tokens/secs | 102.26 tokens/secs | 104.43 tokens/secs |
|         | 32         | 127.02 tokens/secs | 126.87 tokens/secs | 127.47 tokens/secs |
| Decode  | 1          | 24.79 tokens/secs  | 24.72 tokens/secs  | 24.83 tokens/secs  |
|         | 2          | 48.98 tokens/secs  | 48.89 tokens/secs  | 49.02 tokens/secs  |
|         | 4          | 97.40 tokens/secs  | 97.20 tokens/secs  | 98.00 tokens/secs  |
|         | 8          | 193.78 tokens/secs | 193.53 tokens/secs | 193.95 tokens/secs |
|         | 16         | 333.13 tokens/secs | 331.78 tokens/secs | 333.88 tokens/secs |
|         | 32         | 538.27 tokens/secs | 536.43 tokens/secs | 539.21 tokens/secs |

AWQ benchmarks:

| Parameter          | Value                                                     |
|--------------------|-----------------------------------------------------------|
| Model              | abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq |
| Sequence Length    | 10                                                        |
| Decode Length      | 8                                                         |
| N Runs             | 10                                                        |
| Warmups            | 1                                                         |
| Temperature        | None                                                      |
| Top K              | None                                                      |
| Top P              | None                                                      |
| Typical P          | None                                                      |
| Repetition Penalty | None                                                      |
| Watermark          | false                                                     |
| Do Sample          | false                                                     |

| Step           | Batch Size | Average   | Lowest    | Highest   | p50       | p90       | p99       |
|----------------|------------|-----------|-----------|-----------|-----------|-----------|-----------|
| Prefill        | 1          | 18.84 ms  | 18.70 ms  | 19.25 ms  | 18.75 ms  | 19.25 ms  | 19.25 ms  |
|                | 2          | 31.18 ms  | 31.05 ms  | 31.37 ms  | 31.19 ms  | 31.37 ms  | 31.37 ms  |
|                | 4          | 46.88 ms  | 46.63 ms  | 47.25 ms  | 46.90 ms  | 47.25 ms  | 47.25 ms  |
|                | 8          | 78.74 ms  | 78.44 ms  | 79.09 ms  | 78.81 ms  | 79.09 ms  | 79.09 ms  |
|                | 16         | 154.59 ms | 154.09 ms | 154.96 ms | 154.75 ms | 154.96 ms | 154.96 ms |
|                | 32         | 308.17 ms | 307.61 ms | 308.79 ms | 308.21 ms | 308.79 ms | 308.79 ms |
| Decode (token) | 1          | 16.21 ms  | 16.11 ms  | 16.69 ms  | 16.14 ms  | 16.69 ms  | 16.69 ms  |
|                | 2          | 16.62 ms  | 16.54 ms  | 16.80 ms  | 16.63 ms  | 16.80 ms  | 16.80 ms  |
|                | 4          | 17.28 ms  | 17.18 ms  | 17.42 ms  | 17.31 ms  | 17.42 ms  | 17.42 ms  |
|                | 8          | 18.56 ms  | 18.52 ms  | 18.61 ms  | 18.56 ms  | 18.61 ms  | 18.61 ms  |
|                | 16         | 22.51 ms  | 21.77 ms  | 28.57 ms  | 21.86 ms  | 28.57 ms  | 28.57 ms  |
|                | 32         | 37.61 ms  | 37.58 ms  | 37.67 ms  | 37.61 ms  | 37.67 ms  | 37.67 ms  |
| Decode (total) | 1          | 113.47 ms | 112.78 ms | 116.80 ms | 113.01 ms | 116.80 ms | 116.80 ms |
|                | 2          | 116.37 ms | 115.81 ms | 117.60 ms | 116.43 ms | 117.60 ms | 117.60 ms |
|                | 4          | 120.99 ms | 120.27 ms | 121.94 ms | 121.15 ms | 121.94 ms | 121.94 ms |
|                | 8          | 129.91 ms | 129.65 ms | 130.25 ms | 129.91 ms | 130.25 ms | 130.25 ms |
|                | 16         | 157.60 ms | 152.36 ms | 199.98 ms | 153.04 ms | 199.98 ms | 199.98 ms |
|                | 32         | 263.28 ms | 263.03 ms | 263.70 ms | 263.27 ms | 263.70 ms | 263.70 ms |

| Step    | Batch Size | Average            | Lowest             | Highest            |
|---------|------------|--------------------|--------------------|--------------------|
| Prefill | 1          | 53.09 tokens/secs  | 51.94 tokens/secs  | 53.49 tokens/secs  |
|         | 2          | 64.14 tokens/secs  | 63.75 tokens/secs  | 64.41 tokens/secs  |
|         | 4          | 85.32 tokens/secs  | 84.66 tokens/secs  | 85.78 tokens/secs  |
|         | 8          | 101.60 tokens/secs | 101.16 tokens/secs | 101.99 tokens/secs |
|         | 16         | 103.50 tokens/secs | 103.25 tokens/secs | 103.83 tokens/secs |
|         | 32         | 103.84 tokens/secs | 103.63 tokens/secs | 104.03 tokens/secs |
| Decode  | 1          | 61.70 tokens/secs  | 59.93 tokens/secs  | 62.07 tokens/secs  |
|         | 2          | 120.31 tokens/secs | 119.05 tokens/secs | 120.89 tokens/secs |
|         | 4          | 231.43 tokens/secs | 229.62 tokens/secs | 232.81 tokens/secs |
|         | 8          | 431.08 tokens/secs | 429.94 tokens/secs | 431.94 tokens/secs |
|         | 16         | 715.32 tokens/secs | 560.06 tokens/secs | 735.11 tokens/secs |
|         | 32         | 850.79 tokens/secs | 849.43 tokens/secs | 851.62 tokens/secs |

Thanks!

@Narsil
Copy link
Collaborator

Narsil commented Aug 12, 2023

@abhinavkulkarni Can you try with exllama please?

It looks very promising !

@abhinavkulkarni
Copy link
Contributor

abhinavkulkarni commented Aug 12, 2023

Hey @Narsil,

I am unable to install Exllama GPTQ kernels even when I run BUILD_EXTENSIONS=True make install.

Do I need to install them separately?

Edit:

I installed exllama kernels by cd server/exllama_kernels and python setup.py install.

I do see log lines while loading the server:

2023-08-12T18:03:38.108460Z INFO text_generation_launcher: Using exllama kernels

I get worse results than before for GPTQ:

| Parameter          | Value                         |
|--------------------|-------------------------------|
| Model              | TheBloke/Llama-2-7b-Chat-GPTQ |
| Sequence Length    | 10                            |
| Decode Length      | 8                             |
| N Runs             | 10                            |
| Warmups            | 1                             |
| Temperature        | None                          |
| Top K              | None                          |
| Top P              | None                          |
| Typical P          | None                          |
| Repetition Penalty | None                          |
| Watermark          | false                         |
| Do Sample          | false                         |


| Step           | Batch Size | Average   | Lowest    | Highest   | p50       | p90       | p99       |
|----------------|------------|-----------|-----------|-----------|-----------|-----------|-----------|
| Prefill        | 1          | 54.29 ms  | 54.10 ms  | 54.52 ms  | 54.24 ms  | 54.52 ms  | 54.52 ms  |
|                | 2          | 58.94 ms  | 58.84 ms  | 59.08 ms  | 58.92 ms  | 59.08 ms  | 59.08 ms  |
|                | 4          | 68.53 ms  | 68.19 ms  | 68.76 ms  | 68.61 ms  | 68.76 ms  | 68.76 ms  |
|                | 8          | 102.44 ms | 102.32 ms | 102.63 ms | 102.43 ms | 102.63 ms | 102.63 ms |
|                | 16         | 143.92 ms | 143.65 ms | 144.09 ms | 143.99 ms | 144.09 ms | 144.09 ms |
|                | 32         | 227.84 ms | 227.70 ms | 228.07 ms | 227.82 ms | 228.07 ms | 228.07 ms |
| Decode (token) | 1          | 31.17 ms  | 30.53 ms  | 36.57 ms  | 30.58 ms  | 30.56 ms  | 30.56 ms  |
|                | 2          | 33.92 ms  | 33.88 ms  | 33.97 ms  | 33.93 ms  | 33.91 ms  | 33.91 ms  |
|                | 4          | 41.06 ms  | 40.83 ms  | 41.31 ms  | 41.23 ms  | 41.07 ms  | 41.07 ms  |
|                | 8          | 54.19 ms  | 54.14 ms  | 54.25 ms  | 54.20 ms  | 54.18 ms  | 54.18 ms  |
|                | 16         | 59.27 ms  | 59.18 ms  | 59.45 ms  | 59.25 ms  | 59.41 ms  | 59.41 ms  |
|                | 32         | 70.56 ms  | 70.50 ms  | 70.62 ms  | 70.56 ms  | 70.62 ms  | 70.62 ms  |
| Decode (total) | 1          | 218.16 ms | 213.71 ms | 256.01 ms | 214.03 ms | 213.94 ms | 213.94 ms |
|                | 2          | 237.45 ms | 237.13 ms | 237.81 ms | 237.50 ms | 237.37 ms | 237.37 ms |
|                | 4          | 287.43 ms | 285.84 ms | 289.14 ms | 288.59 ms | 287.47 ms | 287.47 ms |
|                | 8          | 379.34 ms | 379.00 ms | 379.73 ms | 379.44 ms | 379.29 ms | 379.29 ms |
|                | 16         | 414.88 ms | 414.25 ms | 416.15 ms | 414.76 ms | 415.86 ms | 415.86 ms |
|                | 32         | 493.91 ms | 493.49 ms | 494.36 ms | 493.90 ms | 494.36 ms | 494.36 ms |


| Step    | Batch Size | Average            | Lowest             | Highest            |
|---------|------------|--------------------|--------------------|--------------------|
| Prefill | 1          | 18.42 tokens/secs  | 18.34 tokens/secs  | 18.48 tokens/secs  |
|         | 2          | 33.93 tokens/secs  | 33.85 tokens/secs  | 33.99 tokens/secs  |
|         | 4          | 58.37 tokens/secs  | 58.18 tokens/secs  | 58.65 tokens/secs  |
|         | 8          | 78.10 tokens/secs  | 77.95 tokens/secs  | 78.19 tokens/secs  |
|         | 16         | 111.18 tokens/secs | 111.04 tokens/secs | 111.38 tokens/secs |
|         | 32         | 140.45 tokens/secs | 140.31 tokens/secs | 140.54 tokens/secs |
| Decode  | 1          | 32.18 tokens/secs  | 27.34 tokens/secs  | 32.75 tokens/secs  |
|         | 2          | 58.96 tokens/secs  | 58.87 tokens/secs  | 59.04 tokens/secs  |
|         | 4          | 97.42 tokens/secs  | 96.84 tokens/secs  | 97.96 tokens/secs  |
|         | 8          | 147.62 tokens/secs | 147.47 tokens/secs | 147.76 tokens/secs |
|         | 16         | 269.96 tokens/secs | 269.13 tokens/secs | 270.37 tokens/secs |
|         | 32         | 453.53 tokens/secs | 453.11 tokens/secs | 453.91 tokens/secs |

@0x1997
Copy link
Author

0x1997 commented Aug 21, 2023

Thanks for @abhinavkulkarni‘s code, I did some simple evaluation. The output quality of AWQ model is a little bit worse than GPTQ. But the 60% speedup at inference is quite nice.

My branch is at https://github.com/0x1997/text-generation-inference/tree/awq.

Currently multi GPU support is broken, the model generates garbled outputs like this. Do you have any idea how to fix this? @abhinavkulkarni

re,:,\\\\.,,,,,,,,,,,,,,,,,,,,,,, a rad, k,,,,,,,,,,,,,,,,,,,e\\\\ the the they, I have,,,,,,,,,, the\\\\ and, I\\\\,,,,,,,,,,,.,\\\\anded the the\\\\\\\\\\\\,,,. the thesers,\\\\\\\\ and, ap\\\\\\\\\\\\\\\\ the, a\\\\ and.\\\\\\\\\\\\\\\\.,,,,,,,,,,,, a I,,,:,\\\\, avision, aon,,,,,,,,,,,, in a a the the ,,,,,,,,,,,, a, a a the the ,,,,,,- the the the the\\\\ made\\,, the k,,,,,,.,,,,,a a,,,,,,, the, a a,her it,,,,,,, a a,,,,,, the it,,, theo and., you., the,

@sjzhou4
Copy link

sjzhou4 commented Aug 31, 2023

@abhinavkulkarni #948
This issue is tried by your awq method, can you help to solve this problem? Thank you.

@MichaelHauser0971
Copy link

Thanks for @abhinavkulkarni‘s code, I did some simple evaluation. The output quality of AWQ model is a little bit worse than GPTQ. But the 60% speedup at inference is quite nice.

My branch is at https://github.com/0x1997/text-generation-inference/tree/awq.

Currently multi GPU support is broken, the model generates garbled outputs like this. Do you have any idea how to fix this? @abhinavkulkarni

re,:,\\\\.,,,,,,,,,,,,,,,,,,,,,,, a rad, k,,,,,,,,,,,,,,,,,,,e\\\\ the the they, I have,,,,,,,,,, the\\\\ and, I\\\\,,,,,,,,,,,.,\\\\anded the the\\\\\\\\\\\\,,,. the thesers,\\\\\\\\ and, ap\\\\\\\\\\\\\\\\ the, a\\\\ and.\\\\\\\\\\\\\\\\.,,,,,,,,,,,, a I,,,:,\\\\, avision, aon,,,,,,,,,,,, in a a the the ,,,,,,,,,,,, a, a a the the ,,,,,,- the the the the\\\\ made\\,, the k,,,,,,.,,,,,a a,,,,,,, the, a a,her it,,,,,,, a a,,,,,, the it,,, theo and., you., the,

Have you solved this problem? I also encountered the same problem.

@dingjingzhen
Copy link

dingjingzhen commented Sep 13, 2023

Can anyone run benchmarks against TGI + exllama kernels ?

Those are supposed to provide a similar speedup.

We don't want to support every quantization scheme in TGI, just the best possible subset:

  • No quantization: best PPL
  • bitsandbytes: Low vram - no quantization steps - works on every model
  • GPTQ: Low VRAM - fasted inference (should be ~2x if I'm not mistaken) with exllama.

This works very good, I measured almost no ppl loss, but the performance is faster than gptq
#1018

@dingjingzhen
Copy link

Try this solution of ours, the best ppl with faster performance than gptq
#1018
Below is our test on 3090. environment: torch=2.01, cuda=11.8, nvidia driver: 525.78.01
prompt=1024, max_new_tokens=50
image

@abhinavkulkarni
Copy link
Contributor

@Narsil: I have opened a PR #1019 for adding AWQ support for FlashLlama models. Please take a look. Please refer to earlier replies from me for benchmarking results against GPTQ.

@MichaelHauser0971, @sjzhou4, @0x1997, @casper-hansen: I have not yet tested multi-GPU setup, lets first try to get approval for single-GPU PR.

@ryanshrott
Copy link

@abhinavkulkarni trying to catch up on this thread. How can I run llama2 AWQ or GPTQ with VLLM? Is it possible yet?

@casper-hansen
Copy link

@ryanshrott vLLM support for AWQ is close to being merged, check their branch out:
https://github.com/vllm-project/vllm/tree/add_awq_quant_support

@abhinavkulkarni
Copy link
Contributor

@ryanshrott: Please check the PR I have raised. It runs a Llama 2 model with FlashAttention v2 and vLLM.

@ryanshrott
Copy link

What's the timeline on merging to main branch?

@Narsil
Copy link
Collaborator

Narsil commented Sep 22, 2023

I ran some tests and the PR is very close to ready. If OP doesn't want to make the changes I'll do them in a few days.
Reactivity is a bit lower from us, we have some nice things cooking, please bear with us.

@TheBloke
Copy link

TheBloke commented Sep 22, 2023

Great to hear! I've uploaded plenty of models to be used with it :)
image

My READMEs linked to this PR and mentioned support was coming 'soon'. Once this is merged I can update them all to include TGI details.

@ryanshrott
Copy link

ryanshrott commented Sep 22, 2023

Will this PR have comparable speeds to regular non-quantized models?

I currently find AWQ quantization with VLLM to run very slow.

@casper-hansen
Copy link

casper-hansen commented Sep 22, 2023

INT4 throughput will not be higher than FP16 at very high data parallelism. For that, you must use INT8 or FP16. High batch sizes means that you are compute bound and INT4 is not made for this scenario.

@ryanshrott
Copy link

@casper-hansen I'm not quite following all your technical notes. Are you saying that 4-bit awq will be fast or not?

@casper-hansen
Copy link

Yes, it can be much faster. But like I just explained, there are cases where it will not be faster. It depends on your use-case.

@RonanKMcGovern
Copy link

Will this approach default to using GEMM? or is there a parameters where one can configure GEMV or fp16?

I may be missing something in the code updates, I just didn't find any reference to GEMM.

Also, great work on this.

@casper-hansen
Copy link

GEMV is only faster at batch size 1 with a small context (20% faster). For deployment purposes with many concurrent requests, GEMM will overall be much faster as it scales better. @RonanKMcGovern

Narsil pushed a commit that referenced this issue Sep 25, 2023
# Add AWQ quantization inference support

Fixes
#781

This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.

This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).

Quick way to test this PR would be bring up TGI as follows:

```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```

Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions
[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested. 
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released
[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).

Please refer to the linked issue for benchmarks for
[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs
[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).

Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.

## Who can review?

@OlivierDehaene OR @Narsil

---------

Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
@Narsil Narsil closed this as completed in c5de7cd Sep 25, 2023
@RonanKMcGovern
Copy link

RonanKMcGovern commented Sep 25, 2023

I pulled this docker image and it's recognising awq.

1.0.3 from the readme won't work for awq though. Might be worth putting a note there in the readme if not ready for a release?

Also, after trying the following flags on the latest image:

--model-id TheBloke/Llama-2-70B-chat-AWQ --trust-remote-code --port 8080 --max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 --quantize awq

I'm hitting:

2023-09-25T07:54:05.251889450-07:00     return callback(**use_params)  # type: ignore
2023-09-25T07:54:05.251891440-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py", line 82, in serve
2023-09-25T07:54:05.251893660-07:00     server.serve(
2023-09-25T07:54:05.251895730-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 195, in serve
2023-09-25T07:54:05.251898130-07:00     asyncio.run(
2023-09-25T07:54:05.251900180-07:00   File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run
2023-09-25T07:54:05.251902330-07:00     return loop.run_until_complete(main)
2023-09-25T07:54:05.251905440-07:00   File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 634, in run_until_complete
2023-09-25T07:54:05.251907560-07:00     self.run_forever()
2023-09-25T07:54:05.251909610-07:00   File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
2023-09-25T07:54:05.251911670-07:00     self._run_once()
2023-09-25T07:54:05.251913740-07:00   File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
2023-09-25T07:54:05.251915790-07:00     handle._run()
2023-09-25T07:54:05.251917850-07:00   File "/opt/conda/lib/python3.9/asyncio/events.py", line 80, in _run
2023-09-25T07:54:05.251920010-07:00     self._context.run(self._callback, *self._args)
2023-09-25T07:54:05.251924120-07:00 > File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 147, in serve_inner
2023-09-25T07:54:05.251926260-07:00     model = get_model(
2023-09-25T07:54:05.251928480-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 187, in get_model
2023-09-25T07:54:05.251930600-07:00     return FlashLlama(
2023-09-25T07:54:05.251932640-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/flash_llama.py", line 68, in __init__
2023-09-25T07:54:05.251934730-07:00     model = FlashLlamaForCausalLM(config, weights)
2023-09-25T07:54:05.251936780-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 474, in __init__
2023-09-25T07:54:05.251939230-07:00     self.model = FlashLlamaModel(config, weights)
2023-09-25T07:54:05.251941310-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 412, in __init__
2023-09-25T07:54:05.251943390-07:00     [
2023-09-25T07:54:05.251945510-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 413, in <listcomp>
2023-09-25T07:54:05.251947630-07:00     FlashLlamaLayer(
2023-09-25T07:54:05.251949740-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 349, in __init__
2023-09-25T07:54:05.251951740-07:00     self.self_attn = FlashLlamaAttention(
2023-09-25T07:54:05.251953700-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 229, in __init__
2023-09-25T07:54:05.251955840-07:00     self.query_key_value = load_attention(config, prefix, weights)
2023-09-25T07:54:05.251957850-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 154, in load_attention
2023-09-25T07:54:05.251959930-07:00     return _load_gqa(config, prefix, weights)
2023-09-25T07:54:05.251962000-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 183, in _load_gqa
2023-09-25T07:54:05.251966830-07:00     weight = weight.to(dtype=weights.dtype).to(device=weights.device)
2023-09-25T07:54:05.251968970-07:00 AttributeError: 'tuple' object has no attribute 'to'

@abhinavkulkarni
Copy link
Contributor

@RonanKMcGovern: I built latest commit on main and was able to run the command you posted, except for model TheBloke/Llama-2-7B-chat-AWQ instead of TheBloke/Llama-2-70B-chat-AWQ.

text-generation-launcher --model-id TheBloke/Llama-2-7B-chat-AWQ --trust-remote-code --port 8080 --max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 --quantize awq

I am able to send inputs to the server using cURL and obtain a legible response.

@RonanKMcGovern
Copy link

Thanks @abhinavkulkarni !

  • Yes, I have no issues on the 7B model on an A6000.
  • I have now tried the 70B model on an A6000 (error as per above) and also on a A100. The A100 with 80 GB is giving the same error.

@RonanKMcGovern
Copy link

Confirming 70B is running now as expected. Thanks all.

@jjmlovesgit
Copy link

I am also running on 3090 "TheBloke/Llama-2-7B-chat-AWQ" with success on Langchain POCs using commands from above:

Demo mccorji@llama:~/tgi/dev/text-generation-inference$ ./start_7b_awq_simple.sh
Starting the Docker container with local files (No Internet): Llama-2-7b-Chat-AWQ and volume: /home/mccorji/tgi/dev/text-generation-inference/data ...
cb808d64f57643e7a38e51d04e6ec48d6f9d8f27f8c814039944127fdb8fef20
Container started successfully!
Running the text-generation-launcher command from /data directory inside the container...Local Files only will be used
2023-09-27T04:20:17.283981Z INFO text_generation_launcher: Args { model_id: "TheBloke/Llama-2-7B-chat-AWQ", revision: None, validation_workers: 2, sharded: None, num_shard: None, quantize: Some(Awq), dtype: None, trust_remote_code: true, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_top_n_tokens: 5, max_input_length: 2048, max_total_tokens: 4096, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4096, max_batch_total_tokens: None, max_waiting_tokens: 20, hostname: "cb808d64f576", port: 8080, shard_uds_path: "/tmp/text-generation-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, rope_scaling: None, rope_factor: None, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, env: false }
2023-09-27T04:20:17.284020Z WARN text_generation_launcher: trust_remote_code is set. Trusting that model TheBloke/Llama-2-7B-chat-AWQ do not contain malicious code.
2023-09-27T04:20:17.284088Z INFO download: text_generation_launcher: Starting download process.
2023-09-27T04:20:19.235606Z INFO text_generation_launcher: Files are already present on the host. Skipping download.

2023-09-27T04:20:19.487792Z INFO download: text_generation_launcher: Successfully downloaded weights.
2023-09-27T04:20:19.488056Z INFO shard-manager: text_generation_launcher: Starting shard rank=0
2023-09-27T04:20:28.267572Z INFO text_generation_launcher: Server started at unix:///tmp/text-generation-server-0

2023-09-27T04:20:28.299314Z INFO shard-manager: text_generation_launcher: Shard ready in 8.810838819s rank=0
2023-09-27T04:20:28.399514Z INFO text_generation_launcher: Starting Webserver
2023-09-27T04:20:28.719311Z WARN text_generation_router: router/src/main.rs:349: --revision is not set
2023-09-27T04:20:28.719339Z WARN text_generation_router: router/src/main.rs:350: We strongly advise to set it to a known supported commit.
2023-09-27T04:20:28.993330Z INFO text_generation_router: router/src/main.rs:371: Serving revision 47c8d2736daf1e3b57d9689129c3ddfc596299e1 of model TheBloke/Llama-2-7b-Chat-AWQ
2023-09-27T04:20:28.998950Z INFO text_generation_router: router/src/main.rs:213: Warming up model
2023-09-27T04:20:31.358658Z INFO text_generation_router: router/src/main.rs:246: Setting max batch total tokens to 31984
2023-09-27T04:20:31.358683Z INFO text_generation_router: router/src/main.rs:247: Connected
2023-09-27T04:20:31.358687Z WARN text_generation_router: router/src/main.rs:252: Invalid hostname, defaulting to 0.0.0.0
2023-09-27T04:20:58.063416Z INFO HTTP request{otel.name=POST / http.client_ip= http.flavor=1.1 http.host=localhost:8080 http.method=POST http.route=/ http.scheme=HTTP http.target=/ http.user_agent=python-requests/2.29.0 otel.kind=server trace_id=97f2b23363738b637d455608234b8cc9 http.status_code=200 otel.status_code="OK"}:compat_generate{default_return_full_text=true}:generate_stream{parameters=GenerateParameters { best_of: None, temperature: Some(0.01), repetition_penalty: Some(1.03), top_k: Some(10), top_p: Some(0.95), typical_p: Some(0.95), do_sample: false, max_new_tokens: 512, return_full_text: Some(false), stop: [], truncate: None, watermark: false, details: true, decoder_input_details: false, seed: None, top_n_tokens: None } total_time="1.92377586s" validation_time="1.458384ms" queue_time="75.402µs" inference_time="1.922242217s" time_per_token="14.030965ms" seed="Some(3854729438860470207)"}: text_generation_router::server: router/src/server.rs:457: Success
2023-09-27T04:21:24.539278Z INFO HTTP request{otel.name=POST / http.client_ip= http.flavor=1.1 http.host=localhost:8080 http.method=POST http.route=/ http.scheme=HTTP http.target=/ http.user_agent=python-requests/2.29.0 otel.kind=server trace_id=f9923de7bbb6fe7f662d34a2d3566b00 http.status_code=200 otel.status_code="OK"}:compat_generate{default_return_full_text=true}:generate_stream{parameters=GenerateParameters { best_of: None, temperature: Some(0.01), repetition_penalty: Some(1.03), top_k: Some(10), top_p: Some(0.95), typical_p: Some(0.95), do_sample: false, max_new_tokens: 512, return_full_text: Some(false), stop: [], truncate: None, watermark: false, details: true, decoder_input_details: false, seed: None, top_n_tokens: None } total_time="361.679818ms" validation_time="1.190807ms" queue_time="58.746µs" inference_time="360.430359ms" time_per_token="20.023908ms" seed="Some(6292320286675141931)"}: text_generation_router::server: router/src/server.rs:457: Success

@naticio
Copy link

naticio commented Oct 14, 2023

so.....can we run any awq model using TGI or just some of them (as of now)

trying to launch text-generation-launcher --model-id TheBloke/Wizard-Vicuna-30B-Uncensored-AWQ

but doesn;t work

RuntimeError: weight model.layers.0.self_attn.q_proj.weight does not exist
rank=0
2023-10-14T00:15:44.698002Z ERROR text_generation_launcher: Shard 0 failed to start
2023-10-14T00:15:44.698035Z INFO text_generation_launcher: Shutting down shards

@abhinavkulkarni
Copy link
Contributor

abhinavkulkarni commented Oct 14, 2023

@naticio: Currently only FlashLlama models are supported for AWQ quantization. So, the underlying model has to be a Llama 1 or 2 architecture.

However, it should be easy to add support for other types of AWQ quantized models such as MPT, Falcon, etc.

verdant621 added a commit to verdant621/text-generation-inference that referenced this issue Oct 19, 2023
# Add AWQ quantization inference support

Fixes
huggingface/text-generation-inference#781

This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.

This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).

Quick way to test this PR would be bring up TGI as follows:

```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```

Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions

[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested. 
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released

[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).

Please refer to the linked issue for benchmarks for

[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs

[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).

Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.

## Who can review?

@OlivierDehaene OR @Narsil

---------



# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com>
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
cr313 added a commit to cr313/text-generation-inference-load-test that referenced this issue Apr 19, 2024
# Add AWQ quantization inference support

Fixes
huggingface/text-generation-inference#781

This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.

This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).

Quick way to test this PR would be bring up TGI as follows:

```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```

Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions

[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested. 
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released

[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).

Please refer to the linked issue for benchmarks for

[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs

[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).

Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.

## Who can review?

@OlivierDehaene OR @Narsil

---------



# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com>
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
alfredgui2 pushed a commit to mlsys-io/kv.run that referenced this issue Jul 6, 2024
# Add AWQ quantization inference support

Fixes
huggingface/text-generation-inference#781

This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.

This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).

Quick way to test this PR would be bring up TGI as follows:

```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```

Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions

[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested. 
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released

[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).

Please refer to the linked issue for benchmarks for

[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs

[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).

Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.

## Who can review?

@OlivierDehaene OR @Narsil

---------



# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com>
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.