In [1]:
from model import GPT
import numpy as np
import torch
from tqdm import tqdm
from copy import deepcopy
import tiktoken
from datasets import load_from_disk

In [2]:
model = GPT.from_pretrained('/home/mfx/huggingface/gpt2-xl')
device = "cuda:4"
model = model.to(device)
enc = tiktoken.get_encoding("gpt2")
test = load_from_disk("/home/mfx/CLOVer/output/wikitext-2-raw-v1")['test']
all_input_ids = torch.LongTensor(enc.encode("\n\n".join(test["text"]))).unsqueeze(0)
seq_len = all_input_ids.size(1)-1

loading weights from pretrained gpt: gpt2-xl
forcing vocab_size=50257, block_size=1024, bias=True
number of parameters: 1555.97M


In [3]:
def calculating_perplexity(model, datasest):
    max_length = model.config.block_size
    stride = 512
    nll_sum = 0.0
    n_tokens = 0
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = datasest[:, begin_loc:end_loc].to(device)
        target_ids = datasest[:, begin_loc+1:end_loc+1].to(device)
        target_ids[:, :-trg_len] = -100
        with torch.no_grad():
            _,loss,_,_ = model(input_ids, targets=target_ids)

            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = loss

        # Accumulate the total negative log-likelihood and the total number of tokens
        num_valid_tokens = (target_ids != -100).sum().item()  # number of valid tokens in target_ids
        batch_size = target_ids.size(0)
        num_loss_tokens = num_valid_tokens - batch_size  # subtract batch_size due to internal label shift
        nll_sum += neg_log_likelihood * num_loss_tokens

        n_tokens += num_loss_tokens

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    avg_nll = nll_sum / n_tokens  # average negative log-likelihood per token
    ppl = torch.exp(avg_nll)
    return avg_nll, ppl
    

In [4]:
calculating_perplexity(model, all_input_ids)

100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(2.6933, device='cuda:4'), tensor(14.7809, device='cuda:4'))

