In [1]:
from model import GPT
import numpy as np
import torch
from tqdm import tqdm
from copy import deepcopy
import tiktoken
from datasets import load_from_disk

In [2]:
orthogonal_model = GPT.from_pretrained('/home/mfx/huggingface/gpt2-xl')
device = "cuda:3"
state_dict = torch.load("/home/mfx/CLOVer/output/orthogonal/gpt2-xl.pt")
orthogonal_model.load_state_dict(state_dict)
orthogonal_model = orthogonal_model.to(device)
enc = tiktoken.get_encoding("gpt2")
test = load_from_disk("/home/mfx/CLOVer/output/wikitext-2-raw-v1")['test']
all_input_ids = torch.LongTensor(enc.encode("\n\n".join(test["text"]))).unsqueeze(0)
seq_len = all_input_ids.size(1)-1

loading weights from pretrained gpt: gpt2-xl
forcing vocab_size=50257, block_size=1024, bias=True
number of parameters: 1555.97M


  state_dict = torch.load("/home/mfx/CLOVer/output/orthogonal/gpt2-xl.pt")


In [3]:
def calculating_perplexity(model, datasest):
    max_length = model.config.block_size
    stride = 512
    nll_sum = 0.0
    n_tokens = 0
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = datasest[:, begin_loc:end_loc].to(device)
        target_ids = datasest[:, begin_loc+1:end_loc+1].to(device)
        #target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100
        with torch.no_grad():
            _,loss,_,_ = model(input_ids, targets=target_ids)

            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = loss

        # Accumulate the total negative log-likelihood and the total number of tokens
        num_valid_tokens = (target_ids != -100).sum().item()  # number of valid tokens in target_ids
        batch_size = target_ids.size(0)
        num_loss_tokens = num_valid_tokens - batch_size  # subtract batch_size due to internal label shift
        nll_sum += neg_log_likelihood * num_loss_tokens

        n_tokens += num_loss_tokens
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    avg_nll = nll_sum / n_tokens  # average negative log-likelihood per token
    ppl = torch.exp(avg_nll)
    return avg_nll, ppl
    

In [4]:
calculating_perplexity(orthogonal_model, all_input_ids)

100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(2.6933, device='cuda:3'), tensor(14.7811, device='cuda:3'))

