In [1]:
%env CUDA_VISIBLE_DEVICES=6
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env OMP_NUM_THREADS=16

import os
import sys
sys.path.insert(0, '..')

import time
import random
from tqdm.auto import trange
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedLinear


torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_loading_dir = '/extra_disk_1/vahe1994/BRRR/layer10.self_attn.q_proj.input_activation.pt'
num_codebooks = 2
nbits_per_codebook = 8
out_group_size = 1
in_group_size = 8
batch_size = 16384
beam_size = 1
rrr_rank = 32
beam_search_epochs = 100
sparsity_regularizer = 0
print_frequency = 10
scale_nbits = 0    # 0 means no scales, 16 means no compression;
codebook_values_nbits = 16  # less than 16 means we quantize codebooks as well
init_max_iter = 100

env: CUDA_VISIBLE_DEVICES=6
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: OMP_NUM_THREADS=16




In [2]:
import wandb

os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), ipynbname.name() + ".ipynb")

# start a new wandb run to track this script
run = wandb.init(
    # set the wandb project where this run will be logged
    dir=os.getcwd(),
    project="AddQuantization-debug",
    entity = "rock-and-roll",
    save_code=True,
    name = f"{ipynbname.name()}_AQ_{num_codebooks=}_{out_group_size=}_{in_group_size=}_{nbits_per_codebook=}_{beam_search_epochs=}",
    settings=wandb.Settings(code_dir="."),
    # track hyperparameters and run metadata
    config={
    "num_codebooks" : num_codebooks,
    "out_group_size": out_group_size,
    "in_group_size": in_group_size,
    "group_size" : out_group_size * in_group_size,
    "batch_size" : batch_size,
    "beam_size" : beam_size,
    "nbits_per_codebook" : nbits_per_codebook,
    "codebook_values_nbits": codebook_values_nbits,
    "scale_nbits": scale_nbits,
    "beam_search_epochs": beam_search_epochs,
    "sparsity_regularizer": sparsity_regularizer,
    "rrr_rank": rrr_rank,
    "init_max_iter": init_max_iter,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mjustheuristic[0m ([33mrock-and-roll[0m). Use [1m`wandb login --relogin`[0m to force relogin




In [3]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf", torch_dtype='auto', low_cpu_mem_usage=True)

X = torch.load(input_loading_dir, map_location='cpu').float().flatten(0, -2)
reference_weight = model.model.layers[10].self_attn.q_proj.weight.detach().to(device).float()

XTX = torch.zeros(X.shape[-1], X.shape[-1], device=device, dtype=torch.float64)
for i in range(0, len(X), batch_size):
    x_batch = X[i: i + batch_size].cuda().double()
    XTX.addmm_(x_batch.T, x_batch, alpha=1/len(X))
    del x_batch
XTX = XTX.float()
del X

Downloading shards:   0%|          | 0/15 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [4]:
quantized_weight = QuantizedLinear(
    XTX=XTX, reference_weight=reference_weight, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, rrr_rank=rrr_rank,
    out_group_size=out_group_size, in_group_size=in_group_size,
    verbose=True, max_iter=init_max_iter,   # faster init, not tested
)
run.log({"Avg_bits": quantized_weight.estimate_nbits_per_parameter()})
print("AVG bits:", quantized_weight.estimate_nbits_per_parameter())
opt = torch.optim.Adam(quantized_weight.parameters(), lr=1e-4, betas=(0.0, 0.95), amsgrad=True)

for epoch in range(1000):
    start = time.perf_counter()
    delta_weight = (quantized_weight() - reference_weight).double()
    loss = (delta_weight @ XTX.double()).flatten() @ delta_weight.flatten() / len(delta_weight)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    run.log({'loss':loss.item()}, step=epoch)
    
    if epoch % print_frequency == 0:
        print(f"loss={loss.item():.10f}\t",
              f"time_on_epoch {epoch} = {time.perf_counter() - start}")
    if (epoch + 1) % beam_search_epochs == 0:
        quantized_weight.beam_search_update_codes_(
            XTX, reference_weight, beam_size=beam_size, sparsity_regularizer=sparsity_regularizer,
            dim_rng=random.Random(), verbose=True)

        if sparsity_regularizer != 0:
            sparsity_rate = ((quantized_weight.codes == 0).sum() / quantized_weight.codes.numel()).item()
            print(f"Sparsity rate {sparsity_rate:.5f}")
            run.log({'sparsity rate': sparsity_rate}, step=epoch)
            mean_code_nbits = sum(get_mean_nbits_by_codebook(quantized_weight.codes)) / num_codebooks
            print(f"mean_code_nbits {mean_code_nbits:.5f}")
            run.log({'Mean codebook length nbits': mean_code_nbits}, step=epoch)
            if in_group_size > 1 and out_group_size > 1:
                curr_avg_bits  = calc_avg_bits(num_codebooks, 1, mean_code_nbits,
                                     nbits_per_codebook, in_features, out_features, scale_nbits)
                run.log({"Avg_bits": curr_avg_bits}, step=epoch)

!!!! RRR RANK = 32


initializing with kmeans:   0%|          | 0/2 [00:00<?, ?it/s]

  codebook_i, _, _ = fit_kmeans(


AVG bits: 2.1279296875
loss=0.0068659294	 time_on_epoch 0 = 0.47034866688773036
loss=0.0038049441	 time_on_epoch 10 = 0.14076759619638324
loss=0.0032719078	 time_on_epoch 20 = 0.140756756067276




loss=0.0031104048	 time_on_epoch 30 = 0.14079759689047933
loss=0.0030090041	 time_on_epoch 40 = 0.14085792610421777
loss=0.0029696115	 time_on_epoch 50 = 0.14074284583330154
loss=0.0029664642	 time_on_epoch 60 = 0.14095029700547457
loss=0.0029449556	 time_on_epoch 70 = 0.1408701059408486
loss=0.0029216796	 time_on_epoch 80 = 0.14093184703961015
loss=0.0029125094	 time_on_epoch 90 = 0.14112638588994741


  0%|          | 0/2048 [00:00<?, ?it/s]

loss=0.0021841616	 time_on_epoch 100 = 0.14104199595749378
loss=0.0021734770	 time_on_epoch 110 = 0.1407064269296825
loss=0.0021719400	 time_on_epoch 120 = 0.14103615563362837
loss=0.0021710287	 time_on_epoch 130 = 0.14092997601255774
loss=0.0021703476	 time_on_epoch 140 = 0.140886546112597
loss=0.0021697919	 time_on_epoch 150 = 0.14096741611137986
loss=0.0021693199	 time_on_epoch 160 = 0.14089361624792218
loss=0.0021689113	 time_on_epoch 170 = 0.14139958564192057
loss=0.0021685567	 time_on_epoch 180 = 0.14104339620098472
loss=0.0021682554	 time_on_epoch 190 = 0.14083258667960763


  0%|          | 0/2048 [00:00<?, ?it/s]

loss=0.0020116357	 time_on_epoch 200 = 0.141018767375499
loss=0.0020030833	 time_on_epoch 210 = 0.14083915622904897
loss=0.0020018875	 time_on_epoch 220 = 0.14116966631263494
loss=0.0020012201	 time_on_epoch 230 = 0.14136980613693595
loss=0.0020007471	 time_on_epoch 240 = 0.14109506597742438
loss=0.0020003870	 time_on_epoch 250 = 0.14094640593975782
loss=0.0020001124	 time_on_epoch 260 = 0.14105019578710198
loss=0.0019998809	 time_on_epoch 270 = 0.1407855972647667
loss=0.0019997095	 time_on_epoch 280 = 0.14073178684338927
loss=0.0019995642	 time_on_epoch 290 = 0.14074815716594458


  0%|          | 0/2048 [00:00<?, ?it/s]

loss=0.0019441461	 time_on_epoch 300 = 0.14089556690305471
loss=0.0019385082	 time_on_epoch 310 = 0.14108717627823353
loss=0.0019377522	 time_on_epoch 320 = 0.1418600450269878
loss=0.0019373201	 time_on_epoch 330 = 0.14099987596273422
loss=0.0019370013	 time_on_epoch 340 = 0.14132737554609776
loss=0.0019367432	 time_on_epoch 350 = 0.14097792701795697
loss=0.0019365274	 time_on_epoch 360 = 0.1407167660072446
loss=0.0019363497	 time_on_epoch 370 = 0.1407507872208953
loss=0.0019362172	 time_on_epoch 380 = 0.1408348362892866
loss=0.0019361535	 time_on_epoch 390 = 0.14073202619329095


  0%|          | 0/2048 [00:00<?, ?it/s]

loss=0.0019086780	 time_on_epoch 400 = 0.14102852577343583
loss=0.0019044359	 time_on_epoch 410 = 0.14070075610652566
loss=0.0019038915	 time_on_epoch 420 = 0.14081126591190696
loss=0.0019035815	 time_on_epoch 430 = 0.14158996613696218
loss=0.0019033524	 time_on_epoch 440 = 0.1409751968458295
loss=0.0019031678	 time_on_epoch 450 = 0.14130343589931726
loss=0.0019030173	 time_on_epoch 460 = 0.14079406578093767
loss=0.0019029037	 time_on_epoch 470 = 0.14101905561983585
loss=0.0019028438	 time_on_epoch 480 = 0.14085278613492846
loss=0.0019028805	 time_on_epoch 490 = 0.14104067580774426


  0%|          | 0/2048 [00:00<?, ?it/s]

loss=0.0018868461	 time_on_epoch 500 = 0.14096689596772194
loss=0.0018834354	 time_on_epoch 510 = 0.14120643632486463
loss=0.0018830216	 time_on_epoch 520 = 0.14085106598213315
loss=0.0018827914	 time_on_epoch 530 = 0.1408526566810906
loss=0.0018826264	 time_on_epoch 540 = 0.14084624592214823
loss=0.0018825031	 time_on_epoch 550 = 0.14081582613289356
loss=0.0018824239	 time_on_epoch 560 = 0.1408047671429813
loss=0.0018824118	 time_on_epoch 570 = 0.14076677709817886
loss=0.0018825246	 time_on_epoch 580 = 0.14083088701590896
loss=0.0018828821	 time_on_epoch 590 = 0.14076133631169796


  0%|          | 0/2048 [00:00<?, ?it/s]

loss=0.0018722054	 time_on_epoch 600 = 0.1410132572054863
loss=0.0018693853	 time_on_epoch 610 = 0.14079836569726467
loss=0.0018690574	 time_on_epoch 620 = 0.1412493558600545
loss=0.0018688773	 time_on_epoch 630 = 0.1409091972745955
loss=0.0018687482	 time_on_epoch 640 = 0.14089863607659936
loss=0.0018686490	 time_on_epoch 650 = 0.14081498701125383
loss=0.0018685755	 time_on_epoch 660 = 0.14082390582188964
loss=0.0018685313	 time_on_epoch 670 = 0.14101921673864126
loss=0.0018685275	 time_on_epoch 680 = 0.14097003592178226
loss=0.0018685857	 time_on_epoch 690 = 0.14109452720731497


  0%|          | 0/2048 [00:00<?, ?it/s]

loss=0.0018612342	 time_on_epoch 700 = 0.14127511624246836
loss=0.0018588825	 time_on_epoch 710 = 0.1410245569422841
loss=0.0018586225	 time_on_epoch 720 = 0.1408826056867838
loss=0.0018584836	 time_on_epoch 730 = 0.14087909599766135
loss=0.0018583875	 time_on_epoch 740 = 0.1411249772645533
loss=0.0018583188	 time_on_epoch 750 = 0.14113233610987663
loss=0.0018582771	 time_on_epoch 760 = 0.14121503569185734
loss=0.0018582708	 time_on_epoch 770 = 0.14279864495620131
loss=0.0018583182	 time_on_epoch 780 = 0.14098927564918995
loss=0.0018584528	 time_on_epoch 790 = 0.14077151706442237


  0%|          | 0/2048 [00:00<?, ?it/s]

loss=0.0018531749	 time_on_epoch 800 = 0.14108340675011277
loss=0.0018511775	 time_on_epoch 810 = 0.14083313709124923
loss=0.0018509621	 time_on_epoch 820 = 0.14077292708680034
loss=0.0018508461	 time_on_epoch 830 = 0.14111936604604125
loss=0.0018507620	 time_on_epoch 840 = 0.1409753169864416
loss=0.0018506944	 time_on_epoch 850 = 0.14106252696365118
loss=0.0018506385	 time_on_epoch 860 = 0.14094429602846503
loss=0.0018505932	 time_on_epoch 870 = 0.14078499702736735
loss=0.0018505592	 time_on_epoch 880 = 0.14128152607008815
loss=0.0018505392	 time_on_epoch 890 = 0.1409889874048531


  0%|          | 0/2048 [00:00<?, ?it/s]

loss=0.0018467491	 time_on_epoch 900 = 0.14098746608942747
loss=0.0018450877	 time_on_epoch 910 = 0.140658896882087
loss=0.0018449137	 time_on_epoch 920 = 0.14111529709771276
loss=0.0018448209	 time_on_epoch 930 = 0.14112113695591688
loss=0.0018447540	 time_on_epoch 940 = 0.14079731702804565
loss=0.0018447008	 time_on_epoch 950 = 0.14076189696788788
loss=0.0018446578	 time_on_epoch 960 = 0.1407570568844676
loss=0.0018446243	 time_on_epoch 970 = 0.14077628683298826
loss=0.0018446019	 time_on_epoch 980 = 0.14097715634852648
loss=0.0018445937	 time_on_epoch 990 = 0.14090442704036832


  0%|          | 0/2048 [00:00<?, ?it/s]