In [5]:
qk_norm_list = []
vo_norm_list = []
for name, module in model.named_modules():
    if name.endswith("attn"):
        q_weight = deepcopy(module.q_proj.weight.data) # (hidden_size, q_in_dim)
        q_bias = deepcopy(module.q_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
        q_weight = torch.cat([q_weight, q_bias],dim=1)  # (hidden_size, q_in_dim+1)
        k_weight = deepcopy(module.k_proj.weight.data) # (hidden_size, k_in_dim)
        k_bias = deepcopy(module.k_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
        k_weight = torch.cat([k_weight, k_bias],dim=1)  # (hidden_size, k_in_dim+1)
        qk_norm = q_weight.norm(p=2,dim=-1) * k_weight.norm(p=2,dim=-1)
        qk_norm_list.append(qk_norm)
        
        v_weight = deepcopy(module.v_proj.weight.data) # (hidden_size, v_in_dim)
        v_bias = deepcopy(module.v_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
        v_weight = torch.cat([v_weight, v_bias],dim=1)  # (hidden_size, v_in_dim+1)
        o_weight = deepcopy(module.o_proj.weight.data) # (o_out_dim,hidden_size)
        vo_norm = v_weight.norm(p=2,dim=-1) * o_weight.norm(p=2,dim=0)
        vo_norm_list.append(vo_norm)
qk_norm_list = torch.stack(qk_norm_list)
vo_norm_list = torch.stack(vo_norm_list)

In [5]:
for i in range(63):
    ratio = i/64
    qk_quantile = torch.quantile(qk_norm_list, ratio)
    vo_quantile = torch.quantile(vo_norm_list, ratio)
    print(ratio, qk_quantile, vo_quantile)
    for name, module in model.named_modules():
        if name.endswith("attn"):
            q_weight = deepcopy(module.q_proj.weight.data) # (hidden_size, q_in_dim)
            q_bias = deepcopy(module.q_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            q_weight = torch.cat([q_weight, q_bias],dim=1)  # (hidden_size, q_in_dim+1)
            k_weight = deepcopy(module.k_proj.weight.data) # (hidden_size, k_in_dim)
            k_bias = deepcopy(module.k_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            k_weight = torch.cat([k_weight, k_bias],dim=1)  # (hidden_size, k_in_dim+1)
            qk_norm = q_weight.norm(p=2,dim=-1) * k_weight.norm(p=2,dim=-1)
            module.q_proj.weight.data[qk_norm<qk_quantile]=0
            module.q_proj.bias.data[qk_norm<qk_quantile]=0
            module.k_proj.weight.data[qk_norm<qk_quantile]=0
            module.k_proj.bias.data[qk_norm<qk_quantile]=0
    print(calculating_perplexity(model, all_input_ids))
    

0.0 tensor(0.6197, device='cuda:4') tensor(0.1824, device='cuda:4')


  0%|          | 0/562 [00:00<?, ?it/s]

100%|█████████▉| 560/562 [01:03<00:00,  8.76it/s]


(tensor(2.6933, device='cuda:4'), tensor(14.7809, device='cuda:4'))
0.015625 tensor(1.5088, device='cuda:4') tensor(0.5361, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(2.8149, device='cuda:4'), tensor(16.6913, device='cuda:4'))
0.03125 tensor(2.0249, device='cuda:4') tensor(0.7478, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(3.0299, device='cuda:4'), tensor(20.6960, device='cuda:4'))
0.046875 tensor(2.2731, device='cuda:4') tensor(0.8428, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(3.4355, device='cuda:4'), tensor(31.0459, device='cuda:4'))
0.0625 tensor(2.3851, device='cuda:4') tensor(0.9352, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(3.6798, device='cuda:4'), tensor(39.6370, device='cuda:4'))
0.078125 tensor(2.4543, device='cuda:4') tensor(1.0077, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.1096, device='cuda:4'), tensor(60.9204, device='cuda:4'))
0.09375 tensor(2.5030, device='cuda:4') tensor(1.0541, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.1835, device='cuda:4'), tensor(65.5932, device='cuda:4'))
0.109375 tensor(2.5409, device='cuda:4') tensor(1.0950, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(4.2020, device='cuda:4'), tensor(66.8213, device='cuda:4'))
0.125 tensor(2.5702, device='cuda:4') tensor(1.1331, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.2907, device='cuda:4'), tensor(73.0206, device='cuda:4'))
0.140625 tensor(2.5984, device='cuda:4') tensor(1.1650, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(4.3282, device='cuda:4'), tensor(75.8050, device='cuda:4'))
0.15625 tensor(2.6229, device='cuda:4') tensor(1.1957, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.3687, device='cuda:4'), tensor(78.9437, device='cuda:4'))
0.171875 tensor(2.6464, device='cuda:4') tensor(1.2269, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.4672, device='cuda:4'), tensor(87.1112, device='cuda:4'))
0.1875 tensor(2.6676, device='cuda:4') tensor(1.2580, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(4.4933, device='cuda:4'), tensor(89.4140, device='cuda:4'))
0.203125 tensor(2.6879, device='cuda:4') tensor(1.2923, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.5454, device='cuda:4'), tensor(94.1982, device='cuda:4'))
0.21875 tensor(2.7075, device='cuda:4') tensor(1.3303, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.5888, device='cuda:4'), tensor(98.3755, device='cuda:4'))
0.234375 tensor(2.7270, device='cuda:4') tensor(1.3692, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.6265, device='cuda:4'), tensor(102.1592, device='cuda:4'))
0.25 tensor(2.7462, device='cuda:4') tensor(1.4129, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.6338, device='cuda:4'), tensor(102.8994, device='cuda:4'))
0.265625 tensor(2.7644, device='cuda:4') tensor(1.4589, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.6661, device='cuda:4'), tensor(106.2831, device='cuda:4'))
0.28125 tensor(2.7822, device='cuda:4') tensor(1.4982, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(4.6694, device='cuda:4'), tensor(106.6294, device='cuda:4'))
0.296875 tensor(2.8009, device='cuda:4') tensor(1.5329, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.7128, device='cuda:4'), tensor(111.3615, device='cuda:4'))
0.3125 tensor(2.8188, device='cuda:4') tensor(1.5657, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(4.7500, device='cuda:4'), tensor(115.5855, device='cuda:4'))
0.328125 tensor(2.8360, device='cuda:4') tensor(1.6011, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(4.7216, device='cuda:4'), tensor(112.3464, device='cuda:4'))
0.34375 tensor(2.8525, device='cuda:4') tensor(1.6372, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(4.7723, device='cuda:4'), tensor(118.1865, device='cuda:4'))
0.359375 tensor(2.8688, device='cuda:4') tensor(1.6768, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.65it/s]