In [4]:
qk_norm_list = []
vo_norm_list = []
for name, module in orthogonal_model.named_modules():
    if name.endswith("attn"):
        q_weight = deepcopy(module.q_proj.weight.data) # (hidden_size, q_in_dim)
        q_bias = deepcopy(module.q_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
        q_weight = torch.cat([q_weight, q_bias],dim=1)  # (hidden_size, q_in_dim+1)
        k_weight = deepcopy(module.k_proj.weight.data) # (hidden_size, k_in_dim)
        k_bias = deepcopy(module.k_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
        k_weight = torch.cat([k_weight, k_bias],dim=1)  # (hidden_size, k_in_dim+1)
        qk_norm = q_weight.norm(p=2,dim=-1) * k_weight.norm(p=2,dim=-1)
        qk_norm_list.append(qk_norm)
        
        v_weight = deepcopy(module.v_proj.weight.data) # (hidden_size, v_in_dim)
        v_bias = deepcopy(module.v_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
        v_weight = torch.cat([v_weight, v_bias],dim=1)  # (hidden_size, v_in_dim+1)
        o_weight = deepcopy(module.o_proj.weight.data) # (o_out_dim,hidden_size)
        vo_norm = v_weight.norm(p=2,dim=-1) * o_weight.norm(p=2,dim=0)
        vo_norm_list.append(vo_norm)
qk_norm_list = torch.stack(qk_norm_list)
vo_norm_list = torch.stack(vo_norm_list)

In [5]:
for i in range(63):
    ratio = i/64
    qk_quantile = torch.quantile(qk_norm_list, ratio)
    vo_quantile = torch.quantile(vo_norm_list, ratio)
    print(ratio, qk_quantile, vo_quantile)
    for name, module in orthogonal_model.named_modules():
        if name.endswith("attn"):
            q_weight = deepcopy(module.q_proj.weight.data) # (hidden_size, q_in_dim)
            q_bias = deepcopy(module.q_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            q_weight = torch.cat([q_weight, q_bias],dim=1)  # (hidden_size, q_in_dim+1)
            k_weight = deepcopy(module.k_proj.weight.data) # (hidden_size, k_in_dim)
            k_bias = deepcopy(module.k_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            k_weight = torch.cat([k_weight, k_bias],dim=1)  # (hidden_size, k_in_dim+1)
            qk_norm = q_weight.norm(p=2,dim=-1) * k_weight.norm(p=2,dim=-1)
            module.q_proj.weight.data[qk_norm<qk_quantile]=0
            module.q_proj.bias.data[qk_norm<qk_quantile]=0
            module.k_proj.weight.data[qk_norm<qk_quantile]=0
            module.k_proj.bias.data[qk_norm<qk_quantile]=0
    print(calculating_perplexity(orthogonal_model, all_input_ids))
    

0.0 tensor(0.2857, device='cuda:3') tensor(0.1108, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(2.6933, device='cuda:3'), tensor(14.7811, device='cuda:3'))
0.015625 tensor(0.6729, device='cuda:3') tensor(0.4198, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(2.6937, device='cuda:3'), tensor(14.7858, device='cuda:3'))
0.03125 tensor(1.0134, device='cuda:3') tensor(0.5602, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(2.6943, device='cuda:3'), tensor(14.7958, device='cuda:3'))
0.046875 tensor(1.3605, device='cuda:3') tensor(0.6584, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(2.6986, device='cuda:3'), tensor(14.8595, device='cuda:3'))
0.0625 tensor(1.6007, device='cuda:3') tensor(0.7402, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(2.7219, device='cuda:3'), tensor(15.2098, device='cuda:3'))
0.078125 tensor(1.7432, device='cuda:3') tensor(0.8104, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(2.7562, device='cuda:3'), tensor(15.7396, device='cuda:3'))
0.09375 tensor(1.8402, device='cuda:3') tensor(0.8713, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(2.7642, device='cuda:3'), tensor(15.8658, device='cuda:3'))
0.109375 tensor(1.9137, device='cuda:3') tensor(0.9268, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.54it/s]


(tensor(2.7739, device='cuda:3'), tensor(16.0217, device='cuda:3'))
0.125 tensor(1.9766, device='cuda:3') tensor(0.9791, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.53it/s]


(tensor(2.8678, device='cuda:3'), tensor(17.5985, device='cuda:3'))
0.140625 tensor(2.0301, device='cuda:3') tensor(1.0261, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.54it/s]


(tensor(2.8886, device='cuda:3'), tensor(17.9677, device='cuda:3'))
0.15625 tensor(2.0798, device='cuda:3') tensor(1.0710, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.54it/s]


(tensor(2.9098, device='cuda:3'), tensor(18.3523, device='cuda:3'))
0.171875 tensor(2.1251, device='cuda:3') tensor(1.1131, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(2.9296, device='cuda:3'), tensor(18.7192, device='cuda:3'))
0.1875 tensor(2.1676, device='cuda:3') tensor(1.1523, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.54it/s]


(tensor(2.9367, device='cuda:3'), tensor(18.8544, device='cuda:3'))
0.203125 tensor(2.2084, device='cuda:3') tensor(1.1899, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(2.9551, device='cuda:3'), tensor(19.2039, device='cuda:3'))
0.21875 tensor(2.2471, device='cuda:3') tensor(1.2257, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(2.9675, device='cuda:3'), tensor(19.4431, device='cuda:3'))
0.234375 tensor(2.2849, device='cuda:3') tensor(1.2603, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(2.9775, device='cuda:3'), tensor(19.6390, device='cuda:3'))
0.25 tensor(2.3207, device='cuda:3') tensor(1.2941, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(2.9907, device='cuda:3'), tensor(19.9006, device='cuda:3'))
0.265625 tensor(2.3567, device='cuda:3') tensor(1.3267, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(3.0057, device='cuda:3'), tensor(20.1995, device='cuda:3'))
0.28125 tensor(2.3911, device='cuda:3') tensor(1.3595, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(3.0086, device='cuda:3'), tensor(20.2591, device='cuda:3'))
0.296875 tensor(2.4252, device='cuda:3') tensor(1.3923, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(3.0270, device='cuda:3'), tensor(20.6362, device='cuda:3'))
0.3125 tensor(2.4592, device='cuda:3') tensor(1.4244, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(3.0418, device='cuda:3'), tensor(20.9439, device='cuda:3'))
0.328125 tensor(2.4923, device='cuda:3') tensor(1.4568, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(3.0643, device='cuda:3'), tensor(21.4192, device='cuda:3'))
0.34375 tensor(2.5272, device='cuda:3') tensor(1.4893, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(3.0821, device='cuda:3'), tensor(21.8041, device='cuda:3'))
0.359375 tensor(2.5594, device='cuda:3') tensor(1.5212, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(3.0855, device='cuda:3'), tensor(21.8779, device='cuda:3'))
0.375 tensor(2.5931, device='cuda:3') tensor(1.5534, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(3.0870, device='cuda:3'), tensor(21.9111, device='cuda:3'))
0.390625 tensor(2.6255, device='cuda:3') tensor(1.5864, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(3.2503, device='cuda:3'), tensor(25.7968, device='cuda:3'))
0.40625 tensor(2.6607, device='cuda:3') tensor(1.6183, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(3.2802, device='cuda:3'), tensor(26.5798, device='cuda:3'))
0.421875 tensor(2.6936, device='cuda:3') tensor(1.6500, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(3.3246, device='cuda:3'), tensor(27.7887, device='cuda:3'))
0.4375 tensor(2.7272, device='cuda:3') tensor(1.6826, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(3.3765, device='cuda:3'), tensor(29.2693, device='cuda:3'))
0.453125 tensor(2.7603, device='cuda:3') tensor(1.7163, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(3.4007, device='cuda:3'), tensor(29.9849, device='cuda:3'))
0.46875 tensor(2.7960, device='cuda:3') tensor(1.7497, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(3.4763, device='cuda:3'), tensor(32.3412, device='cuda:3'))
0.484375 tensor(2.8294, device='cuda:3') tensor(1.7831, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(3.4850, device='cuda:3'), tensor(32.6232, device='cuda:3'))
0.5 tensor(2.8649, device='cuda:3') tensor(1.8166, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(3.4977, device='cuda:3'), tensor(33.0399, device='cuda:3'))
0.515625 tensor(2.9024, device='cuda:3') tensor(1.8516, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(3.5090, device='cuda:3'), tensor(33.4153, device='cuda:3'))
0.53125 tensor(2.9387, device='cuda:3') tensor(1.8869, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(3.5560, device='cuda:3'), tensor(35.0234, device='cuda:3'))
0.546875 tensor(2.9766, device='cuda:3') tensor(1.9229, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(3.5798, device='cuda:3'), tensor(35.8673, device='cuda:3'))
0.5625 tensor(3.0143, device='cuda:3') tensor(1.9593, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(3.6226, device='cuda:3'), tensor(37.4340, device='cuda:3'))
0.578125 tensor(3.0534, device='cuda:3') tensor(1.9951, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(3.5930, device='cuda:3'), tensor(36.3420, device='cuda:3'))
0.59375 tensor(3.0928, device='cuda:3') tensor(2.0328, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(3.6275, device='cuda:3'), tensor(37.6182, device='cuda:3'))
0.609375 tensor(3.1354, device='cuda:3') tensor(2.0714, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(3.6563, device='cuda:3'), tensor(38.7159, device='cuda:3'))
0.625 tensor(3.1779, device='cuda:3') tensor(2.1100, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(3.6877, device='cuda:3'), tensor(39.9525, device='cuda:3'))
0.640625 tensor(3.2234, device='cuda:3') tensor(2.1504, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(3.7443, device='cuda:3'), tensor(42.2796, device='cuda:3'))
0.65625 tensor(3.2723, device='cuda:3') tensor(2.1914, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(3.8942, device='cuda:3'), tensor(49.1181, device='cuda:3'))
0.671875 tensor(3.3203, device='cuda:3') tensor(2.2336, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(3.9478, device='cuda:3'), tensor(51.8226, device='cuda:3'))
0.6875 tensor(3.3726, device='cuda:3') tensor(2.2771, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(4.0304, device='cuda:3'), tensor(56.2832, device='cuda:3'))
0.703125 tensor(3.4262, device='cuda:3') tensor(2.3214, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(4.1120, device='cuda:3'), tensor(61.0662, device='cuda:3'))
0.71875 tensor(3.4847, device='cuda:3') tensor(2.3669, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.62it/s]


(tensor(4.1859, device='cuda:3'), tensor(65.7522, device='cuda:3'))
0.734375 tensor(3.5446, device='cuda:3') tensor(2.4165, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(4.2272, device='cuda:3'), tensor(68.5273, device='cuda:3'))
0.75 tensor(3.6122, device='cuda:3') tensor(2.4650, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(4.6142, device='cuda:3'), tensor(100.9049, device='cuda:3'))
0.765625 tensor(3.6815, device='cuda:3') tensor(2.5200, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(4.6613, device='cuda:3'), tensor(105.7733, device='cuda:3'))
0.78125 tensor(3.7603, device='cuda:3') tensor(2.5764, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(4.7247, device='cuda:3'), tensor(112.6915, device='cuda:3'))
0.796875 tensor(3.8435, device='cuda:3') tensor(2.6354, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(4.8658, device='cuda:3'), tensor(129.7701, device='cuda:3'))
0.8125 tensor(3.9359, device='cuda:3') tensor(2.6960, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(4.9340, device='cuda:3'), tensor(138.9337, device='cuda:3'))
0.828125 tensor(4.0404, device='cuda:3') tensor(2.7632, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(5.0071, device='cuda:3'), tensor(149.4746, device='cuda:3'))
0.84375 tensor(4.1553, device='cuda:3') tensor(2.8338, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(5.0587, device='cuda:3'), tensor(157.3782, device='cuda:3'))
0.859375 tensor(4.2874, device='cuda:3') tensor(2.9127, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(5.3220, device='cuda:3'), tensor(204.7841, device='cuda:3'))
0.875 tensor(4.4391, device='cuda:3') tensor(3.0010, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(5.4313, device='cuda:3'), tensor(228.4410, device='cuda:3'))
0.890625 tensor(4.6106, device='cuda:3') tensor(3.0970, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(5.5442, device='cuda:3'), tensor(255.7487, device='cuda:3'))
0.90625 tensor(4.8117, device='cuda:3') tensor(3.2072, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.66it/s]


(tensor(5.6529, device='cuda:3'), tensor(285.1132, device='cuda:3'))
0.921875 tensor(5.0441, device='cuda:3') tensor(3.3362, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(5.8285, device='cuda:3'), tensor(339.8391, device='cuda:3'))
0.9375 tensor(5.3527, device='cuda:3') tensor(3.4874, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(6.0155, device='cuda:3'), tensor(409.7274, device='cuda:3'))
0.953125 tensor(5.7999, device='cuda:3') tensor(3.6818, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.67it/s]


(tensor(6.2179, device='cuda:3'), tensor(501.6692, device='cuda:3'))
0.96875 tensor(6.5356, device='cuda:3') tensor(3.9532, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.67it/s]

(tensor(6.4243, device='cuda:3'), tensor(616.6791, device='cuda:3'))





In [5]:
for i in range(63):
    ratio = i/64
    qk_quantile = torch.quantile(qk_norm_list, ratio)
    vo_quantile = torch.quantile(vo_norm_list, ratio)
    print(ratio, qk_quantile, vo_quantile)
    for name, module in orthogonal_model.named_modules():
        if name.endswith("attn"):
            v_weight = deepcopy(module.v_proj.weight.data) # (hidden_size, v_in_dim)
            v_bias = deepcopy(module.v_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            v_weight = torch.cat([v_weight, v_bias],dim=1)  # (hidden_size, v_in_dim+1)
            o_weight = deepcopy(module.o_proj.weight.data) # (o_out_dim,hidden_size)
            vo_norm = v_weight.norm(p=2,dim=-1) * o_weight.norm(p=2,dim=0)
            module.v_proj.weight.data[vo_norm<vo_quantile]=0
            module.v_proj.bias.data[vo_norm<vo_quantile]=0
            module.v_proj.weight.data[:,vo_norm<vo_quantile]=0
    print(calculating_perplexity(orthogonal_model, all_input_ids))
    

0.0 tensor(0.2857, device='cuda:3') tensor(0.1108, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.70it/s]


(tensor(2.6933, device='cuda:3'), tensor(14.7811, device='cuda:3'))
0.015625 tensor(0.6729, device='cuda:3') tensor(0.4198, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(2.6925, device='cuda:3'), tensor(14.7690, device='cuda:3'))
0.03125 tensor(1.0134, device='cuda:3') tensor(0.5602, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(2.6979, device='cuda:3'), tensor(14.8483, device='cuda:3'))
0.046875 tensor(1.3605, device='cuda:3') tensor(0.6584, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(2.7020, device='cuda:3'), tensor(14.9099, device='cuda:3'))
0.0625 tensor(1.6007, device='cuda:3') tensor(0.7402, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(2.7266, device='cuda:3'), tensor(15.2814, device='cuda:3'))
0.078125 tensor(1.7432, device='cuda:3') tensor(0.8104, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(2.8762, device='cuda:3'), tensor(17.7467, device='cuda:3'))
0.09375 tensor(1.8402, device='cuda:3') tensor(0.8713, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(3.4741, device='cuda:3'), tensor(32.2676, device='cuda:3'))
0.109375 tensor(1.9137, device='cuda:3') tensor(0.9268, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(5.1759, device='cuda:3'), tensor(176.9607, device='cuda:3'))
0.125 tensor(1.9766, device='cuda:3') tensor(0.9791, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(5.8038, device='cuda:3'), tensor(331.5579, device='cuda:3'))
0.140625 tensor(2.0301, device='cuda:3') tensor(1.0261, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(6.0209, device='cuda:3'), tensor(411.9470, device='cuda:3'))
0.15625 tensor(2.0798, device='cuda:3') tensor(1.0710, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(6.1135, device='cuda:3'), tensor(451.9392, device='cuda:3'))
0.171875 tensor(2.1251, device='cuda:3') tensor(1.1131, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.2235, device='cuda:3'), tensor(504.4834, device='cuda:3'))
0.1875 tensor(2.1676, device='cuda:3') tensor(1.1523, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.2826, device='cuda:3'), tensor(535.1931, device='cuda:3'))
0.203125 tensor(2.2084, device='cuda:3') tensor(1.1899, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(6.3874, device='cuda:3'), tensor(594.2902, device='cuda:3'))
0.21875 tensor(2.2471, device='cuda:3') tensor(1.2257, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.3434, device='cuda:3'), tensor(568.7365, device='cuda:3'))
0.234375 tensor(2.2849, device='cuda:3') tensor(1.2603, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.4860, device='cuda:3'), tensor(655.8740, device='cuda:3'))
0.25 tensor(2.3207, device='cuda:3') tensor(1.2941, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(6.5678, device='cuda:3'), tensor(711.8134, device='cuda:3'))
0.265625 tensor(2.3567, device='cuda:3') tensor(1.3267, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.60it/s]


(tensor(6.4286, device='cuda:3'), tensor(619.3145, device='cuda:3'))
0.28125 tensor(2.3911, device='cuda:3') tensor(1.3595, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.62it/s]


(tensor(6.4395, device='cuda:3'), tensor(626.0801, device='cuda:3'))
0.296875 tensor(2.4252, device='cuda:3') tensor(1.3923, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.5194, device='cuda:3'), tensor(678.1672, device='cuda:3'))
0.3125 tensor(2.4592, device='cuda:3') tensor(1.4244, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.4365, device='cuda:3'), tensor(624.1912, device='cuda:3'))
0.328125 tensor(2.4923, device='cuda:3') tensor(1.4568, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.4336, device='cuda:3'), tensor(622.4177, device='cuda:3'))
0.34375 tensor(2.5272, device='cuda:3') tensor(1.4893, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.3152, device='cuda:3'), tensor(552.8973, device='cuda:3'))
0.359375 tensor(2.5594, device='cuda:3') tensor(1.5212, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.2600, device='cuda:3'), tensor(523.1984, device='cuda:3'))
0.375 tensor(2.5931, device='cuda:3') tensor(1.5534, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(6.3234, device='cuda:3'), tensor(557.4923, device='cuda:3'))
0.390625 tensor(2.6255, device='cuda:3') tensor(1.5864, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.2933, device='cuda:3'), tensor(540.9361, device='cuda:3'))
0.40625 tensor(2.6607, device='cuda:3') tensor(1.6183, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.3495, device='cuda:3'), tensor(572.2211, device='cuda:3'))
0.421875 tensor(2.6936, device='cuda:3') tensor(1.6500, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.3716, device='cuda:3'), tensor(585.0118, device='cuda:3'))
0.4375 tensor(2.7272, device='cuda:3') tensor(1.6826, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(6.3889, device='cuda:3'), tensor(595.1824, device='cuda:3'))
0.453125 tensor(2.7603, device='cuda:3') tensor(1.7163, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.63it/s]


