In [None]:
import torch
H_dict = torch.load('collected_H/H_opt-6.7b_pajama_seed0.pt')

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from spqr.evalutils import evaluate_perplexity
import torch
from copy import deepcopy

model_path = '/raid/LLM/opt-6.7b/'
device = 'cuda:3'
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.half)
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
from spqr.quant_groups import Quantizer, quantize, dequantize, quantize_dequantize
from spqr.spqr_engine import calculate_bit_error_injection_mask_quantized
import time

linear_list = [
    'self_attn.k_proj',
    'self_attn.q_proj',
    'self_attn.v_proj',
    'self_attn.out_proj',
    'fc1',
    'fc2',
]


In [None]:



qweight_dict = {}
row_perm_dict = {}
scale_dict = {}
zero_dict = {}
for i in range(len(layers)):
    qweight_dict[i] = {}
    row_perm_dict[i] = {}
    scale_dict[i] = {}
    zero_dict[i] = {}
    layer = layers[i]
    sublayers = {name: sublayer for name, sublayer in layer.named_modules() if name in linear_list}
    for name, sublayer in sublayers.items():
        tick = time.time()
        weight = sublayer.weight.detach().clone()
        col_perm, dead, H_inv_diag = H_dict[i][name]
        weight = weight[:, col_perm]
        weight[:, dead] = 0
        weight = weight.to(device)
        out_dim, in_dim = weight.shape

        quantizer = Quantizer(weight.shape)
        quantizer.configure(4, True, False)
        quantizer.find_params(weight, weight=True)
        row_perm = torch.argsort(quantizer.scale.T.squeeze(), descending=True).to(device)
        qweight = quantize(weight, quantizer.scale, quantizer.zero, quantizer.maxq)
        qweight = qweight[row_perm, :]
        qweight_dict[i][name] = qweight.to(torch.int8)
        row_perm_dict[i][name] = row_perm
        scale_dict[i][name] = quantizer.scale
        zero_dict[i][name] = quantizer.zero
    
torch.save(qweight_dict, f'ordered_quant_models/opt-6.7b-qweight.pt')
torch.save(row_perm_dict, f'ordered_quant_models/opt-6.7b-row-perm.pt')
torch.save(scale_dict, f'ordered_quant_models/opt-6.7b-scale.pt')
torch.save(zero_dict, f'ordered_quant_models/opt-6.7b-zero.pt')

In [None]:
from spqr.errorutils import error_injection

qweight_dict = torch.load(f'ordered_quant_models/opt-6.7b-qweight.pt')
row_perm_dict = torch.load(f'ordered_quant_models/opt-6.7b-row-perm.pt')
scale_dict = torch.load(f'ordered_quant_models/opt-6.7b-scale.pt')
zero_dict = torch.load(f'ordered_quant_models/opt-6.7b-zero.pt')
for i in qweight_dict:
    for name in qweight_dict[i]:
        qweight_dict[i][name] = qweight_dict[i][name].to(device)
        row_perm_dict[i][name] = row_perm_dict[i][name].to(device)
        scale_dict[i][name] = scale_dict[i][name].to(device)
        zero_dict[i][name] = zero_dict[i][name].to(device)


In [None]:

for percentile in range(10, 1, -1):
    print(f'error masking percentile: {percentile}%')
    seed = 0
    cp_model = deepcopy(model).to(device)
    layers = cp_model.model.decoder.layers
    for i in range(len(layers)):
        layer = layers[i]
        sublayers = {name: sublayer for name, sublayer in layer.named_modules() if name in linear_list}
        for name, sublayer in sublayers.items():
            qweight = qweight_dict[i][name].clone()
            row_perm = row_perm_dict[i][name]
            col_perm, _, _ = H_dict[i][name]
            col_perm = col_perm.to(device)
            err_matrix = error_injection(
                qweight, 1e-3, seed, 4, device
            ).reshape_as(qweight)
            err_mask_row = round(qweight.shape[0] * (percentile/100))
            err_mask_col = round(qweight.shape[1] * (percentile/100))
            err_matrix[:err_mask_row, :err_mask_col] = 0
            qweight = qweight.to(torch.int32) ^ err_matrix.to(device)
            row_invperm = torch.argsort(row_perm).to(device)
            qweight = qweight[row_invperm, :]
            scale = scale_dict[i][name].to(device)
            zero = zero_dict[i][name].to(device)
            dqweight = dequantize(qweight, scale, zero)

            col_invperm = torch.argsort(col_perm).to(device)
            dqweight = dqweight[:, col_invperm]
            sublayer.weight.data = dqweight.to(device)
            seed = seed + 10

    if model.device == device:
        print(evaluate_perplexity(cp_model, tokenizer))
    else:
        print(evaluate_perplexity(cp_model.to(device), tokenizer))