In [1]:
import torch
from smoothquant.smooth import smooth_lm
from smoothquant.calibration import get_act_scales
from smoothquant.utils import quantize_model, Perplexity
from transformers import AutoModelForCausalLM, AutoTokenizer

## FP16 Model Accuracy

In [2]:
model_name = 'PY007/TinyLlama-1.1B-Chat-v0.2'

model_fp16 = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16, 
    device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [3]:
ppl = Perplexity(model_fp16, tokenizer)
tokens = tokenizer(ppl._text, truncation=False, return_tensors='pt').input_ids.to(ppl._model.device)
out = ppl.calculate_perplexity(tokens=tokens)
print(f'FP16 perplexity: {out[-1]}')

Perplexity: - :   0%|          | 0/655 [00:00<?, ?it/s]

FP16 perplexity: 13.785424512084473


## Naive W8A8 Quantized

In [4]:
model_w8a8 = quantize_model(model_fp16)
ppl = Perplexity(model_fp16, tokenizer)
out = ppl.calculate_perplexity(tokens=tokens)
print(f'Naive W8A8 perplexity: {out[-1]}')

Perplexity: - :   0%|          | 0/655 [00:00<?, ?it/s]

Naive W8A8 perplexity: 40.489475107108944


## SmoothQuant W8A8 Quantized

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16, 
    device_map='auto'
)

act_scales = get_act_scales(
    model, tokenizer, 'mit-han-lab/pile-val-backup', 512, 512
)

smooth_lm(model, act_scales, 0.5)
model_smoothquant_w8a8 = quantize_model(model)

Repo card metadata block was not found. Setting CardData to empty.
100%|██████████| 512/512 [00:31<00:00, 16.18it/s]


smooth llama decoder: model.layers.0
smooth llama decoder: model.layers.1
smooth llama decoder: model.layers.2
smooth llama decoder: model.layers.3
smooth llama decoder: model.layers.4
smooth llama decoder: model.layers.5
smooth llama decoder: model.layers.6
smooth llama decoder: model.layers.7
smooth llama decoder: model.layers.8
smooth llama decoder: model.layers.9
smooth llama decoder: model.layers.10
smooth llama decoder: model.layers.11
smooth llama decoder: model.layers.12
smooth llama decoder: model.layers.13
smooth llama decoder: model.layers.14
smooth llama decoder: model.layers.15
smooth llama decoder: model.layers.16
smooth llama decoder: model.layers.17
smooth llama decoder: model.layers.18
smooth llama decoder: model.layers.19
smooth llama decoder: model.layers.20
smooth llama decoder: model.layers.21


In [6]:
ppl = Perplexity(model_smoothquant_w8a8, tokenizer)
out = ppl.calculate_perplexity(tokens=tokens)
print(f'SmoothQuant W8A8 perplexity: {out[-1]}')

Perplexity: - :   0%|          | 0/655 [00:00<?, ?it/s]

SmoothQuant W8A8 perplexity: 36.41558657884512