(tensor(4.8154, device='cuda:4'), tensor(123.3919, device='cuda:4'))
0.375 tensor(2.8858, device='cuda:4') tensor(1.7110, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(4.8347, device='cuda:4'), tensor(125.7953, device='cuda:4'))
0.390625 tensor(2.9022, device='cuda:4') tensor(1.7425, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(4.8694, device='cuda:4'), tensor(130.2411, device='cuda:4'))
0.40625 tensor(2.9197, device='cuda:4') tensor(1.7743, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(4.8919, device='cuda:4'), tensor(133.2045, device='cuda:4'))
0.421875 tensor(2.9373, device='cuda:4') tensor(1.8058, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(4.9102, device='cuda:4'), tensor(135.6686, device='cuda:4'))
0.4375 tensor(2.9547, device='cuda:4') tensor(1.8354, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(4.9300, device='cuda:4'), tensor(138.3821, device='cuda:4'))
0.453125 tensor(2.9732, device='cuda:4') tensor(1.8639, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(4.9462, device='cuda:4'), tensor(140.6419, device='cuda:4'))
0.46875 tensor(2.9924, device='cuda:4') tensor(1.8934, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(5.0030, device='cuda:4'), tensor(148.8555, device='cuda:4'))
0.484375 tensor(3.0120, device='cuda:4') tensor(1.9245, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(5.0196, device='cuda:4'), tensor(151.3533, device='cuda:4'))
0.5 tensor(3.0321, device='cuda:4') tensor(1.9553, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.66it/s]


(tensor(5.0809, device='cuda:4'), tensor(160.9242, device='cuda:4'))
0.515625 tensor(3.0538, device='cuda:4') tensor(1.9859, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(5.1279, device='cuda:4'), tensor(168.6678, device='cuda:4'))
0.53125 tensor(3.0763, device='cuda:4') tensor(2.0183, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(5.2202, device='cuda:4'), tensor(184.9759, device='cuda:4'))
0.546875 tensor(3.1022, device='cuda:4') tensor(2.0524, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.62it/s]


(tensor(5.3016, device='cuda:4'), tensor(200.6488, device='cuda:4'))
0.5625 tensor(3.1284, device='cuda:4') tensor(2.0866, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(5.3593, device='cuda:4'), tensor(212.5769, device='cuda:4'))
0.578125 tensor(3.1567, device='cuda:4') tensor(2.1217, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(5.5142, device='cuda:4'), tensor(248.1798, device='cuda:4'))
0.59375 tensor(3.1870, device='cuda:4') tensor(2.1593, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(5.6232, device='cuda:4'), tensor(276.7810, device='cuda:4'))
0.609375 tensor(3.2202, device='cuda:4') tensor(2.1931, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.65it/s]


