In [64]:
import os
import torch
from tqdm import tqdm
import bitsandbytes as bnb
from copy import deepcopy

In [2]:
ckpt = torch.load("../checkpoints/llama2/Llama-2-7b/consolidated.00.pth",map_location='cpu')

In [133]:
wo_list = []
for key,val in ckpt.items():
    if key.endswith("wk.weight"):
        wo_list += [val.to("cuda", torch.float32)]

In [176]:
rank = 256
range_anchor = 2
# def low_rank_equivalent(base, target):
#     delta = target.clone().detach() - base.clone().detach()
#     u,s,v = torch.svd(delta.to(torch.float32))
#     k = rank
#     u_topk, s_topk, v_topk = u[:, :k], s[:k], v[:, :k]

#     lora_b = torch.mm(u_topk, torch.diag(s_topk.sqrt())).to(target.dtype)
#     lora_a = torch.mm(torch.diag(s_topk.sqrt()), v_topk.t()).to(target.dtype)
#     return lora_a, lora_b
# w_base = torch.nn.Parameter(wo_list[0].data.clone()).cuda()
# w_lora_list = []
# for ii in tqdm([kk for kk in range(1,32)]):
#     lora_a, lora_b = low_rank_equivalent(w_base.data, wo_list[ii])
#     lora1 = torch.nn.Parameter(lora_b.data.clone(),requires_grad=True)
#     lora2 = torch.nn.Parameter(lora_a.data.clone(),requires_grad=True)
#     w_lora_list +=[(lora1.cuda(), lora2.cuda())]
# lora1 = torch.nn.Parameter(torch.empty_like(wo_list[0][:,:rank]),requires_grad=True)
# lora2 = torch.nn.Parameter(torch.empty_like(wo_list[0][:rank,:]),requires_grad=True)
# torch.nn.init.xavier_normal_(lora1)
# torch.nn.init.xavier_normal_(lora2)
# w_lora_list = [(lora1.cuda(), lora2.cuda())] + w_lora_list


w_base = torch.nn.Parameter(torch.empty_like(wo_list[0])).cuda()
torch.nn.init.xavier_normal_(w_base)
w_lora_list = []
for ii in range(range_anchor):
    lora1 = torch.nn.Parameter(torch.empty_like(wo_list[0][:,:rank]),requires_grad=True)
    lora2 = torch.nn.Parameter(torch.empty_like(wo_list[0][:rank,:]),requires_grad=True)
    torch.nn.init.xavier_normal_(lora1)
    torch.nn.init.xavier_normal_(lora2)
    w_lora_list +=[(lora1.cuda(), lora2.cuda())]


weight = bnb.nn.Params4bit(
                wo_list[0].data.clone().cpu(), 
                requires_grad=False,
                quant_type='nf4',
)

In [179]:
weight.to("cuda")
weight_nf4 = bnb.functional.dequantize_4bit(weight, weight.quant_state)
((wo_list[0].data.clone()-weight_nf4)**2).mean()

tensor(3.2611e-06, device='cuda:0')

In [171]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD([w_lora_list[ii][0] for ii in range(range_anchor)]+[w_lora_list[ii][1] for ii in range(range_anchor)]+[w_base], lr=100000)

In [175]:
pbar = tqdm([ll for ll in range(500)], desc='Training', leave=True)
for epoch in pbar:
    optimizer.zero_grad()           # Zero the gradients
    loss = 0
    for idx, ww in enumerate(wo_list[:range_anchor]):
        w_approx = w_base + w_lora_list[idx][0] @ w_lora_list[idx][1]
        loss += criterion(w_approx, ww)  # Calculate loss
    # print(loss.item())
    loss.backward()
    pbar.set_description(f'Training - Epoch {epoch+1}, Loss: {loss.item():.2e}, w:{w_base[0,0].item():.2e} grad:{w_base.grad[0,0].item():.2e}')
    # print(w_base.grad[0,0])
    # print(w_base[0,0])
    optimizer.step()
print(loss)

Training - Epoch 500, Loss: 3.86e-05, w:2.99e-03 grad:-6.86e-12: 100%|██████████| 500/500 [00:02<00:00, 216.46it/s]