(tensor(6.4332, device='cuda:3'), tensor(622.1726, device='cuda:3'))
0.46875 tensor(2.7960, device='cuda:3') tensor(1.7497, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.64it/s]


(tensor(6.4325, device='cuda:3'), tensor(621.7180, device='cuda:3'))
0.484375 tensor(2.8294, device='cuda:3') tensor(1.7831, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.65it/s]


(tensor(6.4601, device='cuda:3'), tensor(639.1425, device='cuda:3'))
0.5 tensor(2.8649, device='cuda:3') tensor(1.8166, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.65it/s]


(tensor(6.5141, device='cuda:3'), tensor(674.5850, device='cuda:3'))
0.515625 tensor(2.9024, device='cuda:3') tensor(1.8516, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.65it/s]


(tensor(6.5685, device='cuda:3'), tensor(712.3335, device='cuda:3'))
0.53125 tensor(2.9387, device='cuda:3') tensor(1.8869, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.66it/s]


(tensor(6.6249, device='cuda:3'), tensor(753.6611, device='cuda:3'))
0.546875 tensor(2.9766, device='cuda:3') tensor(1.9229, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.65it/s]


(tensor(6.7053, device='cuda:3'), tensor(816.7587, device='cuda:3'))
0.5625 tensor(3.0143, device='cuda:3') tensor(1.9593, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.66it/s]


