# SmoothQuant on OPT-13B

### Guangxuan Xiao\*, Ji Lin\*, Mickael Seznec, Julien Demouth, Song Han

In this notebook, we use OPT-13B model to demonstrate SmoothQuant can use 8-bit for both weights and activations to achieve the same accuracy as FP16 models. Unlike previous method [[Dettmers *et al.*, 2022]](https://arxiv.org/abs/2208.07339), SmoothQuant enables fully INT8 GEMMs for linear layers and does not require high precision numbers to represent outliers. 

This notebook demonstrates SmoothQuant on OPT-13B in consideration of the user's resouce constraints. We have tested SmoothQuant on up to 176 billion parameter models (OPT-175B, BLOOM-176B, GLM-130B). You can also adjust the model name to validate SmoothQuant on other models. `../act_scales/` provides the activation channel scales for OPT and BLOOM models.

In order to run this notebook, you need to install the following packages:

- smoothquant
- PyTorch
- Transformers
- Accelerate

In [80]:
import torch
from transformers.models.opt.modeling_opt import (
    OPTAttention, OPTDecoderLayer, OPTForCausalLM)

from transformers import AutoModelForCausalLM
from transformers.models.llama.modeling_llama import (
    LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm) #, LlamaForCausalLM)

from transformers import GPT2Tokenizer, AutoTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import W8A8Linear

In this notebook, we simulate the 8-bit dynamic per-tensor weight and activation quantization with FP16, i.e., fake quantization. We have implemented the real 8-bit quantization with INT8 CUTLASS GEMM kernels for both PyTorch and FasterTransformer. Please stay tuned for the release.

In [97]:
def quantize_model(model, weight_quant='per_tensor', act_quant='per_tensor', quantize_bmm_input=True):
    for name, m in model.model.named_modules():
        if isinstance(m, LlamaDecoderLayer):  # OPTDecoderLayer
            # m.fc1 = W8A8Linear.from_float(m.fc1, weight_quant=weight_quant, act_quant=act_quant)
            # m.fc2 = W8A8Linear.from_float(m.fc2, weight_quant=weight_quant, act_quant=act_quant)
            
            m.self_attn.q_proj = W8A8Linear.from_float(
                m.self_attn.q_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.self_attn.k_proj = W8A8Linear.from_float(
                m.self_attn.k_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.self_attn.v_proj = W8A8Linear.from_float(
                m.self_attn.v_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.self_attn.o_proj = W8A8Linear.from_float(
                m.self_attn.o_proj, weight_quant=weight_quant, act_quant=act_quant)
            
            m.mlp.gate_proj = W8A8Linear.from_float(
                m.mlp.gate_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.mlp.up_proj = W8A8Linear.from_float(
                m.mlp.up_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.mlp.down_proj = W8A8Linear.from_float(
                m.mlp.down_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
        elif isinstance(m, LlamaAttention):  # OPTAttention
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            # print(f'!!! {m}')
            continue
            
            print(m.q_proj)
            m.q_proj = W8A8Linear.from_float(
                m.q_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.k_proj = W8A8Linear.from_float(
                m.k_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.v_proj = W8A8Linear.from_float(
                m.v_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.out_proj = W8A8Linear.from_float(m.out_proj, weight_quant=weight_quant, act_quant=act_quant)
    return model


The following is an evaluator to see the performance of the model. We use a toy dataset (the first 1000 examples in the validation set of the Lambada dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same.

**In this demo, we have simplified the evaluation by using the first 1,000 samples from the LAMBADA dataset's validation set. We employ the "Last Token Prediction Accuracy" as our evaluation metric. This approximate evaluation is intended for demonstration purposes, providing simple but meaningful comparisons of relative performance between methods. For a more strict assessment, we recommend using the [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) to obtain the "Last Word Prediction Accuracy" for the LAMBADA dataset, which is the reported metric in our paper.**

In [135]:
class Evaluator:
    def __init__(self, dataset, tokenizer, device):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        # tokenize the dataset
        def tokenize_function(examples):
            example = self.tokenizer(examples['text'], )
            return example

        self.dataset = self.dataset.map(tokenize_function, batched=True)
        self.dataset.set_format(type='torch', columns=['input_ids'])
        
#         for batch in self.dataset:
#             print(batch['input_ids'].shape)

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        # The task is to predict the last word of the input.
        total, hit = 0, 0
        for i, batch in enumerate(self.dataset):
            input_ids = batch['input_ids'].to(self.device).unsqueeze(0)
            
            if input_ids.shape[1] < 5:
                # print(f'!!! Skip input of shape {input_ids.shape} !!!')
                continue
            
            label = input_ids[:, -1]
            outputs = model(input_ids)

            
            # print(input_ids.shape)
            # print(outputs)
            # print(outputs.logits.shape)
            # assert False
            
            last_token_logits = outputs.logits[:, -1, :]  # TODO: ???
            # -1 or -2 ???
            
            pred = last_token_logits.argmax(dim=-1)
            
            #if i < 10:
            #    print('From evaluate:')
            #    print(input_ids)
            #    print(pred)
            
            total += label.size(0)
            hit += (pred == label).sum().item()
        acc = hit / total
        return acc


In [6]:
! ls ~/models

llama-7b


In [7]:
! pwd

/home/alekseev_v/projects_node4/smoothquant/examples


In [8]:
MODEL_PATH = '/home/alekseev_v/models/llama-7b'

In [12]:
from datasets import load_dataset

#tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-13b')
#dataset = load_dataset('lambada', split='validation[:1000]')

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='validation[:1000]')

# evaluator = Evaluator(dataset, tokenizer, 'cuda')


Map: 100%|█████████████████████████████| 1000/1000 [00:00<00:00, 2343.36 examples/s]


## FP16 Model Accuracy

Let's first check the performance of the original FP16 model.

In [62]:
# OPTForCausalLM

model_fp16 = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH, torch_dtype=torch.float16, device_map='auto')

Loading checkpoint shards: 100%|████████████████████| 33/33 [00:30<00:00,  1.07it/s]


In [136]:
evaluator = Evaluator(dataset, tokenizer, 'cuda')

Map: 100%|█████████████████████████████| 1000/1000 [00:00<00:00, 2973.88 examples/s]


In [44]:
acc_fp16 = evaluator.evaluate(model_fp16)
print(f'Original model (fp16) accuracy: {acc_fp16}')


!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip inpu

!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip inpu

!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
Original model (fp16) accuracy: 0.968944099378882


We then quantize the model to W8A8 and check the performance.

## Naive W8A8 Quantized Model Accuracy

In [64]:
model_w8a8 = quantize_model(model_fp16)
print(model_w8a8)

!!! LlamaAttention(
  (q_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (k_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (v_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (o_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=None)
  (rotary_emb): LlamaRotaryEmbedding()
)
!!! LlamaAttention(
  (q_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (k_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (v_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (o_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor

!!! LlamaAttention(
  (q_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (k_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (v_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (o_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=None)
  (rotary_emb): LlamaRotaryEmbedding()
)
!!! LlamaAttention(
  (q_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (k_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (v_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
  (o_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor

In [65]:
acc_w8a8 = evaluator.evaluate(model_w8a8)
print(f'Naive W8A8 quantized model accuracy: {acc_w8a8}')

!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip inpu

!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip inpu

!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
!!! Skip input of shape torch.Size([1, 1]) !!!
Naive W8A8 quantized model accuracy: 0.10093167701863354


We can see there is a significant accuracy drop. This is consistent with LLM.int8()'s finding: when the model size increases larger than 6.7B, systematic outliers will emerge in activations, which makes fully INT8 quantization impossible.

In [137]:
import torch
import torch.nn as nn

from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.bloom.modeling_bloom import BloomBlock


@torch.no_grad()
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
    if not isinstance(fcs, list):
        fcs = [fcs]
        
    assert isinstance(ln, LlamaRMSNorm) #nn.LayerNorm), type(ln)
    
    for fc in fcs:
        assert isinstance(fc, nn.Linear)
        assert ln.weight.numel() == fc.in_features == act_scales.numel()

    device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
    act_scales = act_scales.to(device=device, dtype=dtype)
    weight_scales = torch.cat([fc.weight.abs().max(
        dim=0, keepdim=True)[0] for fc in fcs], dim=0)
    weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)

    scales = (act_scales.pow(alpha) / weight_scales.pow(1-alpha)
              ).clamp(min=1e-5).to(device).to(dtype)

    ln.weight.div_(scales)
    # ln.bias.div_(scales)

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))
    
    return scales


@torch.no_grad()
def smooth_lm(model, scales, alpha=0.5):
    for name, module in model.named_modules():
        if isinstance(module, LlamaDecoderLayer):  # OPTDecoderLayer
            attn_ln = module.input_layernorm  # self_attn_layer_norm
            qkv = [module.self_attn.q_proj,
                   module.self_attn.k_proj,
                   module.self_attn.v_proj,]
                   # module.self_attn.o_proj]  # ???
            qkv_input_scales = scales[name + '.self_attn.q_proj']
            smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)

            
            ffn_ln = module.post_attention_layernorm  # final_layer_norm
            #fc1 = module.mlp.gate_proj  # ??? dense_h_to_4h
            fc1 = [
                module.mlp.gate_proj, module.mlp.up_proj
            ]
            fc1_input_scales = scales[name + '.mlp.gate_proj']
            
            _scales = smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)
            
            # print(_scales)
            
        """
        elif isinstance(module, BloomBlock):
            attn_ln = module.input_layernorm
            qkv = module.self_attention.query_key_value
            qkv_input_scales = scales[name + '.self_attention.query_key_value']
            smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)

            ffn_ln = module.post_attention_layernorm
            fc1 = module.mlp.gate_proj  # ??? dense_h_to_4h
            fc1_input_scales = scales[name + '.mlp.gate_proj']
            smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)
        """

## SmoothQuant W8A8 Quantized Model Accuracy

Let's smooth the model, quantize it, and check the performance! In `../act_scales`, we provide the activation scales for OPT and BLOOM models. You can also use this notebook to test quantizing those models.

In [73]:
! ls ../act_scales/llama-7b.pt

../act_scales/llama-7b.pt


In [148]:
# model = OPTForCausalLM.from_pretrained(
#     'facebook/opt-13b', torch_dtype=torch.float16, device_map='auto')

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH, torch_dtype=torch.float16, device_map='auto')

act_scales = torch.load('../act_scales/llama-7b.pt')

Loading checkpoint shards: 100%|████████████████████| 33/33 [00:42<00:00,  1.27s/it]


In [149]:
act_scales

{'model.layers.0.self_attn.q_proj': tensor([0.1954, 0.0732, 0.0819,  ..., 0.1273, 0.0908, 0.1018]),
 'model.layers.0.self_attn.k_proj': tensor([0.1954, 0.0732, 0.0819,  ..., 0.1273, 0.0908, 0.1018]),
 'model.layers.0.self_attn.v_proj': tensor([0.1954, 0.0732, 0.0819,  ..., 0.1273, 0.0908, 0.1018]),
 'model.layers.0.self_attn.o_proj': tensor([0.0148, 0.0150, 0.0174,  ..., 0.1532, 0.0299, 0.0484]),
 'model.layers.0.mlp.gate_proj': tensor([0.1642, 0.2338, 0.2385,  ..., 0.1931, 0.2471, 0.2076]),
 'model.layers.0.mlp.up_proj': tensor([0.1642, 0.2338, 0.2385,  ..., 0.1931, 0.2471, 0.2076]),
 'model.layers.0.mlp.down_proj': tensor([0.1388, 0.5039, 0.4951,  ..., 0.1766, 0.3137, 0.2393]),
 'model.layers.1.self_attn.q_proj': tensor([0.4578, 0.2279, 0.2401,  ..., 0.1924, 0.2510, 0.2461]),
 'model.layers.1.self_attn.k_proj': tensor([0.4578, 0.2279, 0.2401,  ..., 0.1924, 0.2510, 0.2461]),
 'model.layers.1.self_attn.v_proj': tensor([0.4578, 0.2279, 0.2401,  ..., 0.1924, 0.2510, 0.2461]),
 'model.lay

In [150]:
evaluator.evaluate(model)

0.968944099378882

In [151]:
smooth_lm(model, act_scales, 0.5)

In [152]:
evaluator.evaluate(model)

0.968944099378882

In [153]:
model.save_pretrained('llama-7b-smooth-0.5')  
tokenizer.save_pretrained('llama-7b-smooth-0.5') 

('llama-7b-smooth-0.5/tokenizer_config.json',
 'llama-7b-smooth-0.5/special_tokens_map.json',
 'llama-7b-smooth-0.5/tokenizer.model',
 'llama-7b-smooth-0.5/added_tokens.json')

In [154]:
! ls llama-7b-smooth-0.5

config.json			  pytorch_model.bin.index.json
generation_config.json		  special_tokens_map.json
pytorch_model-00001-of-00002.bin  tokenizer_config.json
pytorch_model-00002-of-00002.bin  tokenizer.model


In [5]:
! ls /home/alekseev_v/models/smooth-0.5/llama-7b/

config.json			  pytorch_model.bin.index.json
generation_config.json		  special_tokens_map.json
pytorch_model-00001-of-00002.bin  tokenizer_config.json
pytorch_model-00002-of-00002.bin  tokenizer.model


In [7]:
! cat /home/alekseev_v/models/llama-7b/config.json

{"architectures": ["LLaMAForCausalLM"], "bos_token_id": 0, "eos_token_id": 1, "hidden_act": "silu", "hidden_size": 4096, "intermediate_size": 11008, "initializer_range": 0.02, "max_sequence_length": 2048, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "pad_token_id": -1, "rms_norm_eps": 1e-06, "torch_dtype": "float16", "transformers_version": "4.27.0.dev0", "use_cache": true, "vocab_size": 32000}

In [6]:
! cat /home/alekseev_v/models/smooth-0.5/llama-7b/config.json

{
  "_name_or_path": "/home/alekseev_v/models/llama-7b",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "bos_token_id": 0,
  "eos_token_id": 1,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 2048,
  "max_sequence_length": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pad_token_id": -1,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.33.1",
  "use_cache": true,
  "vocab_size": 32000
}


In [115]:
1

1

In [143]:
model_smoothquant_w8a8 = quantize_model(model)
print(model_smoothquant_w8a8)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=31999)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
          (k_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
          (v_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
          (o_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_tensor, act_qua

We can see the smoothed model has the same accuracy as the FP16 model. This is because SmoothQuant smooths the outliers in activations and moves the quantization difficulty from activations to weights.

In [144]:
1

1

In [145]:
acc_smoothquant_w8a8 = evaluator.evaluate(model_smoothquant_w8a8)

In [146]:
print(f'SmoothQuant W8A8 quantized model accuracy: {acc_smoothquant_w8a8}')

SmoothQuant W8A8 quantized model accuracy: 0.14596273291925466
