In [2]:
import torch
from smoothquant.smooth import smooth_lm
from smoothquant.utils import Perplexity
from smoothquant.opt import Int8OPTForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant.calibration import get_act_scales, get_static_decoder_layer_scales

In [3]:
model_name = 'facebook/opt-6.7b'

## FP16 Model Accuracy

In [6]:
model_fp16 = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16, 
    device_map='auto'
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
ppl = Perplexity(model_fp16, tokenizer)
out = ppl.calculate_perplexity()
print(f'FP16 perplexity: {out[-1]}')

Perplexity: - :   0%|          | 0/559 [00:00<?, ?it/s]

KeyboardInterrupt: 

## SmoothQuant W8A8 Quantized

In [4]:
# load model
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16, 
    device_map='auto',
    low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# smooth layers
act_scales = get_act_scales(model, tokenizer, 'mit-han-lab/pile-val-backup', 512, 512)
smooth_lm(model, act_scales, 0.5)

# get model scales
decoder_layer_scales, raw_scales = get_static_decoder_layer_scales(
    model,
    tokenizer,
    'mit-han-lab/pile-val-backup',
    num_samples=512,
    seq_len=512
)

with torch.device("cuda"):
    model_smoothquant_w8a8 = Int8OPTForCausalLM.from_float(model, decoder_layer_scales)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Repo card metadata block was not found. Setting CardData to empty.
100%|██████████| 512/512 [00:47<00:00, 10.79it/s]


Collecting activation scales...


  0%|          | 0/512 [00:00<?, ?it/s]Repo card metadata block was not found. Setting CardData to empty.
Mean input scale: 5.61: 100%|██████████| 512/512 [00:48<00:00, 10.51it/s]


RuntimeError: result type Float can't be cast to the desired output type Char

In [None]:
ppl = Perplexity(model_smoothquant_w8a8, tokenizer)
out = ppl.calculate_perplexity()
print(f'SmoothQuant W8A8 perplexity: {out[-1]}')

Perplexity: - :   0%|          | 0/655 [00:00<?, ?it/s]

SmoothQuant W8A8 perplexity: 82762.64201220016