(tensor(6.7729, device='cuda:3'), tensor(873.8462, device='cuda:3'))
0.578125 tensor(3.0534, device='cuda:3') tensor(1.9951, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.67it/s]


(tensor(6.8761, device='cuda:3'), tensor(968.8724, device='cuda:3'))
0.59375 tensor(3.0928, device='cuda:3') tensor(2.0328, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.67it/s]


(tensor(6.9831, device='cuda:3'), tensor(1078.2074, device='cuda:3'))
0.609375 tensor(3.1354, device='cuda:3') tensor(2.0714, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.0939, device='cuda:3'), tensor(1204.6348, device='cuda:3'))
0.625 tensor(3.1779, device='cuda:3') tensor(2.1100, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.1897, device='cuda:3'), tensor(1325.6405, device='cuda:3'))
0.640625 tensor(3.2234, device='cuda:3') tensor(2.1504, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.3747, device='cuda:3'), tensor(1595.0669, device='cuda:3'))
0.65625 tensor(3.2723, device='cuda:3') tensor(2.1914, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.70it/s]


(tensor(7.4658, device='cuda:3'), tensor(1747.3074, device='cuda:3'))
0.671875 tensor(3.3203, device='cuda:3') tensor(2.2336, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.5187, device='cuda:3'), tensor(1842.1316, device='cuda:3'))
0.6875 tensor(3.3726, device='cuda:3') tensor(2.2771, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.5688, device='cuda:3'), tensor(1936.8184, device='cuda:3'))
0.703125 tensor(3.4262, device='cuda:3') tensor(2.3214, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.67it/s]