tensor(3.8600e-05, device='cuda:0', grad_fn=<AddBackward0>)





In [163]:
w_base.std(), w_base.mean()

(tensor(0.0144, device='cuda:0', grad_fn=<StdBackward0>),
 tensor(3.7319e-06, device='cuda:0', grad_fn=<MeanBackward0>))

In [164]:
idx=0
(w_lora_list[idx][0] @ w_lora_list[idx][1]).std(),(w_lora_list[idx][0] @ w_lora_list[idx][1]).mean()

(tensor(0.0174, device='cuda:0', grad_fn=<StdBackward0>),
 tensor(3.2446e-06, device='cuda:0', grad_fn=<MeanBackward0>))

In [55]:
def low_rank_equivalent(base, target):
    delta = target.clone().detach() - base.clone().detach()
    u,s,v = torch.svd(delta.to(torch.float32))
    k = 1024
    u_topk, s_topk, v_topk = u[:, :k], s[:k], v[:, :k]

    lora_b = torch.mm(u_topk, torch.diag(s_topk.sqrt())).to(target.dtype)
    lora_a = torch.mm(torch.diag(s_topk.sqrt()), v_topk.t()).to(target.dtype)
    return base + lora_b@lora_a
delta = 0
for ww in tqdm(wo_list[1:]):
    new_ww = low_rank_equivalent(wo_list[0],ww)
    delta += ((ww-new_ww)**2).mean()
print(delta)

100%|██████████| 31/31 [00:59<00:00,  1.92s/it]

tensor(0.0039, device='cuda:0')





In [210]:
ckpt_new = deepcopy(ckpt)

In [211]:
for ending_name in ["wq.weight","wk.weight","wv.weight","wo.weight"]:
    wo_list = []
    w_name_list = []
    for key,val in ckpt.items():
        if key.endswith(ending_name):
            wo_list += [val.to("cuda", torch.float32)]
            w_name_list += [key]

    rank = 512
    range_anchor = 2
    print(f"(In a group) full params:{4096*4096*range_anchor}, retained params:{4096*rank*2*range_anchor+4096}, reduced:{(4096*rank*2*range_anchor+4096)/(4096*4096*range_anchor):.3f}")

    for group_idx in range(0,32,range_anchor):
        w_base = torch.nn.Parameter(torch.empty_like(wo_list[group_idx])).cuda()
        torch.nn.init.xavier_normal_(w_base)
        w_lora_list = []
        for ii in range(range_anchor):
            lora1 = torch.nn.Parameter(torch.empty_like(wo_list[0][:,:rank]),requires_grad=True)
            lora2 = torch.nn.Parameter(torch.empty_like(wo_list[0][:rank,:]),requires_grad=True)
            torch.nn.init.xavier_normal_(lora1)
            torch.nn.init.xavier_normal_(lora2)
            w_lora_list +=[(lora1.cuda(), lora2.cuda())]

        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.SGD([w_lora_list[ii][0] for ii in range(range_anchor)]+[w_lora_list[ii][1] for ii in range(range_anchor)]+[w_base], lr=500000)

        pbar = tqdm([ll for ll in range(1000)], desc=f'Training group_idx={group_idx}', leave=True)
        for epoch in pbar:
            optimizer.zero_grad()           # Zero the gradients
            loss = 0
            names = w_name_list[group_idx:group_idx+range_anchor]
            for idx, ww in enumerate(wo_list[group_idx:group_idx+range_anchor]):
                w_approx = w_base + w_lora_list[idx][0] @ w_lora_list[idx][1]
                loss += criterion(w_approx, ww)  # Calculate loss
            loss.backward()
            pbar.set_description(f'Training - Epoch {epoch+1}, Loss: {loss.item():.2e}, w:{w_base[0,0].item():.2e} grad:{w_base.grad[0,0].item():.2e}')
            optimizer.step()
        for layer_idx, name in enumerate(names):
            ckpt_new[name] = w_base.to(ckpt_new[name].dtype)
            ckpt_new[name.replace("weight","lora_a.weight")] = w_lora_list[layer_idx][1]
            ckpt_new[name.replace("weight","lora_b.weight")] = w_lora_list[layer_idx][0]
        print(group_idx,loss, names)
        print("-"*40)

