In [2]:
import torch
import torch.nn as nn

import functools
from functools import partial

from collections import defaultdict

from tqdm import tqdm

import numpy as np

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

from utils import build_model_and_tokenizer
from quant import quant, dequant
from auto_gptq import AutoGPTQForCausalLM

from transformers.pytorch_utils import Conv1D

In [3]:
model_name = 'facebook/opt-6.7b'
device = "cuda:1"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testenc = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.24it/s]


In [5]:
# 약 5분 30초 가량 소모
'''
from utils import opt_eval
opt_eval(model, testenc, device)
'''

10.860100746154785

In [4]:
# 약 1분 16초 소모, base model (fp) perplexity
from utils import evaluate_opt, opt_eval #evaluate_opt: approx, opt_eval: more precise ppl
print(f'original ppl: {evaluate_opt(model, testenc).item()}')

# 약 28초 소모, cpu memory 약 67GB 까지 치솟음
from quant import quant, dequant
bits = 4
gs = 128
scale, zero, qs = quant(bits, gs, model)
q_x = dequant(scale, zero, qs, gs, bits)
for key in q_x.keys():
    if key.split('.')[-1] != 'lm_head':
        weight = key+'.weight'
        model.state_dict()[weight][:] = q_x[key]

# 1분 16초, quant_model perplexity
print(f'quantized model ppl: {evaluate_opt(model, testenc)}')

evaluating...: 100%|██████████| 40/40 [01:16<00:00,  1.92s/it]


original ppl: 10.673616409301758


evaluating...: 100%|██████████| 40/40 [01:16<00:00,  1.91s/it]

quantized model ppl: 10.967483520507812





In [None]:
def get_activation(model, dataset, tokenizer, num_samples=512, seq_len=512):
    model.eval()
    device = next(model.parameters()).device # next는 객체의 __next__ 호출, 다음 iter를 부름?
    act_scales = {}

    def stat_tensor(name, tensor):
        hidden_dim = tensor.shape[-1]
        tensor = tensor.view(-1, hidden_dim).abs().detach()
        comming_max = torch.max(tensor, dim=0)[0].float().cpu()
        if name in act_scales:
            act_scales[name] = torch.max(act_scales[name], comming_max)
        else:
            act_scales[name] = comming_max
    
    def stat_input_hook(m, x, y, name):
        if isinstance(x, tuple):
            x = x[0]
        stat_tensor(name, x)

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear) | isinstance(m, Conv1D):
            hooks.append(
                m.register_forward_hook(
                    functools.partial(stat_input_hook, name=name))
            )
    
    dataset = dataset.shuffle(seed=42)
    dataset_list = []
    for ii in range(len(dataset)):
        if dataset[ii]['text'] != '':
            dataset_list.append(dataset[ii])

    for i in tqdm(range(num_samples)):
        input_ids = tokenizer(dataset_list[i]["text"], return_tensors="pt",
                              max_length=seq_len, truncation=True).input_ids.to(device)
        model(input_ids)

    for h in hooks:
        h.remove()

    return act_ffn_min, act_ffn_max

In [None]:
from err_gen import error_injection