In [1]:
# per-channel, per-tensor, per-token 차이 확인

from smoothquant.fake_quant import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Evaluator:
    def __init__(self, dataset, tokenizer, device):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        # tokenize the dataset
        def tokenize_function(examples):
            example = self.tokenizer(examples['text'])
            return example

        self.dataset = self.dataset.map(tokenize_function, batched=True)
        self.dataset.set_format(type='torch', columns=['input_ids'])

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        # The task is to predict the last word of the input.
        total, hit = 0, 0
        for batch in self.dataset:
            input_ids = batch['input_ids'].to(self.device).unsqueeze(0)
            label = input_ids[:, -1]
            outputs = model(input_ids)
            last_token_logits = outputs.logits[:, -2, :]
            pred = last_token_logits.argmax(dim=-1)
            total += label.size(0)
            hit += (pred == label).sum().item()
        acc = hit / total
        return acc


In [3]:
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, OPTForCausalLM

model_id = 'facebook/opt-13b'
model_fp16 = OPTForCausalLM.from_pretrained('facebook/opt-13b', torch_dtype=torch.float16, device_map='auto')


Loading checkpoint shards: 100%|██████████| 3/3 [00:16<00:00,  5.49s/it]


In [4]:
from datasets import load_dataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset('lambada', split='validation[:1000]')
evaluator = Evaluator(dataset, tokenizer, 'cuda')


tokenizer_config.json: 100%|██████████| 721/721 [00:00<00:00, 3.34MB/s]
vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 3.41MB/s]
merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 822kB/s]
special_tokens_map.json: 100%|██████████| 441/441 [00:00<00:00, 2.62MB/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 14613.43 examples/s]


In [5]:
acc_before_q = evaluator.evaluate(model_fp16)
print(acc_before_q)

KeyboardInterrupt: 

In [None]:
with torch.no_grad():
    for name, m in model_fp16.model.named_modules():
        if isinstance(m, OPTDecoderLayer):
            m.fc1.weight = quantize_weight_per_channel_absmax(m.fc1.weight)
            m.fc2.weight = quantize_weight_per_channel_absmax(m.fc2.weight)
        elif isinstance(m, OPTAttention):
            m.q_proj.weight = quantize_weight_per_channel_absmax(m.q_proj.weight)
            m.k_proj.weight = quantize_weight_per_channel_absmax(m.k_proj.weight)
            m.v_proj.weight = quantize_weight_per_channel_absmax(m.v_proj.weight)
            m.out_proj.weight = quantize_weight_per_channel_absmax(m.out_proj.weight)

In [None]:
acc = evaluator.evaluate(model_fp16)
print(acc)

0.797