(In a group) full params:33554432, retained params:8392704, reduced:0.250


Training - Epoch 1000, Loss: 1.04e-05, w:4.93e-03 grad:9.92e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.09it/s]


0 tensor(1.0443e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.0.attention.wq.weight', 'layers.1.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.16e-04, w:-8.26e-03 grad:-1.39e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.77it/s]


2 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.2.attention.wq.weight', 'layers.3.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.38e-04, w:-1.85e-02 grad:3.17e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.88it/s]


4 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.4.attention.wq.weight', 'layers.5.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.15e-04, w:-1.38e-02 grad:-1.72e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.88it/s]


6 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.6.attention.wq.weight', 'layers.7.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.29e-04, w:1.74e-03 grad:-1.64e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.67it/s]


8 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.8.attention.wq.weight', 'layers.9.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.09e-04, w:4.75e-03 grad:1.91e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.72it/s]


10 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.10.attention.wq.weight', 'layers.11.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.21e-04, w:-7.91e-03 grad:1.58e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.58it/s]


12 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.12.attention.wq.weight', 'layers.13.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.15e-04, w:1.15e-02 grad:-2.86e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.99it/s]


14 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.14.attention.wq.weight', 'layers.15.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.24e-04, w:-3.82e-03 grad:-4.10e-13: 100%|██████████| 1000/1000 [00:07<00:00, 131.31it/s]


16 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.16.attention.wq.weight', 'layers.17.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.32e-04, w:-9.23e-03 grad:2.44e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.64it/s]


18 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.18.attention.wq.weight', 'layers.19.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.27e-04, w:-4.46e-03 grad:2.90e-13: 100%|██████████| 1000/1000 [00:07<00:00, 130.35it/s]


20 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.20.attention.wq.weight', 'layers.21.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.45e-04, w:-6.03e-03 grad:-3.72e-13: 100%|██████████| 1000/1000 [00:07<00:00, 130.05it/s]


22 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.22.attention.wq.weight', 'layers.23.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.32e-04, w:-4.72e-03 grad:3.03e-12: 100%|██████████| 1000/1000 [00:07<00:00, 129.94it/s]


24 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.24.attention.wq.weight', 'layers.25.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.41e-04, w:-1.61e-02 grad:2.50e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.37it/s]


26 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.26.attention.wq.weight', 'layers.27.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.29e-04, w:5.77e-03 grad:-2.91e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.37it/s]


28 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.28.attention.wq.weight', 'layers.29.attention.wq.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.22e-04, w:-2.95e-02 grad:1.58e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.14it/s]


30 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.30.attention.wq.weight', 'layers.31.attention.wq.weight']
----------------------------------------
(In a group) full params:33554432, retained params:8392704, reduced:0.250


Training - Epoch 1000, Loss: 9.09e-06, w:1.01e-02 grad:-1.23e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.38it/s]


0 tensor(9.0937e-06, device='cuda:0', grad_fn=<AddBackward0>) ['layers.0.attention.wk.weight', 'layers.1.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.26e-04, w:-8.82e-04 grad:-2.61e-12: 100%|██████████| 1000/1000 [00:07<00:00, 132.03it/s]


2 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.2.attention.wk.weight', 'layers.3.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.51e-04, w:-2.40e-02 grad:1.43e-12: 100%|██████████| 1000/1000 [00:07<00:00, 132.07it/s]


4 tensor(0.0002, device='cuda:0', grad_fn=<AddBackward0>) ['layers.4.attention.wk.weight', 'layers.5.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.24e-04, w:-1.77e-03 grad:-5.86e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.69it/s]


6 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.6.attention.wk.weight', 'layers.7.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.41e-04, w:-2.54e-02 grad:-3.01e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.89it/s]


8 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.8.attention.wk.weight', 'layers.9.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.20e-04, w:-1.81e-02 grad:2.55e-13: 100%|██████████| 1000/1000 [00:07<00:00, 131.30it/s]


10 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.10.attention.wk.weight', 'layers.11.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.32e-04, w:1.43e-02 grad:-3.00e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.20it/s]


