In [2]:
import torch
import torch.nn as nn

import functools
from functools import partial

from collections import defaultdict

from tqdm import tqdm

import numpy as np

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

from utils import build_model_and_tokenizer, opt_eval
from quant import quant, dequant
from auto_gptq import AutoGPTQForCausalLM

from transformers.pytorch_utils import Conv1D
from copy import deepcopy
import gc

from err_gen import error_injection

model_name = 'facebook/opt-125m'
device = "cuda"
orig_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testenc = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def cal_qsnr(orig_tensor, target_tensor):
    mse = torch.mean((orig_tensor - target_tensor) ** 2)
    
    # Calculate the mean squared value of the original signal
    signal_power = torch.mean(orig_tensor ** 2)
    
    # Calculate the QSNR in dB
    qsnr = -10 * np.log(mse / signal_power)
    
    return qsnr

In [32]:
p = 0.003
p*(1-p)**7 * 8

0.02350051338791765

In [90]:
def cal_syn_qsnr(orig_tensor, target_tensor, ber, bits, scale):
    avg_bit_error =  ((2**bits - 1) / bits) * torch.mean(scale)
    num_error_elem = orig_tensor.numel() * bits * ber
    mse = torch.mean((orig_tensor - target_tensor) ** 2)
    sum = torch.sum((orig_tensor - target_tensor) ** 2) + (avg_bit_error ** 2) * num_error_elem
    mse = sum / orig_tensor.numel()
    signal_power = torch.mean(orig_tensor ** 2)
    qsnr = -10 * np.log(mse / signal_power)
    return qsnr

In [91]:
gs = 16
bits = 8
model = deepcopy(orig_model)
scale, zero, qs = quant(bits, gs, model)
q_x = dequant(scale, zero, qs, gs, bits)
qsnr = 0
num_key = 0
for key in q_x.keys():
    if key.split('.')[-1] != 'lm_head':
        num_key += 1
        weight = key+'.weight'
        orig_tensor = orig_model.state_dict()[weight]
        qsnr += cal_syn_qsnr(orig_tensor, q_x[key], 1e-3, bits, scale[key])

print(f'average qsnr = {qsnr / num_key}')

average qsnr = 64.99832916259766


In [6]:
# original perplexity of quantized model
precision = [8]
group_size = [128]
ber = np.linspace(1, 10, 10) * 1e-3
loop_cnt = 10

for bits in precision:
    for gs in group_size:
        for i, rate in enumerate(ber):
            for j in range(loop_cnt):
                print(f'bit error rate = {rate}')
                print(f'q_bits: {bits} group_size: {gs}')
                model = deepcopy(orig_model)
                scale, zero, qs = quant(bits, gs, model)
                qs_err = deepcopy(qs)
                for key in qs:
                    qs_err[key] = error_injection(qs_err[key], rate, (i+1)*(42+j), "cpu")
                q_x = dequant(scale, zero, qs_err, gs, bits)

                qsnr = 0
                num_key = 0
                for key in q_x.keys():
                    if key.split('.')[-1] != 'lm_head':
                        num_key += 1
                        weight = key+'.weight'
                        name = key.split('.')[-1]
                        layer = key.split('.')[-2] if (name == 'fc1') or (name == 'fc2') else key.split('.')[-3]
                        orig_tensor = orig_model.state_dict()[weight]
                        qsnr += cal_qsnr(orig_tensor, q_x[key])

                        model.state_dict()[weight][:] = q_x[key]
                
                print(f'average qsnr = {qsnr / num_key}')
                
                print(f'ppl {opt_eval(model, testenc, device)}')
                del model, scale, zero, qs, q_x, qs_err
                gc.collect()


bit error rate = 0.001
q_bits: 8 group_size: 128
average qsnr = 44.965667724609375
ppl 31.544193267822266
bit error rate = 0.001
q_bits: 8 group_size: 128
average qsnr = 45.32555389404297
ppl 33.78642654418945
bit error rate = 0.001
q_bits: 8 group_size: 128
average qsnr = 45.10350036621094
ppl 33.88242721557617
bit error rate = 0.001
q_bits: 8 group_size: 128
average qsnr = 45.36417007446289
ppl 342.2835388183594
bit error rate = 0.001
q_bits: 8 group_size: 128
average qsnr = 45.085811614990234
ppl 30.85785675048828
bit error rate = 0.001
q_bits: 8 group_size: 128
average qsnr = 45.442657470703125
ppl 44.03659439086914
bit error rate = 0.001
q_bits: 8 group_size: 128
average qsnr = 45.00423049926758
ppl 33.49163818359375
bit error rate = 0.001
q_bits: 8 group_size: 128
average qsnr = 45.13892364501953
ppl 36.879364013671875
bit error rate = 0.001
q_bits: 8 group_size: 128
average qsnr = 44.96913528442383
ppl 30.077341079711914
bit error rate = 0.001
q_bits: 8 group_size: 128
average q