(tensor(7.6085, device='cuda:3'), tensor(2015.1606, device='cuda:3'))
0.71875 tensor(3.4847, device='cuda:3') tensor(2.3669, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6134, device='cuda:3'), tensor(2025.0610, device='cuda:3'))
0.734375 tensor(3.5446, device='cuda:3') tensor(2.4165, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6134, device='cuda:3'), tensor(2025.1442, device='cuda:3'))
0.75 tensor(3.6122, device='cuda:3') tensor(2.4650, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6143, device='cuda:3'), tensor(2027.0107, device='cuda:3'))
0.765625 tensor(3.6815, device='cuda:3') tensor(2.5200, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6188, device='cuda:3'), tensor(2036.1702, device='cuda:3'))
0.78125 tensor(3.7603, device='cuda:3') tensor(2.5764, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6188, device='cuda:3'), tensor(2036.1702, device='cuda:3'))
0.796875 tensor(3.8435, device='cuda:3') tensor(2.6354, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.70it/s]


(tensor(7.6188, device='cuda:3'), tensor(2036.1702, device='cuda:3'))
0.8125 tensor(3.9359, device='cuda:3') tensor(2.6960, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6188, device='cuda:3'), tensor(2036.1702, device='cuda:3'))
0.828125 tensor(4.0404, device='cuda:3') tensor(2.7632, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.6188, device='cuda:3'), tensor(2036.1702, device='cuda:3'))
0.84375 tensor(4.1553, device='cuda:3') tensor(2.8338, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.6158, device='cuda:3'), tensor(2029.9540, device='cuda:3'))
0.859375 tensor(4.2874, device='cuda:3') tensor(2.9127, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.68it/s]


