In [1]:
%env CUDA_VISIBLE_DEVICES=2
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env OMP_NUM_THREADS=16

import os
import sys
sys.path.insert(0, '..')

import time
import random
from tqdm.auto import trange
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight


torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_loading_dir = '/extra_disk_1/vahe1994/BRRR/layer10.self_attn.q_proj.input_activation.pt'
num_codebooks = 1
nbits_per_codebook = 14
out_group_size = 1
in_group_size = 8
batch_size = 16384
beam_size = 1
beam_search_epochs = 100
sparsity_regularizer = 0
print_frequency = 10
scale_nbits = 2    # 0 means no scales, 16 means no compression;
codebook_values_nbits = 16  # less than 16 means we quantize codebooks as well
init_max_iter = 100

env: CUDA_VISIBLE_DEVICES=2
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: OMP_NUM_THREADS=16


In [2]:
import wandb

os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), ipynbname.name() + ".ipynb")

# start a new wandb run to track this script
run = wandb.init(
    # set the wandb project where this run will be logged
    dir=os.getcwd(),
    project="AddQuantization-debug",
    entity = "rock-and-roll",
    save_code=True,
    name = f"{ipynbname.name()}_AQ_{num_codebooks=}_{out_group_size=}_{in_group_size=}_{nbits_per_codebook=}_{beam_search_epochs=}",
    settings=wandb.Settings(code_dir="."),
    # track hyperparameters and run metadata
    config={
    "num_codebooks" : num_codebooks,
    "out_group_size": out_group_size,
    "in_group_size": in_group_size,
    "group_size" : out_group_size * in_group_size,
    "batch_size" : batch_size,
    "beam_size" : beam_size,
    "nbits_per_codebook" : nbits_per_codebook,
    "codebook_values_nbits": codebook_values_nbits,
    "scale_nbits": scale_nbits,
    "beam_search_epochs": beam_search_epochs,
    "sparsity_regularizer": sparsity_regularizer,
    "init_max_iter": init_max_iter,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mjustheuristic[0m ([33mrock-and-roll[0m). Use [1m`wandb login --relogin`[0m to force relogin




In [3]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf", torch_dtype='auto', low_cpu_mem_usage=True)

X = torch.load(input_loading_dir, map_location='cpu').float().flatten(0, -2)
reference_weight = model.model.layers[10].self_attn.q_proj.weight.detach().to(device).float()

XTX = torch.zeros(X.shape[-1], X.shape[-1], device=device, dtype=torch.float64)
for i in range(0, len(X), batch_size):
    x_batch = X[i: i + batch_size].cuda().double()
    XTX.addmm_(x_batch.T, x_batch, alpha=1/len(X))
    del x_batch
XTX = XTX.float()
del X

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [4]:
quantized_weight = QuantizedWeight(
    reference_weight=reference_weight, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits,
    out_group_size=out_group_size, in_group_size=in_group_size,
    verbose=True, max_iter=init_max_iter,   # faster init, not tested
)
run.log({"Avg_bits": quantized_weight.estimate_nbits_per_parameter()})
opt = torch.optim.Adam(quantized_weight.parameters(), lr=1e-4, betas=(0.0, 0.95), amsgrad=True)

for epoch in range(1000):
    start = time.perf_counter()
    delta_weight = (quantized_weight() - reference_weight).double()
    loss = (delta_weight @ XTX.double()).flatten() @ delta_weight.flatten() / len(delta_weight)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    run.log({'loss':loss.item()}, step=epoch)
    
    if epoch % print_frequency == 0:
        print(f"loss={loss.item():.10f}\t",
              f"time_on_epoch {epoch} = {time.perf_counter() - start}")
    if (epoch + 1) % beam_search_epochs == 0:
        quantized_weight.beam_search_update_codes_(
            XTX, reference_weight, beam_size=beam_size, sparsity_regularizer=sparsity_regularizer,
            dim_rng=random.Random(), verbose=True)

        if sparsity_regularizer != 0:
            sparsity_rate = ((quantized_weight.codes == 0).sum() / quantized_weight.codes.numel()).item()
            print(f"Sparsity rate {sparsity_rate:.5f}")
            run.log({'sparsity rate': sparsity_rate}, step=epoch)
            mean_code_nbits = sum(get_mean_nbits_by_codebook(quantized_weight.codes)) / num_codebooks
            print(f"mean_code_nbits {mean_code_nbits:.5f}")
            run.log({'Mean codebook length nbits': mean_code_nbits}, step=epoch)
            if in_group_size > 1 and out_group_size > 1:
                curr_avg_bits  = calc_avg_bits(num_codebooks, 1, mean_code_nbits,
                                     nbits_per_codebook, in_features, out_features, scale_nbits)
                run.log({"Avg_bits": curr_avg_bits}, step=epoch)

  groupwise_cluster_indices = torch.searchsorted(border_indices[:, 1:], groupwise_ranks_1based, side='left')


initializing with kmeans:   0%|          | 0/1 [00:00<?, ?it/s]

  codebook_i, _, _ = fit_kmeans(


loss=0.0254404022	 time_on_epoch 0 = 1.802529029082507
loss=0.0210026318	 time_on_epoch 10 = 0.13315559295006096
loss=0.0175706957	 time_on_epoch 20 = 0.13393667386844754
loss=0.0149330052	 time_on_epoch 30 = 0.13344294298440218
loss=0.0129503049	 time_on_epoch 40 = 0.13311199308373034
loss=0.0114855729	 time_on_epoch 50 = 0.13305627298541367
loss=0.0104046396	 time_on_epoch 60 = 0.13350988295860589
loss=0.0096008013	 time_on_epoch 70 = 0.1330640739761293
loss=0.0089960344	 time_on_epoch 80 = 0.1331505631096661




loss=0.0085348955	 time_on_epoch 90 = 0.13324559293687344


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0035342623	 time_on_epoch 100 = 0.1342155139427632
loss=0.0035245917	 time_on_epoch 110 = 0.13356428290717304
loss=0.0035197623	 time_on_epoch 120 = 0.13342554285191
loss=0.0035165187	 time_on_epoch 130 = 0.13311086385510862
loss=0.0035141055	 time_on_epoch 140 = 0.13374247308820486
loss=0.0035121871	 time_on_epoch 150 = 0.13366003311239183
loss=0.0035105889	 time_on_epoch 160 = 0.13311564410105348
loss=0.0035092115	 time_on_epoch 170 = 0.13361279317177832
loss=0.0035079945	 time_on_epoch 180 = 0.1335298928897828
loss=0.0035068988	 time_on_epoch 190 = 0.13351217308081686


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0027583639	 time_on_epoch 200 = 0.13422761298716068
loss=0.0027521192	 time_on_epoch 210 = 0.1331038640346378
loss=0.0027489219	 time_on_epoch 220 = 0.1332344738766551
loss=0.0027466409	 time_on_epoch 230 = 0.13368744309991598
loss=0.0027448470	 time_on_epoch 240 = 0.13358768401667476
loss=0.0027433508	 time_on_epoch 250 = 0.13345329486764967
loss=0.0027420536	 time_on_epoch 260 = 0.13358681416139007
loss=0.0027408986	 time_on_epoch 270 = 0.133119055069983
loss=0.0027398508	 time_on_epoch 280 = 0.13355062482878566
loss=0.0027388872	 time_on_epoch 290 = 0.13312221597880125


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0025228964	 time_on_epoch 300 = 0.13364888890646398
loss=0.0025195496	 time_on_epoch 310 = 0.1336160791106522
loss=0.0025177463	 time_on_epoch 320 = 0.13310001883655787
loss=0.0025164050	 time_on_epoch 330 = 0.1330654399935156
loss=0.0025153146	 time_on_epoch 340 = 0.1332611700054258
loss=0.0025143814	 time_on_epoch 350 = 0.13354766997508705
loss=0.0025135557	 time_on_epoch 360 = 0.1331644500605762
loss=0.0025128087	 time_on_epoch 370 = 0.13312125112861395
loss=0.0025121221	 time_on_epoch 380 = 0.1331358510069549
loss=0.0025114841	 time_on_epoch 390 = 0.13362607080489397


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0024125674	 time_on_epoch 400 = 0.13408855418674648
loss=0.0024103773	 time_on_epoch 410 = 0.13317511486820877
loss=0.0024091397	 time_on_epoch 420 = 0.13309856480918825
loss=0.0024081878	 time_on_epoch 430 = 0.13308660499751568
loss=0.0024073935	 time_on_epoch 440 = 0.13323473604395986
loss=0.0024066997	 time_on_epoch 450 = 0.1335714349988848
loss=0.0024060761	 time_on_epoch 460 = 0.1331613960210234
loss=0.0024055047	 time_on_epoch 470 = 0.13312953617423773
loss=0.0024049743	 time_on_epoch 480 = 0.13308153697289526
loss=0.0024044772	 time_on_epoch 490 = 0.13308247597888112


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0023482013	 time_on_epoch 500 = 0.13381659891456366
loss=0.0023466351	 time_on_epoch 510 = 0.13306462997570634
loss=0.0023457173	 time_on_epoch 520 = 0.13385432003997266
loss=0.0023449968	 time_on_epoch 530 = 0.13312284089624882
loss=0.0023443863	 time_on_epoch 540 = 0.13309816014952958
loss=0.0023438464	 time_on_epoch 550 = 0.1331249310169369
loss=0.0023433560	 time_on_epoch 560 = 0.13319928105920553
loss=0.0023429029	 time_on_epoch 570 = 0.13310051104053855
loss=0.0023424790	 time_on_epoch 580 = 0.13313604099676013
loss=0.0023420793	 time_on_epoch 590 = 0.13310815207660198


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0023061153	 time_on_epoch 600 = 0.13407287397421896
loss=0.0023049505	 time_on_epoch 610 = 0.13359928503632545
loss=0.0023042411	 time_on_epoch 620 = 0.1331018649507314
loss=0.0023036717	 time_on_epoch 630 = 0.13314920500852168
loss=0.0023031810	 time_on_epoch 640 = 0.1330852450337261
loss=0.0023027412	 time_on_epoch 650 = 0.13313689501956105
loss=0.0023023376	 time_on_epoch 660 = 0.13315646490082145
loss=0.0023019614	 time_on_epoch 670 = 0.13312553614377975
loss=0.0023016072	 time_on_epoch 680 = 0.1330747460015118
loss=0.0023012713	 time_on_epoch 690 = 0.13316271593794227


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0022764155	 time_on_epoch 700 = 0.13415284804068506
loss=0.0022754879	 time_on_epoch 710 = 0.13334769988432527
loss=0.0022749107	 time_on_epoch 720 = 0.13311076909303665
loss=0.0022744418	 time_on_epoch 730 = 0.13310146005824208
loss=0.0022740339	 time_on_epoch 740 = 0.13317885994911194
loss=0.0022736655	 time_on_epoch 750 = 0.1330872098915279
loss=0.0022733253	 time_on_epoch 760 = 0.1330617901403457
loss=0.0022730067	 time_on_epoch 770 = 0.13310838001780212
loss=0.0022727052	 time_on_epoch 780 = 0.13305363105610013
loss=0.0022724182	 time_on_epoch 790 = 0.13366883993148804


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0022542918	 time_on_epoch 800 = 0.13398651196621358
loss=0.0022535495	 time_on_epoch 810 = 0.1330864131450653
loss=0.0022530730	 time_on_epoch 820 = 0.13310854299925268
loss=0.0022526800	 time_on_epoch 830 = 0.13312069419771433
loss=0.0022523345	 time_on_epoch 840 = 0.13311032392084599
loss=0.0022520199	 time_on_epoch 850 = 0.1331010339781642
loss=0.0022517274	 time_on_epoch 860 = 0.13319650408811867
loss=0.0022514521	 time_on_epoch 870 = 0.1330987939145416
loss=0.0022511904	 time_on_epoch 880 = 0.13309921487234533
loss=0.0022509403	 time_on_epoch 890 = 0.13314738497138023


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0022368615	 time_on_epoch 900 = 0.13390998705290258
loss=0.0022362412	 time_on_epoch 910 = 0.1330994670279324
loss=0.0022358361	 time_on_epoch 920 = 0.13309631682932377
loss=0.0022354991	 time_on_epoch 930 = 0.1331188678741455
loss=0.0022352007	 time_on_epoch 940 = 0.1332482979632914
loss=0.0022349277	 time_on_epoch 950 = 0.1331654479727149
loss=0.0022346728	 time_on_epoch 960 = 0.13312231819145381
loss=0.0022344319	 time_on_epoch 970 = 0.13314944808371365
loss=0.0022342022	 time_on_epoch 980 = 0.13314449903555214
loss=0.0022339821	 time_on_epoch 990 = 0.13316197786480188


  0%|          | 0/1024 [00:00<?, ?it/s]