(tensor(5.6710, device='cuda:4'), tensor(290.3139, device='cuda:4'))
0.625 tensor(3.2578, device='cuda:4') tensor(2.2306, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(5.7783, device='cuda:4'), tensor(323.2143, device='cuda:4'))
0.640625 tensor(3.2984, device='cuda:4') tensor(2.2688, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(5.8475, device='cuda:4'), tensor(346.3797, device='cuda:4'))
0.65625 tensor(3.3436, device='cuda:4') tensor(2.3062, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(5.9778, device='cuda:4'), tensor(394.5854, device='cuda:4'))
0.671875 tensor(3.3884, device='cuda:4') tensor(2.3416, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(6.0014, device='cuda:4'), tensor(404.0019, device='cuda:4'))
0.6875 tensor(3.4354, device='cuda:4') tensor(2.3770, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(6.0652, device='cuda:4'), tensor(430.5999, device='cuda:4'))
0.703125 tensor(3.4833, device='cuda:4') tensor(2.4094, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(6.1153, device='cuda:4'), tensor(452.7159, device='cuda:4'))
0.71875 tensor(3.5368, device='cuda:4') tensor(2.4448, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.66it/s]


(tensor(6.1890, device='cuda:4'), tensor(487.3760, device='cuda:4'))
0.734375 tensor(3.5889, device='cuda:4') tensor(2.4799, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(6.2178, device='cuda:4'), tensor(501.5765, device='cuda:4'))
0.75 tensor(3.6474, device='cuda:4') tensor(2.5164, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(6.2413, device='cuda:4'), tensor(513.5456, device='cuda:4'))
0.765625 tensor(3.7078, device='cuda:4') tensor(2.5513, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(6.2433, device='cuda:4'), tensor(514.5684, device='cuda:4'))
0.78125 tensor(3.7696, device='cuda:4') tensor(2.5879, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.65it/s]


(tensor(6.2956, device='cuda:4'), tensor(542.1810, device='cuda:4'))
0.796875 tensor(3.8340, device='cuda:4') tensor(2.6245, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.66it/s]


(tensor(6.3899, device='cuda:4'), tensor(595.8213, device='cuda:4'))
0.8125 tensor(3.9036, device='cuda:4') tensor(2.6634, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.66it/s]


(tensor(6.4300, device='cuda:4'), tensor(620.1620, device='cuda:4'))
0.828125 tensor(3.9703, device='cuda:4') tensor(2.7048, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.67it/s]


(tensor(6.4833, device='cuda:4'), tensor(654.1536, device='cuda:4'))
0.84375 tensor(4.0384, device='cuda:4') tensor(2.7497, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.65it/s]


(tensor(6.5413, device='cuda:4'), tensor(693.1728, device='cuda:4'))
0.859375 tensor(4.1129, device='cuda:4') tensor(2.7970, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(6.6293, device='cuda:4'), tensor(756.9771, device='cuda:4'))
0.875 tensor(4.1874, device='cuda:4') tensor(2.8467, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.65it/s]


(tensor(6.7225, device='cuda:4'), tensor(830.9122, device='cuda:4'))
0.890625 tensor(4.2723, device='cuda:4') tensor(2.9027, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.66it/s]


(tensor(6.8368, device='cuda:4'), tensor(931.5445, device='cuda:4'))
0.90625 tensor(4.3659, device='cuda:4') tensor(2.9642, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.67it/s]


(tensor(6.9417, device='cuda:4'), tensor(1034.4968, device='cuda:4'))
0.921875 tensor(4.4702, device='cuda:4') tensor(3.0370, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.0468, device='cuda:4'), tensor(1149.2113, device='cuda:4'))
0.9375 tensor(4.5972, device='cuda:4') tensor(3.1197, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.70it/s]


(tensor(7.1653, device='cuda:4'), tensor(1293.8013, device='cuda:4'))
0.953125 tensor(4.7720, device='cuda:4') tensor(3.2120, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.2911, device='cuda:4'), tensor(1467.1998, device='cuda:4'))
0.96875 tensor(5.0900, device='cuda:4') tensor(3.3482, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]

