In [22]:
from quant import quant, dequant
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import evaluate_opt, opt_eval
from datasets import load_dataset
from copy import deepcopy

opt_model_id = 'facebook/opt-125m'
gpt_model_id = 'gpt2'
device = "cuda"
opt_model = AutoModelForCausalLM.from_pretrained(opt_model_id)
#gpt_model = AutoModelForCausalLM.from_pretrained(gpt_model_id)
opt_tokenizer = AutoTokenizer.from_pretrained(opt_model_id)
#gpt_tokenizer = AutoTokenizer.from_pretrained(gpt_model_id)

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
opt_testenc = opt_tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")

gs = 4
q_bits = 2
opt_scale, opt_zero, opt_qweight = quant(q_bits, gs, opt_model)
dequant_weight = dequant(opt_scale, opt_zero, opt_qweight, gs, q_bits)
dequant_opt = deepcopy(opt_model)
for key in dequant_weight.keys():
    if key.split('.')[-1] != 'lm_head':
        weight = key+'.weight'
        dequant_opt.state_dict()[weight][:] = dequant_weight[key]

In [23]:
original_ppl = opt_eval(opt_model, opt_testenc, device)
quant_ppl = opt_eval(dequant_opt, opt_testenc, device)
print(quant_ppl-original_ppl)

14.701229095458984


In [3]:
import numpy as np

def calculate_qsnr(original_signal, quantized_signal):
    """
    Calculate the Quantization Signal-to-Noise Ratio (QSNR).

    Parameters:
    original_signal (numpy array): The original signal vector.
    quantized_signal (numpy array): The quantized signal vector.

    Returns:
    float: The QSNR value in decibels (dB).
    """
    # Calculate the mean squared error (MSE) between the original and quantized signals
    mse = np.mean((quantized_signal - original_signal) ** 2)
    
    # Calculate the mean squared value of the original signal
    signal_power = np.mean(original_signal ** 2)
    
    # Calculate the QSNR in dB
    qsnr = -10 * np.log10(mse / signal_power)
    
    return qsnr

In [4]:
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D

import functools
from tqdm import tqdm

def get_activation(model, dataset, tokenizer, num_samples=512, seq_len=512):
    model.eval()
    device = next(model.parameters()).device # next는 객체의 __next__ 호출, 다음 iter를 부름?
    activations = {}

    def stat_tensor(name, tensor):
        hidden_dim = tensor.shape[-1]
        tensor = tensor.view(-1, hidden_dim).detach().cpu()
        if name in activations:
            activations[name].append(tensor)
        else:
            activations[name] = list()
            activations[name].append(tensor)

    def stat_input_hook(m, x, y, name):
        if isinstance(x, tuple):
            y = y[0]
        stat_tensor(name, y)

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear) | isinstance(m, Conv1D):
            if name.split('.')[-1] != 'lm_head':
                hooks.append(
                    m.register_forward_hook(
                        functools.partial(stat_input_hook, name=name))
                )
    
    dataset = dataset.shuffle(seed=42)
    dataset_list = []
    for ii in range(len(dataset)):
        if dataset[ii]['text'] != '':
            dataset_list.append(dataset[ii])

    for i in tqdm(range(num_samples)):
        input_ids = tokenizer(dataset_list[i]["text"], return_tensors="pt",
                              max_length=seq_len, truncation=True).input_ids.to(device)
        model(input_ids)

    for h in hooks:
        h.remove()

    return activations

In [7]:
activations = get_activation(opt_model.to(device), dataset, opt_tokenizer)

100%|██████████| 512/512 [00:05<00:00, 94.15it/s] 


In [8]:
for key in activations:
    torch.save(activations[key], f'loss_quantification/opt_125m_acts/{key}.pt')