# AWQ on Vicuna

In this notebook, we use Vicuna model to demonstrate the performance of AWQ on instruction-tuned models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint.

In order to run this notebook, you need to install the following packages:
- [AWQ](https://github.com/mit-han-lab/llm-awq)
- [Pytorch](https://pytorch.org/)
- [Accelerate](https://github.com/huggingface/accelerate)
- [Transformers](https://github.com/huggingface/transformers)

In [1]:
import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.quantize.quantizer import real_quantize_model_weight
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from tinychat.demo import gen_params, stream_output
from tinychat.stream_generators import StreamGenerator
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
from tinychat.utils.prompt_templates import get_prompter
import os
# This demo only support single GPU for now
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Please get the Vicuna model from [FastChat](https://github.com/lm-sys/FastChat) and run the following command to generate a quantized model checkpoint first.

```bash
mkdir quant_cache
python -m awq.entry --model_path [vicuna-7b_model_path] \
    --w_bit 4 --q_group_size 128 \
    --load_awq awq_cache/vicuna-7b-w4-g128.pt \
    --q_backend real --dump_quant quant_cache/vicuna-7b-w4-g128-awq.pt
```

In [2]:
# model_path = "" # the path of vicuna-7b model
# load_quant_path = "quant_cache/vicuna-7b-w4-g128-awq.pt"
model_path = "/data/llm/checkpoints/vicuna-hf/vicuna-7b"
load_quant_path = "/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt"

We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. 

In [3]:
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
                                                    torch_dtype=torch.float16)
q_config = {"zero_point": True, "q_group_size": 128}
real_quantize_model_weight(
    model, w_bit=4, q_config=q_config, init_only=True)

model = load_checkpoint_and_dispatch(
    model, load_quant_path,
    device_map="auto",
    no_split_module_classes=["LlamaDecoderLayer"]
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

real weight quantization...(init only): 100%|███████████████████| 32/32 [00:11<00:00,  2.69it/s]
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


In [4]:
make_quant_attn(model, "cuda:0")
make_quant_norm(model)
make_fused_mlp(model)



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): QuantLlamaAttention(
          (qkv_proj): WQLinear(in_features=4096, out_features=12288, bias=False, w_bit=4, group_size=128)
          (o_proj): WQLinear(in_features=4096, out_features=4096, bias=False, w_bit=4, group_size=128)
          (rotary_emb): QuantLlamaRotaryEmbedding()
        )
        (mlp): QuantLlamaMLP(
          (down_proj): WQLinear(in_features=11008, out_features=4096, bias=False, w_bit=4, group_size=128)
        )
        (input_layernorm): FTLlamaRMSNorm()
        (post_attention_layernorm): FTLlamaRMSNorm()
      )
    )
    (norm): FTLlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

In [5]:
model_prompter = get_prompter("llama", model_path)
stream_generator = StreamGenerator
count = 0
while True:
    # Get input from the user
    input_prompt = input("USER: ")
    if input_prompt == "":
        print("EXIT...")
        break
    model_prompter.insert_prompt(input_prompt)
    output_stream = stream_generator(model, tokenizer, model_prompter.model_input, gen_params, device="cuda:0")
    outputs = stream_output(output_stream)    
    model_prompter.update_template(outputs)
    count += 1

USER:  Show me some attractions in Boston.


ASSISTANT: 1. Boston Public Library
2. Fenway Park
3. Harvard Square
4. Boston Common
5. Freedom Trail
6. Museum of Fine Arts
7. Isabella Stewart Gardner Museum
8. Paul Revere House
9. New England Aquarium
10. Museum of Science
Speed of Inference
--------------------------------------------------
Context Stage    : 7.18 ms/token
Generation Stage : 9.49 ms/token
Average Speed    : 8.53 ms/token


USER:  


EXIT...