(tensor(7.4612, device='cuda:4'), tensor(1739.2566, device='cuda:4'))





In [5]:
for i in range(63):
    ratio = i/64
    qk_quantile = torch.quantile(qk_norm_list, ratio)
    vo_quantile = torch.quantile(vo_norm_list, ratio)
    print(ratio, qk_quantile, vo_quantile)
    for name, module in model.named_modules():
        if name.endswith("attn"):
            v_weight = deepcopy(module.v_proj.weight.data) # (hidden_size, v_in_dim)
            v_bias = deepcopy(module.v_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            v_weight = torch.cat([v_weight, v_bias],dim=1)  # (hidden_size, v_in_dim+1)
            o_weight = deepcopy(module.o_proj.weight.data) # (o_out_dim,hidden_size)
            vo_norm = v_weight.norm(p=2,dim=-1) * o_weight.norm(p=2,dim=0)
            module.v_proj.weight.data[vo_norm<vo_quantile]=0
            module.v_proj.bias.data[vo_norm<vo_quantile]=0
            module.v_proj.weight.data[:,vo_norm<vo_quantile]=0
    print(calculating_perplexity(model, all_input_ids))
    

0.0 tensor(0.6197, device='cuda:4') tensor(0.1824, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.75it/s]


(tensor(2.6933, device='cuda:4'), tensor(14.7809, device='cuda:4'))
0.015625 tensor(1.5088, device='cuda:4') tensor(0.5361, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(2.6959, device='cuda:4'), tensor(14.8191, device='cuda:4'))
0.03125 tensor(2.0249, device='cuda:4') tensor(0.7478, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(2.7000, device='cuda:4'), tensor(14.8797, device='cuda:4'))
0.046875 tensor(2.2731, device='cuda:4') tensor(0.8428, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(2.7118, device='cuda:4'), tensor(15.0569, device='cuda:4'))
0.0625 tensor(2.3851, device='cuda:4') tensor(0.9352, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(2.8808, device='cuda:4'), tensor(17.8292, device='cuda:4'))
0.078125 tensor(2.4543, device='cuda:4') tensor(1.0077, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(3.7263, device='cuda:4'), tensor(41.5258, device='cuda:4'))
0.09375 tensor(2.5030, device='cuda:4') tensor(1.0541, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(5.5208, device='cuda:4'), tensor(249.8459, device='cuda:4'))
0.109375 tensor(2.5409, device='cuda:4') tensor(1.0950, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(5.9289, device='cuda:4'), tensor(375.7513, device='cuda:4'))
0.125 tensor(2.5702, device='cuda:4') tensor(1.1331, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(6.2240, device='cuda:4'), tensor(504.7195, device='cuda:4'))
0.140625 tensor(2.5984, device='cuda:4') tensor(1.1650, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(6.5206, device='cuda:4'), tensor(679.0134, device='cuda:4'))
0.15625 tensor(2.6229, device='cuda:4') tensor(1.1957, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(6.4774, device='cuda:4'), tensor(650.2621, device='cuda:4'))
0.171875 tensor(2.6464, device='cuda:4') tensor(1.2269, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(6.5494, device='cuda:4'), tensor(698.8160, device='cuda:4'))
0.1875 tensor(2.6676, device='cuda:4') tensor(1.2580, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.3479, device='cuda:4'), tensor(571.3197, device='cuda:4'))
0.203125 tensor(2.6879, device='cuda:4') tensor(1.2923, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(6.3088, device='cuda:4'), tensor(549.3676, device='cuda:4'))
0.21875 tensor(2.7075, device='cuda:4') tensor(1.3303, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(6.2412, device='cuda:4'), tensor(513.4897, device='cuda:4'))
0.234375 tensor(2.7270, device='cuda:4') tensor(1.3692, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(6.2579, device='cuda:4'), tensor(522.1473, device='cuda:4'))
0.25 tensor(2.7462, device='cuda:4') tensor(1.4129, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(6.4282, device='cuda:4'), tensor(619.0568, device='cuda:4'))
0.265625 tensor(2.7644, device='cuda:4') tensor(1.4589, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.5205, device='cuda:4'), tensor(678.9036, device='cuda:4'))
0.28125 tensor(2.7822, device='cuda:4') tensor(1.4982, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(6.5225, device='cuda:4'), tensor(680.2676, device='cuda:4'))
0.296875 tensor(2.8009, device='cuda:4') tensor(1.5329, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.4413, device='cuda:4'), tensor(627.1923, device='cuda:4'))
0.3125 tensor(2.8188, device='cuda:4') tensor(1.5657, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.4452, device='cuda:4'), tensor(629.6584, device='cuda:4'))
0.328125 tensor(2.8360, device='cuda:4') tensor(1.6011, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.4318, device='cuda:4'), tensor(621.2968, device='cuda:4'))
0.34375 tensor(2.8525, device='cuda:4') tensor(1.6372, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(6.4519, device='cuda:4'), tensor(633.8748, device='cuda:4'))
0.359375 tensor(2.8688, device='cuda:4') tensor(1.6768, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(6.5256, device='cuda:4'), tensor(682.3607, device='cuda:4'))
0.375 tensor(2.8858, device='cuda:4') tensor(1.7110, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.2817, device='cuda:4'), tensor(534.7097, device='cuda:4'))
0.390625 tensor(2.9022, device='cuda:4') tensor(1.7425, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(6.1899, device='cuda:4'), tensor(487.7947, device='cuda:4'))
0.40625 tensor(2.9197, device='cuda:4') tensor(1.7743, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(6.2203, device='cuda:4'), tensor(502.8717, device='cuda:4'))
0.421875 tensor(2.9373, device='cuda:4') tensor(1.8058, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.2230, device='cuda:4'), tensor(504.2281, device='cuda:4'))
0.4375 tensor(2.9547, device='cuda:4') tensor(1.8354, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.2829, device='cuda:4'), tensor(535.3378, device='cuda:4'))
0.453125 tensor(2.9732, device='cuda:4') tensor(1.8639, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.3051, device='cuda:4'), tensor(547.3552, device='cuda:4'))
0.46875 tensor(2.9924, device='cuda:4') tensor(1.8934, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.3158, device='cuda:4'), tensor(553.2618, device='cuda:4'))
0.484375 tensor(3.0120, device='cuda:4') tensor(1.9245, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.3350, device='cuda:4'), tensor(563.9654, device='cuda:4'))
0.5 tensor(3.0321, device='cuda:4') tensor(1.9553, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(6.3760, device='cuda:4'), tensor(587.5867, device='cuda:4'))
0.515625 tensor(3.0538, device='cuda:4') tensor(1.9859, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(6.4026, device='cuda:4'), tensor(603.4080, device='cuda:4'))
0.53125 tensor(3.0763, device='cuda:4') tensor(2.0183, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.4176, device='cuda:4'), tensor(612.5451, device='cuda:4'))
0.546875 tensor(3.1022, device='cuda:4') tensor(2.0524, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.4291, device='cuda:4'), tensor(619.5972, device='cuda:4'))
0.5625 tensor(3.1284, device='cuda:4') tensor(2.0866, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.4486, device='cuda:4'), tensor(631.8312, device='cuda:4'))
0.578125 tensor(3.1567, device='cuda:4') tensor(2.1217, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.4985, device='cuda:4'), tensor(664.1646, device='cuda:4'))
0.59375 tensor(3.1870, device='cuda:4') tensor(2.1593, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.5430, device='cuda:4'), tensor(694.3326, device='cuda:4'))
0.609375 tensor(3.2202, device='cuda:4') tensor(2.1931, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.6874, device='cuda:4'), tensor(802.2726, device='cuda:4'))
0.625 tensor(3.2578, device='cuda:4') tensor(2.2306, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.8171, device='cuda:4'), tensor(913.3134, device='cuda:4'))
0.640625 tensor(3.2984, device='cuda:4') tensor(2.2688, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.8809, device='cuda:4'), tensor(973.4849, device='cuda:4'))
0.65625 tensor(3.3436, device='cuda:4') tensor(2.3062, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(7.0083, device='cuda:4'), tensor(1105.7375, device='cuda:4'))
0.671875 tensor(3.3884, device='cuda:4') tensor(2.3416, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(7.1066, device='cuda:4'), tensor(1219.9719, device='cuda:4'))
0.6875 tensor(3.4354, device='cuda:4') tensor(2.3770, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(7.2284, device='cuda:4'), tensor(1377.9847, device='cuda:4'))
0.703125 tensor(3.4833, device='cuda:4') tensor(2.4094, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(7.3174, device='cuda:4'), tensor(1506.2129, device='cuda:4'))
0.71875 tensor(3.5368, device='cuda:4') tensor(2.4448, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(7.4242, device='cuda:4'), tensor(1676.1086, device='cuda:4'))
0.734375 tensor(3.5889, device='cuda:4') tensor(2.4799, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(7.4903, device='cuda:4'), tensor(1790.5844, device='cuda:4'))
0.75 tensor(3.6474, device='cuda:4') tensor(2.5164, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(7.5706, device='cuda:4'), tensor(1940.3939, device='cuda:4'))
0.765625 tensor(3.7078, device='cuda:4') tensor(2.5513, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(7.6279, device='cuda:4'), tensor(2054.8201, device='cuda:4'))
0.78125 tensor(3.7696, device='cuda:4') tensor(2.5879, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.796875 tensor(3.8340, device='cuda:4') tensor(2.6245, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.8125 tensor(3.9036, device='cuda:4') tensor(2.6634, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.828125 tensor(3.9703, device='cuda:4') tensor(2.7048, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.84375 tensor(4.0384, device='cuda:4') tensor(2.7497, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.859375 tensor(4.1129, device='cuda:4') tensor(2.7970, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.875 tensor(4.1874, device='cuda:4') tensor(2.8467, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.890625 tensor(4.2723, device='cuda:4') tensor(2.9027, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.90625 tensor(4.3659, device='cuda:4') tensor(2.9642, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.921875 tensor(4.4702, device='cuda:4') tensor(3.0370, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.9375 tensor(4.5972, device='cuda:4') tensor(3.1197, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.953125 tensor(4.7720, device='cuda:4') tensor(3.2120, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))
0.96875 tensor(5.0900, device='cuda:4') tensor(3.3482, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]

(tensor(7.6285, device='cuda:4'), tensor(2055.9226, device='cuda:4'))





In [None]:
for i in range(63):
    ratio = i/64
    qk_quantile = torch.quantile(qk_norm_list, ratio)
    vo_quantile = torch.quantile(vo_norm_list, ratio)
    print(ratio, qk_quantile, vo_quantile)
    for name, module in model.named_modules():
        if name.endswith("attn"):
            q_weight = deepcopy(module.q_proj.weight.data) # (hidden_size, q_in_dim)
            q_bias = deepcopy(module.q_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            q_weight = torch.cat([q_weight, q_bias],dim=1)  # (hidden_size, q_in_dim+1)
            k_weight = deepcopy(module.k_proj.weight.data) # (hidden_size, k_in_dim)
            k_bias = deepcopy(module.k_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            k_weight = torch.cat([k_weight, k_bias],dim=1)  # (hidden_size, k_in_dim+1)
            qk_norm = q_weight.norm(p=2,dim=-1) * k_weight.norm(p=2,dim=-1)
            module.q_proj.weight.data[qk_norm<qk_quantile]=0
            module.q_proj.bias.data[qk_norm<qk_quantile]=0
            module.k_proj.weight.data[qk_norm<qk_quantile]=0
            module.k_proj.bias.data[qk_norm<qk_quantile]=0
            v_weight = deepcopy(module.v_proj.weight.data) # (hidden_size, v_in_dim)
            v_bias = deepcopy(module.v_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            v_weight = torch.cat([v_weight, v_bias],dim=1)  # (hidden_size, v_in_dim+1)
            o_weight = deepcopy(module.o_proj.weight.data) # (o_out_dim,hidden_size)
            vo_norm = v_weight.norm(p=2,dim=-1) * o_weight.norm(p=2,dim=0)
            module.v_proj.weight.data[vo_norm<vo_quantile]=0
            module.v_proj.bias.data[vo_norm<vo_quantile]=0
            module.v_proj.weight.data[:,vo_norm<vo_quantile]=0
    print(calculating_perplexity(model, all_input_ids))
    

0.0 tensor(0.6197, device='cuda:4') tensor(0.1824, device='cuda:4')


100%|█████████▉| 560/562 [01:03<00:00,  8.77it/s]


(tensor(2.6933, device='cuda:4'), tensor(14.7809, device='cuda:4'))
0.015625 tensor(1.5088, device='cuda:4') tensor(0.5361, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(2.8004, device='cuda:4'), tensor(16.4517, device='cuda:4'))
0.03125 tensor(2.0249, device='cuda:4') tensor(0.7478, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(2.8469, device='cuda:4'), tensor(17.2346, device='cuda:4'))
0.046875 tensor(2.2731, device='cuda:4') tensor(0.8428, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(3.0509, device='cuda:4'), tensor(21.1338, device='cuda:4'))
0.0625 tensor(2.3851, device='cuda:4') tensor(0.9352, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(3.4183, device='cuda:4'), tensor(30.5164, device='cuda:4'))
0.078125 tensor(2.4543, device='cuda:4') tensor(1.0077, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(4.6257, device='cuda:4'), tensor(102.0725, device='cuda:4'))
0.09375 tensor(2.5030, device='cuda:4') tensor(1.0541, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(5.8531, device='cuda:4'), tensor(348.3018, device='cuda:4'))
0.109375 tensor(2.5409, device='cuda:4') tensor(1.0950, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.1935, device='cuda:4'), tensor(489.5379, device='cuda:4'))
0.125 tensor(2.5702, device='cuda:4') tensor(1.1331, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.4652, device='cuda:4'), tensor(642.3787, device='cuda:4'))
0.140625 tensor(2.5984, device='cuda:4') tensor(1.1650, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.6721, device='cuda:4'), tensor(790.0670, device='cuda:4'))
0.15625 tensor(2.6229, device='cuda:4') tensor(1.1957, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.6794, device='cuda:4'), tensor(795.8639, device='cuda:4'))
0.171875 tensor(2.6464, device='cuda:4') tensor(1.2269, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.7696, device='cuda:4'), tensor(870.9604, device='cuda:4'))
0.1875 tensor(2.6676, device='cuda:4') tensor(1.2580, device='cuda:4')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.9586, device='cuda:4'), tensor(1052.1194, device='cuda:4'))
0.203125 tensor(2.6879, device='cuda:4') tensor(1.2923, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(7.0259, device='cuda:4'), tensor(1125.3682, device='cuda:4'))
0.21875 tensor(2.7075, device='cuda:4') tensor(1.3303, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(6.9522, device='cuda:4'), tensor(1045.4386, device='cuda:4'))
0.234375 tensor(2.7270, device='cuda:4') tensor(1.3692, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(6.8899, device='cuda:4'), tensor(982.2654, device='cuda:4'))
0.25 tensor(2.7462, device='cuda:4') tensor(1.4129, device='cuda:4')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(6.9386, device='cuda:4'), tensor(1031.3082, device='cuda:4'))
0.265625 tensor(2.7644, device='cuda:4') tensor(1.4589, device='cuda:4')


 67%|██████▋   | 377/562 [00:43<00:21,  8.64it/s]