12 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.12.attention.wk.weight', 'layers.13.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.24e-04, w:-2.21e-03 grad:-5.59e-13: 100%|██████████| 1000/1000 [00:07<00:00, 131.31it/s]


14 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.14.attention.wk.weight', 'layers.15.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.31e-04, w:-1.16e-02 grad:-3.38e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.16it/s]


16 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.16.attention.wk.weight', 'layers.17.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.40e-04, w:6.06e-03 grad:3.50e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.97it/s]


18 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.18.attention.wk.weight', 'layers.19.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.32e-04, w:-1.02e-03 grad:-2.07e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.22it/s]


20 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.20.attention.wk.weight', 'layers.21.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.50e-04, w:-1.34e-02 grad:4.90e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.40it/s]


22 tensor(0.0002, device='cuda:0', grad_fn=<AddBackward0>) ['layers.22.attention.wk.weight', 'layers.23.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.37e-04, w:2.62e-03 grad:-3.48e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.17it/s]


24 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.24.attention.wk.weight', 'layers.25.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.48e-04, w:1.34e-04 grad:-1.95e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.88it/s]


26 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.26.attention.wk.weight', 'layers.27.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.37e-04, w:1.22e-02 grad:1.28e-13: 100%|██████████| 1000/1000 [00:07<00:00, 131.37it/s]


28 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.28.attention.wk.weight', 'layers.29.attention.wk.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.32e-04, w:-4.70e-03 grad:-4.29e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.80it/s]


30 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.30.attention.wk.weight', 'layers.31.attention.wk.weight']
----------------------------------------
(In a group) full params:33554432, retained params:8392704, reduced:0.250


Training - Epoch 1000, Loss: 2.27e-05, w:3.96e-03 grad:-2.75e-12: 100%|██████████| 1000/1000 [00:07<00:00, 132.01it/s]


0 tensor(2.2744e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.0.attention.wv.weight', 'layers.1.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 5.94e-05, w:3.50e-03 grad:-1.35e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.98it/s]


2 tensor(5.9404e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.2.attention.wv.weight', 'layers.3.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 6.69e-05, w:8.34e-04 grad:3.10e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.70it/s]


4 tensor(6.6884e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.4.attention.wv.weight', 'layers.5.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 5.64e-05, w:4.73e-03 grad:-1.34e-11: 100%|██████████| 1000/1000 [00:07<00:00, 131.65it/s]


6 tensor(5.6397e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.6.attention.wv.weight', 'layers.7.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 5.78e-05, w:-4.66e-03 grad:3.75e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.30it/s]


8 tensor(5.7849e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.8.attention.wv.weight', 'layers.9.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 5.67e-05, w:-1.94e-02 grad:4.15e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.04it/s]


10 tensor(5.6732e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.10.attention.wv.weight', 'layers.11.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 6.58e-05, w:1.22e-02 grad:2.06e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.25it/s]


12 tensor(6.5776e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.12.attention.wv.weight', 'layers.13.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 6.97e-05, w:1.10e-02 grad:-1.18e-11: 100%|██████████| 1000/1000 [00:07<00:00, 130.95it/s]


14 tensor(6.9748e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.14.attention.wv.weight', 'layers.15.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 8.82e-05, w:-1.18e-02 grad:8.42e-13: 100%|██████████| 1000/1000 [00:07<00:00, 131.14it/s]


16 tensor(8.8180e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.16.attention.wv.weight', 'layers.17.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 9.88e-05, w:-3.06e-04 grad:-4.10e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.37it/s]


18 tensor(9.8769e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.18.attention.wv.weight', 'layers.19.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.06e-04, w:3.66e-03 grad:4.65e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.32it/s]


20 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.20.attention.wv.weight', 'layers.21.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.19e-04, w:-9.93e-03 grad:7.47e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.09it/s]


22 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.22.attention.wv.weight', 'layers.23.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.25e-04, w:2.48e-02 grad:2.28e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.01it/s]


24 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.24.attention.wv.weight', 'layers.25.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.39e-04, w:4.02e-02 grad:3.99e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.24it/s]