(tensor(7.6158, device='cuda:3'), tensor(2029.9540, device='cuda:3'))
0.875 tensor(4.4391, device='cuda:3') tensor(3.0010, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6158, device='cuda:3'), tensor(2029.9540, device='cuda:3'))
0.890625 tensor(4.6106, device='cuda:3') tensor(3.0970, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6158, device='cuda:3'), tensor(2029.9540, device='cuda:3'))
0.90625 tensor(4.8117, device='cuda:3') tensor(3.2072, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6158, device='cuda:3'), tensor(2029.9540, device='cuda:3'))
0.921875 tensor(5.0441, device='cuda:3') tensor(3.3362, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6158, device='cuda:3'), tensor(2029.9540, device='cuda:3'))
0.9375 tensor(5.3527, device='cuda:3') tensor(3.4874, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6158, device='cuda:3'), tensor(2029.9540, device='cuda:3'))
0.953125 tensor(5.7999, device='cuda:3') tensor(3.6818, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(7.6158, device='cuda:3'), tensor(2029.9540, device='cuda:3'))
0.96875 tensor(6.5356, device='cuda:3') tensor(3.9532, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]

(tensor(7.6158, device='cuda:3'), tensor(2029.9540, device='cuda:3'))





In [None]:
for i in range(63):
    ratio = i/64
    qk_quantile = torch.quantile(qk_norm_list, ratio)
    vo_quantile = torch.quantile(vo_norm_list, ratio)
    print(ratio, qk_quantile, vo_quantile)
    for name, module in orthogonal_model.named_modules():
        if name.endswith("attn"):
            q_weight = deepcopy(module.q_proj.weight.data) # (hidden_size, q_in_dim)
            q_bias = deepcopy(module.q_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            q_weight = torch.cat([q_weight, q_bias],dim=1)  # (hidden_size, q_in_dim+1)
            k_weight = deepcopy(module.k_proj.weight.data) # (hidden_size, k_in_dim)
            k_bias = deepcopy(module.k_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            k_weight = torch.cat([k_weight, k_bias],dim=1)  # (hidden_size, k_in_dim+1)
            qk_norm = q_weight.norm(p=2,dim=-1) * k_weight.norm(p=2,dim=-1)
            module.q_proj.weight.data[qk_norm<qk_quantile]=0
            module.q_proj.bias.data[qk_norm<qk_quantile]=0
            module.k_proj.weight.data[qk_norm<qk_quantile]=0
            module.k_proj.bias.data[qk_norm<qk_quantile]=0
            v_weight = deepcopy(module.v_proj.weight.data) # (hidden_size, v_in_dim)
            v_bias = deepcopy(module.v_proj.bias.data).unsqueeze(1) # (hidden_size, 1)
            v_weight = torch.cat([v_weight, v_bias],dim=1)  # (hidden_size, v_in_dim+1)
            o_weight = deepcopy(module.o_proj.weight.data) # (o_out_dim,hidden_size)
            vo_norm = v_weight.norm(p=2,dim=-1) * o_weight.norm(p=2,dim=0)
            module.v_proj.weight.data[vo_norm<vo_quantile]=0
            module.v_proj.bias.data[vo_norm<vo_quantile]=0
            module.v_proj.weight.data[:,vo_norm<vo_quantile]=0
    print(calculating_perplexity(orthogonal_model, all_input_ids))
    

0.0 tensor(0.2857, device='cuda:3') tensor(0.1108, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.69it/s]


(tensor(2.6933, device='cuda:3'), tensor(14.7811, device='cuda:3'))
0.015625 tensor(0.6729, device='cuda:3') tensor(0.4198, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(2.6929, device='cuda:3'), tensor(14.7742, device='cuda:3'))
0.03125 tensor(1.0134, device='cuda:3') tensor(0.5602, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(2.6986, device='cuda:3'), tensor(14.8589, device='cuda:3'))
0.046875 tensor(1.3605, device='cuda:3') tensor(0.6584, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(2.7082, device='cuda:3'), tensor(15.0018, device='cuda:3'))
0.0625 tensor(1.6007, device='cuda:3') tensor(0.7402, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(2.7758, device='cuda:3'), tensor(16.0514, device='cuda:3'))
0.078125 tensor(1.7432, device='cuda:3') tensor(0.8104, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.55it/s]


(tensor(3.0952, device='cuda:3'), tensor(22.0909, device='cuda:3'))
0.09375 tensor(1.8402, device='cuda:3') tensor(0.8713, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.56it/s]


(tensor(3.5216, device='cuda:3'), tensor(33.8369, device='cuda:3'))
0.109375 tensor(1.9137, device='cuda:3') tensor(0.9268, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.57it/s]


(tensor(5.0354, device='cuda:3'), tensor(153.7578, device='cuda:3'))
0.125 tensor(1.9766, device='cuda:3') tensor(0.9791, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(5.6663, device='cuda:3'), tensor(288.9579, device='cuda:3'))
0.140625 tensor(2.0301, device='cuda:3') tensor(1.0261, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(5.9191, device='cuda:3'), tensor(372.0832, device='cuda:3'))
0.15625 tensor(2.0798, device='cuda:3') tensor(1.0710, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.58it/s]


(tensor(6.0590, device='cuda:3'), tensor(427.9567, device='cuda:3'))
0.171875 tensor(2.1251, device='cuda:3') tensor(1.1131, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.1694, device='cuda:3'), tensor(477.8759, device='cuda:3'))
0.1875 tensor(2.1676, device='cuda:3') tensor(1.1523, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.59it/s]


(tensor(6.1910, device='cuda:3'), tensor(488.3209, device='cuda:3'))
0.203125 tensor(2.2084, device='cuda:3') tensor(1.1899, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.3229, device='cuda:3'), tensor(557.1805, device='cuda:3'))
0.21875 tensor(2.2471, device='cuda:3') tensor(1.2257, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.2870, device='cuda:3'), tensor(537.5338, device='cuda:3'))
0.234375 tensor(2.2849, device='cuda:3') tensor(1.2603, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.4389, device='cuda:3'), tensor(625.7310, device='cuda:3'))
0.25 tensor(2.3207, device='cuda:3') tensor(1.2941, device='cuda:3')


100%|█████████▉| 560/562 [01:05<00:00,  8.61it/s]


(tensor(6.5179, device='cuda:3'), tensor(677.1791, device='cuda:3'))
0.265625 tensor(2.3567, device='cuda:3') tensor(1.3267, device='cuda:3')


100%|█████████▉| 560/562 [01:04<00:00,  8.62it/s]


(tensor(6.4258, device='cuda:3'), tensor(617.5693, device='cuda:3'))
0.28125 tensor(2.3911, device='cuda:3') tensor(1.3595, device='cuda:3')


 40%|████      | 226/562 [00:26<00:38,  8.63it/s]

In [7]:
ids = torch.LongTensor(enc.encode("American identity. I am part of a resistance movement with my peers,")).unsqueeze(0).to("cuda:3")
output = orthogonal_model.generate(idx=ids, max_new_tokens=15)
enc.decode(output[0].tolist())

'American identity. I am part of a resistance movement with my peers,17FA) Vic To beaches\nspeading tostyle with the 21'