In [1]:
# 3.1s
from utils import build_model_and_tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from quant import quant, dequant, quant_unpack, dequant_unpack
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
import functools
from functools import partial
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = 'facebook/opt-125m'
device = "cuda"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testenc = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")

In [3]:
bits = 8
gs = 16
scale, zero, qs = quant(bits, gs, model)
q_x = dequant(scale, zero, qs, gs, bits)

In [4]:
for key in scale:
    print(scale[key].size())

torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([3072, 48])
torch.Size([768, 192])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([3072, 48])
torch.Size([768, 192])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([3072, 48])
torch.Size([768, 192])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([3072, 48])
torch.Size([768, 192])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([3072, 48])
torch.Size([768, 192])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([3072, 48])
torch.Size([768, 192])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([3072, 48])
torch.Size([768, 192])
torch.Size([768, 48])
torch.Size([768, 48])
torch.Size([768, 4

In [5]:
model.eval()
device = next(model.parameters()).device # next는 객체의 __next__ 호출, 다음 iter를 부름?
act_scales = {}
scaled_act_max = {}
scaled_act_min = {}
scaled_act_sum = {}
scaled_act_absum = {}
scaled_act_numel = {}

def stat_tensor(name, tensor):
    hidden_dim = tensor.shape[-1]
    tensor = tensor.view(-1, hidden_dim).detach()
    act_shape = tensor.shape
    for i in range(len(scale[name])):
        temp = scale[name][i].expand(gs, -1).T.flatten().expand(act_shape[0], -1).to(device)
        scaled_act = (tensor * temp)
        if name in scaled_act_max:
            scaled_act_max[name] = scaled_act.max() if scaled_act_max[name] < scaled_act.max() else scaled_act_max[name]
            scaled_act_min[name] = scaled_act.min() if scaled_act_min[name] > scaled_act.min() else scaled_act_min[name]
            scaled_act_sum[name] += scaled_act.sum()
            scaled_act_absum[name] += torch.abs(scaled_act).sum()
            scaled_act_numel[name] += torch.numel(scaled_act)
        else:
            scaled_act_max[name] = scaled_act.max()
            scaled_act_min[name] = scaled_act.min()
            scaled_act_sum[name] = scaled_act.sum()
            scaled_act_absum[name] = torch.abs(scaled_act).sum()
            scaled_act_numel[name] = torch.numel(scaled_act)


def stat_input_hook(m, x, y, name):
    if isinstance(x, tuple):
        x = x[0]
    stat_tensor(name, x)


In [6]:
hooks = []
for name, m in model.named_modules():
    if isinstance(m, nn.Linear) | isinstance(m, Conv1D):
        if name.split('.')[-1] != 'lm_head':
            hooks.append(
                m.register_forward_hook(
                    functools.partial(stat_input_hook, name=name))
            )
        

In [7]:
dataset = dataset.shuffle(seed=42)
dataset_list = []
for ii in range(len(dataset)):
    if dataset[ii]['text'] != '':
        dataset_list.append(dataset[ii])


In [8]:
num_samples = 4
seq_len = 64

for i in tqdm(range(num_samples)):
    input_ids = tokenizer(dataset_list[i]["text"], return_tensors="pt",
                            max_length=seq_len, truncation=True).input_ids.to(device)
    model(input_ids)

for h in hooks:
    h.remove()

100%|██████████| 4/4 [00:24<00:00,  6.11s/it]


In [9]:
scaled_act_avg = {}
scaled_act_absavg = {}
scaled_act_minmax = {}
for key in scaled_act_max:
    scaled_act_avg[key] = scaled_act_sum[key] / scaled_act_numel[key]
    scaled_act_absavg[key] = scaled_act_absum[key] / scaled_act_numel[key]
    scaled_act_minmax[key] = scaled_act_max[key] - scaled_act_min[key]

In [10]:
torch.save(scaled_act_max, 'scaled_act/8bit/gs16/scaled_act_max_opt_125m.pt')
torch.save(scaled_act_min, 'scaled_act/8bit/gs16/scaled_act_min_opt_125m.pt')
torch.save(scaled_act_minmax, 'scaled_act/8bit/gs16/scaled_act_minmax_opt_125m.pt')
torch.save(scaled_act_avg, 'scaled_act/8bit/gs16/scaled_act_avg_opt_125m.pt')
torch.save(scaled_act_absavg, 'scaled_act/8bit/gs16/scaled_act_absavg_opt_125m.pt')