In [2]:
import torch
import torch.nn as nn

import functools
from functools import partial

from collections import defaultdict

from tqdm import tqdm

import numpy as np

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

from utils import build_model_and_tokenizer
from quant import quant, dequant
from auto_gptq import AutoGPTQForCausalLM

from transformers.pytorch_utils import Conv1D

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = 'facebook/opt-125m'
device = "cuda"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testenc = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")

In [4]:
# 약 5분 30초 가량 소모
'''
from utils import opt_eval
opt_eval(model, testenc, device)
'''

'\nfrom utils import opt_eval\nopt_eval(model, testenc, device)\n'

In [5]:
# 약 1분 16초 소모, base model (fp) perplexity
from utils import evaluate_opt, opt_eval #evaluate_opt: approx, opt_eval: more precise ppl
print(f'original ppl: {evaluate_opt(model, testenc).item()}')

# 약 28초 소모, cpu memory 약 67GB 까지 치솟음
from quant import quant, dequant
bits = 8
gs = 128
scale, zero, qs = quant(bits, gs, model)
q_x = dequant(scale, zero, qs, gs, bits)
for key in q_x.keys():
    if key.split('.')[-1] != 'lm_head':
        weight = key+'.weight'
        model.state_dict()[weight][:] = q_x[key]

# 1분 16초, quant_model perplexity
print(f'quantized model ppl: {evaluate_opt(model, testenc)}')

original ppl: 27.579069137573242
quantized model ppl: 27.56010627746582


In [6]:
def get_activation(model, dataset, tokenizer, num_samples=512, seq_len=512):
    model.eval()
    device = next(model.parameters()).device # next는 객체의 __next__ 호출, 다음 iter를 부름?
    act_scales_max = {}
    act_scales_min = {}

    def stat_tensor(name, tensor):
        hidden_dim = tensor.shape[-1]
        tensor = tensor.view(-1, hidden_dim).detach()
        comming_max = torch.max(tensor, dim=0)[0].float().cpu()
        comming_min = torch.min(tensor, dim=0)[0].float().cpu()
        if name in act_scales_max:
            act_scales_max[name] = torch.max(act_scales_max[name], comming_max)
            act_scales_min[name] = torch.min(act_scales_min[name], comming_min)
        else:
            act_scales_max[name] = comming_max
            act_scales_min[name] = comming_min
    
    def stat_input_hook(m, x, y, name):
        if isinstance(x, tuple):
            y = y[0]
        stat_tensor(name, y)

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear) | isinstance(m, Conv1D):
            hooks.append(
                m.register_forward_hook(
                    functools.partial(stat_input_hook, name=name))
            )
    
    dataset = dataset.shuffle(seed=42)
    dataset_list = []
    for ii in range(len(dataset)):
        if dataset[ii]['text'] != '':
            dataset_list.append(dataset[ii])

    for i in tqdm(range(num_samples)):
        input_ids = tokenizer(dataset_list[i]["text"], return_tensors="pt",
                              max_length=seq_len, truncation=True).input_ids.to(device)
        model(input_ids)

    for h in hooks:
        h.remove()

    return act_scales_max, act_scales_min

In [7]:
act_scales_max, act_scales_min = get_activation(model, dataset, tokenizer)

100%|██████████| 512/512 [00:04<00:00, 106.68it/s]


In [8]:
act_scales_max

{'model.decoder.layers.0.self_attn.q_proj': tensor([3.5236, 4.7041, 5.4097, 4.0440, 4.2059, 4.5332, 4.9974, 4.7833, 4.9734,
         5.4693, 4.3798, 4.9689, 5.0644, 4.5536, 5.4619, 3.2922, 4.4473, 5.4894,
         4.8941, 5.9226, 5.0661, 4.8906, 4.5995, 4.7358, 5.1595, 5.0211, 6.4851,
         4.7572, 5.8412, 6.0131, 6.1208, 4.4302, 5.8073, 5.0905, 6.3212, 4.8501,
         5.1880, 5.5920, 4.6481, 5.3415, 4.5617, 4.7621, 5.0914, 4.0989, 5.2782,
         5.4063, 4.8896, 4.3351, 4.2364, 5.0839, 4.1968, 5.9454, 4.7280, 6.0213,
         4.1578, 3.3511, 4.6617, 3.8354, 5.0575, 4.3128, 6.2229, 4.7454, 5.2529,
         4.1334, 3.7444, 3.4968, 2.5612, 2.6447, 3.0013, 2.2865, 3.5780, 3.3171,
         2.3698, 2.2022, 2.7713, 2.8257, 2.0800, 3.5249, 1.5543, 2.2270, 2.8850,
         3.3389, 3.0965, 2.9451, 3.6881, 3.2100, 2.8744, 2.3677, 3.1087, 2.9242,
         2.9841, 2.4318, 4.3679, 2.6167, 1.7908, 3.3702, 2.8076, 3.1746, 2.6684,
         3.1553, 2.4586, 3.1246, 2.9282, 3.1437, 1.7464, 3.0627, 3

In [24]:
for key in act_scales_max:
    if key.split('.')[-1] == 'fc2':
        print(act_scales_max[key].mean().item())


-0.13267572224140167
-0.06528269499540329
-0.1813398152589798
-0.0426807701587677
-0.04478975757956505
-0.004377354402095079
-0.004760784562677145
-0.005610581021755934
-0.0004388898378238082
7.889144035289064e-05
-0.0028311836067587137
0.4849059283733368