26 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.26.attention.wv.weight', 'layers.27.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.49e-04, w:1.10e-02 grad:-3.39e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.75it/s]


28 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.28.attention.wv.weight', 'layers.29.attention.wv.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.45e-04, w:7.61e-03 grad:2.12e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.37it/s]


30 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.30.attention.wv.weight', 'layers.31.attention.wv.weight']
----------------------------------------
(In a group) full params:33554432, retained params:8392704, reduced:0.250


Training - Epoch 1000, Loss: 1.25e-05, w:1.03e-02 grad:-2.80e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.46it/s]


0 tensor(1.2460e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.0.attention.wo.weight', 'layers.1.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 5.92e-05, w:-1.59e-02 grad:-3.98e-12: 100%|██████████| 1000/1000 [00:07<00:00, 132.09it/s]


2 tensor(5.9250e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.2.attention.wo.weight', 'layers.3.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 6.30e-05, w:1.66e-02 grad:-2.58e-12: 100%|██████████| 1000/1000 [00:07<00:00, 132.10it/s]


4 tensor(6.3013e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.4.attention.wo.weight', 'layers.5.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 5.35e-05, w:1.22e-02 grad:-5.86e-12: 100%|██████████| 1000/1000 [00:07<00:00, 132.02it/s]


6 tensor(5.3504e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.6.attention.wo.weight', 'layers.7.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 5.78e-05, w:2.69e-02 grad:-3.33e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.21it/s]


8 tensor(5.7828e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.8.attention.wo.weight', 'layers.9.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 5.42e-05, w:-4.19e-03 grad:9.08e-13: 100%|██████████| 1000/1000 [00:07<00:00, 131.55it/s]


10 tensor(5.4223e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.10.attention.wo.weight', 'layers.11.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 6.15e-05, w:7.09e-04 grad:-6.45e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.82it/s]


12 tensor(6.1539e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.12.attention.wo.weight', 'layers.13.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 6.59e-05, w:-3.95e-03 grad:-5.63e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.47it/s]


14 tensor(6.5853e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.14.attention.wo.weight', 'layers.15.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 8.56e-05, w:9.19e-03 grad:-1.83e-12: 100%|██████████| 1000/1000 [00:07<00:00, 130.83it/s]


16 tensor(8.5564e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.16.attention.wo.weight', 'layers.17.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 9.94e-05, w:-5.30e-03 grad:7.77e-14: 100%|██████████| 1000/1000 [00:07<00:00, 131.26it/s]


18 tensor(9.9391e-05, device='cuda:0', grad_fn=<AddBackward0>) ['layers.18.attention.wo.weight', 'layers.19.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.13e-04, w:-2.27e-02 grad:7.68e-13: 100%|██████████| 1000/1000 [00:07<00:00, 131.13it/s]


20 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.20.attention.wo.weight', 'layers.21.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.21e-04, w:-1.73e-02 grad:-1.95e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.35it/s]


22 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.22.attention.wo.weight', 'layers.23.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.27e-04, w:-8.90e-03 grad:-8.30e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.32it/s]


24 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.24.attention.wo.weight', 'layers.25.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.38e-04, w:1.25e-02 grad:-6.57e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.08it/s]


26 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.26.attention.wo.weight', 'layers.27.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.47e-04, w:1.18e-02 grad:-5.97e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.68it/s]


28 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.28.attention.wo.weight', 'layers.29.attention.wo.weight']
----------------------------------------


Training - Epoch 1000, Loss: 1.43e-04, w:1.14e-02 grad:-2.65e-12: 100%|██████████| 1000/1000 [00:07<00:00, 131.23it/s]

30 tensor(0.0001, device='cuda:0', grad_fn=<AddBackward0>) ['layers.30.attention.wo.weight', 'layers.31.attention.wo.weight']
----------------------------------------





In [212]:
ckpt_new.keys()

dict_keys(['tok_embeddings.weight', 'norm.weight', 'output.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wk.weight', 'layers.0.attention.wv.weight', 'layers.0.attention.wo.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.attention_norm.weight', 'layers.0.ffn_norm.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wv.weight', 'layers.1.attention.wo.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w2.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.attention_norm.weight', 'layers.1.ffn_norm.weight', 'layers.2.attention.wq.weight', 'layers.2.attention.wk.weight', 'layers.2.attention.wv.weight', 'layers.2.attention.wo.weight', 'layers.2.feed_forward.w1.weight', 'layers.2.feed_forward.w2.weight', 'layers.2.feed_forward.w3.weight', 'layers.2.attention_norm.weight', 'layers.2.ffn_norm.weight', 'layers.3.attention.wq.weight', 'layers.3.atten

In [213]:
torch.save(ckpt_new, "../checkpoints/effiLLaMA2/consolidated.00.pth")

In [214]:
print(ckpt_new["layers.30.attention.wq.weight"]==ckpt_new["layers.31.attention.wq.weight"])
print(ckpt_new["layers.10.attention.wk.weight"]==ckpt_new["layers.11.attention.wk.weight"])
print(ckpt_new["layers.2.attention.wv.weight"]==ckpt_new["layers.3.attention.wv.weight"])

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')
tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')
tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, Tr

In [205]:
ckpt_new["layers.31.attention.wq.lora_b.weight"].shape, ckpt_new["layers.31.attention.wq.lora_a.weight"].shape, 

(torch.Size([4096, 512]), torch.Size([512, 4096]))

In [209]:
ckpt_new["layers.3.attention.wk.weight"]==ckpt_new["layers.2.attention.wk.weight"]

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [None]:
model.llma.layers[31].attention.wq.weight
model.llma.layers[2].attention.wk.weight

In [None]:
for ending_name in ["feed_forward.w1"]:
    wo_list = []
    w_name_list = []
    for key,val in ckpt.items():
        if key.endswith(ending_name):
            wo_list += [val.to("cuda", torch.float32)]
            w_name_list += [key]

    rank = 512
    range_anchor = 2
    print(f"(In a group) full params:{4096*4096*range_anchor}, retained params:{4096*rank*2*range_anchor+4096}, reduced:{(4096*rank*2*range_anchor+4096)/(4096*4096*range_anchor):.3f}")

    for group_idx in range(0,32,range_anchor):
        w_base = torch.nn.Parameter(torch.empty_like(wo_list[group_idx])).cuda()
        torch.nn.init.xavier_normal_(w_base)
        w_lora_list = []
        for ii in range(range_anchor):
            lora1 = torch.nn.Parameter(torch.empty_like(wo_list[0][:,:rank]),requires_grad=True)
            lora2 = torch.nn.Parameter(torch.empty_like(wo_list[0][:rank,:]),requires_grad=True)
            torch.nn.init.xavier_normal_(lora1)
            torch.nn.init.xavier_normal_(lora2)
            w_lora_list +=[(lora1.cuda(), lora2.cuda())]

        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.SGD([w_lora_list[ii][0] for ii in range(range_anchor)]+[w_lora_list[ii][1] for ii in range(range_anchor)]+[w_base], lr=500000)

        pbar = tqdm([ll for ll in range(1000)], desc=f'Training group_idx={group_idx}', leave=True)
        for epoch in pbar:
            optimizer.zero_grad()           # Zero the gradients
            loss = 0
            names = w_name_list[group_idx:group_idx+range_anchor]
            for idx, ww in enumerate(wo_list[group_idx:group_idx+range_anchor]):
                w_approx = w_base + w_lora_list[idx][0] @ w_lora_list[idx][1]
                loss += criterion(w_approx, ww)  # Calculate loss
            loss.backward()
            pbar.set_description(f'Training - Epoch {epoch+1}, Loss: {loss.item():.2e}, w:{w_base[0,0].item():.2e} grad:{w_base.grad[0,0].item():.2e}')
            optimizer.step()
        for layer_idx, name in enumerate(names):
            ckpt_new[name] = w_base.to(ckpt_new[name].dtype)
            ckpt_new[name.replace("weight","lora_a.weight")] = w_lora_list[layer_idx][1]
            ckpt_new[name.replace("weight","lora_b.weight")] = w_lora_list[layer_idx][0]
        print(group_idx,loss, names)
        print("-"*40)