In [None]:
import torch
from smoothquant.smooth import smooth_lm
from smoothquant.utils import Perplexity
from smoothquant.llama import Int8LlamaForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant.calibration import get_act_scales, get_static_llama_decoder_layer_scales

In [None]:
model_name = 'TheBloke/Llama-2-7b-chat-fp16'
# model_name = 'PY007/TinyLlama-1.1B-Chat-v0.2'

## FP16 Model Accuracy

In [None]:
model_fp16 = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16, 
    device_map='auto'
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
ppl = Perplexity(model_fp16, tokenizer)
out = ppl.calculate_perplexity()
print(f'FP16 perplexity: {out[-1]}')

## SmoothQuant W8A8 Quantized

In [None]:
# load model
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16, 
    device_map='auto',
    low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# smooth layers
act_scales = get_act_scales(model, tokenizer, 'mit-han-lab/pile-val-backup', 512, 512)
smooth_lm(model, act_scales, 0.5)

# get model scales
decoder_layer_scales, raw_scales = get_static_llama_decoder_layer_scales(
    model,
    tokenizer,
    'mit-han-lab/pile-val-backup',
    num_samples=512,
    seq_len=512
)

model_smoothquant_w8a8 = Int8LlamaForCausalLM.from_float(model, decoder_layer_scales)

In [None]:
ppl = Perplexity(model_smoothquant_w8a8, tokenizer)
out = ppl.calculate_perplexity()
print(f'SmoothQuant W8A8 perplexity: {out[-1]}')