In [1]:
%env CUDA_VISIBLE_DEVICES=2
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env OMP_NUM_THREADS=16

import os
import sys
sys.path.insert(0, '..')

import time
import random
from tqdm.auto import trange
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight
from src.utils import calc_avg_bits, get_mean_nbits_by_codebook  # see adjacent file (aq.py)

torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_loading_dir = '/extra_disk_1/vahe1994/BRRR/layer10.self_attn.q_proj.input_activation.pt'
num_codebooks = 1
nbits_per_codebook = 14
out_group_size = 1
in_group_size = 8
batch_size = 16384
beam_size = 1
beam_search_epochs = 100
sparsity_regularizer = 0
print_frequency = 10
scale_nbits = 2    # 0 means no scales, 16 means no compression;


#for each group we store num_codebooks * nbits_per_codebook - bits 
# for W matrix  we store out_features*(in_features // group_size) * num_codebooks * nbits_per_codebook - bits 
# 1 codebook store codebook_size*group_size*16 - bit 
# all codebooks store num_codebooks* codebook_size*group_size*16 -bits
in_features, out_features  = 8192, 8192
estimated_bits_per_param = calc_avg_bits(
    num_codebooks, out_group_size, in_group_size, nbits_per_codebook, in_features, out_features, scale_nbits)
print("Estimated bits / param", estimated_bits_per_param)

env: CUDA_VISIBLE_DEVICES=2
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: OMP_NUM_THREADS=16




Estimated bits / param 2.0390625


In [2]:
import wandb

os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), ipynbname.name() + ".ipynb")

# start a new wandb run to track this script
run = wandb.init(
    # set the wandb project where this run will be logged
    dir=os.getcwd(),
    project="AddQuantization",
    entity = "rock-and-roll",
    save_code=True,
    name = f"{ipynbname.name()}_AQ_{num_codebooks=}_{out_group_size=}_{in_group_size=}_{nbits_per_codebook=}_{beam_search_epochs=}",
    settings=wandb.Settings(code_dir="."),
    # track hyperparameters and run metadata
    config={
    "num_codebooks" : num_codebooks,
    "out_group_size": out_group_size,
    "in_group_size": in_group_size,
    "group_size" : out_group_size * in_group_size,
    "batch_size" : batch_size,
    "beam_size" : beam_size,
    "nbits_per_codebook" : nbits_per_codebook,
    "Avg_bits": estimated_bits_per_param,
    "beam_search_epochs": beam_search_epochs,
    "sparsity_regularizer": sparsity_regularizer,
    "scale_nvits": scale_nbits,
    }
)
run.log({"Avg_bits": estimated_bits_per_param})

[34m[1mwandb[0m: Currently logged in as: [33mjustheuristic[0m ([33mrock-and-roll[0m). Use [1m`wandb login --relogin`[0m to force relogin




In [3]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf", torch_dtype='auto', low_cpu_mem_usage=True)

X = torch.load(input_loading_dir, map_location='cpu').float().flatten(0, -2)
reference_weight = model.model.layers[10].self_attn.q_proj.weight.detach().to(device).float()

XTX = torch.zeros(X.shape[-1], X.shape[-1], device=device, dtype=torch.float64)
for i in range(0, len(X), batch_size):
    x_batch = X[i: i + batch_size].cuda().double()
    XTX.addmm_(x_batch.T, x_batch, alpha=1/len(X))
    del x_batch
XTX = XTX.float()
del X

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [None]:
quantized_weight = QuantizedWeight(
    reference_weight=reference_weight, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits,
    out_group_size=out_group_size, in_group_size=in_group_size,
    verbose=True, max_iter=100,   # faster init, not tested
)
opt = torch.optim.Adam(quantized_weight.parameters(), lr=1e-4, betas=(0.0, 0.95), amsgrad=True)

for epoch in range(10_000):
    start = time.perf_counter()
    delta_weight = (quantized_weight() - reference_weight).double()
    loss = (delta_weight @ XTX.double()).flatten() @ delta_weight.flatten() / len(delta_weight)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    run.log({'loss':loss.item()}, step=epoch)
    
    if epoch % print_frequency == 0:
        print(f"loss={loss.item():.10f}\t",
              f"time_on_epoch {epoch} = {time.perf_counter() - start}")
    if (epoch + 1) % beam_search_epochs == 0:
        quantized_weight.requantize_(
            XTX, reference_weight, beam_size=beam_size, sparsity_regularizer=sparsity_regularizer,
            dim_rng=random.Random(), verbose=True)

        if sparsity_regularizer != 0:
            sparsity_rate = ((quantized_weight.codes == 0).sum() / quantized_weight.codes.numel()).item()
            print(f"Sparsity rate {sparsity_rate:.5f}")
            run.log({'sparsity rate': sparsity_rate}, step=epoch)
            mean_code_nbits = sum(get_mean_nbits_by_codebook(quantized_weight.codes)) / num_codebooks
            print(f"mean_code_nbits {mean_code_nbits:.5f}")
            run.log({'Mean codebook length nbits': mean_code_nbits}, step=epoch)
            if in_group_size > 1 and out_group_size > 1:
                curr_avg_bits  = calc_avg_bits(num_codebooks, 1, mean_code_nbits,
                                     nbits_per_codebook, in_features, out_features, scale_nbits)
                run.log({"Avg_bits": curr_avg_bits}, step=epoch)

  groupwise_cluster_indices = torch.searchsorted(border_indices[:, 1:], groupwise_ranks_1based, side='left')


initializing with kmeans:   0%|          | 0/1 [00:00<?, ?it/s]

  codebook_i, codes_i, reconstructed_weight_i = fit_kmeans(weight_residue, k=codebook_size, **kwargs)


loss=0.0252773022	 time_on_epoch 0 = 1.7946178630227223
loss=0.0208712124	 time_on_epoch 10 = 0.1330599730135873
loss=0.0174722935	 time_on_epoch 20 = 0.1330504339421168
loss=0.0148668374	 time_on_epoch 30 = 0.1330770329805091
loss=0.0129121828	 time_on_epoch 40 = 0.1331126029836014
loss=0.0114690073	 time_on_epoch 50 = 0.13304006296675652
loss=0.0104033902	 time_on_epoch 60 = 0.13304965302813798
loss=0.0096097461	 time_on_epoch 70 = 0.1331646329490468
loss=0.0090112719	 time_on_epoch 80 = 0.13303362298756838
loss=0.0085535870	 time_on_epoch 90 = 0.13318711298052222


  0%|          | 0/1024 [00:00<?, ?it/s]



loss=0.0034655112	 time_on_epoch 100 = 0.13364200305659324
loss=0.0034559176	 time_on_epoch 110 = 0.13328092207666487
loss=0.0034510969	 time_on_epoch 120 = 0.13307638303376734
loss=0.0034478450	 time_on_epoch 130 = 0.13346318306867033
loss=0.0034454202	 time_on_epoch 140 = 0.133056542952545
loss=0.0034434910	 time_on_epoch 150 = 0.13307083304971457
loss=0.0034418843	 time_on_epoch 160 = 0.13308089296333492
loss=0.0034405009	 time_on_epoch 170 = 0.13307898305356503
loss=0.0034392800	 time_on_epoch 180 = 0.13345021300483495
loss=0.0034381826	 time_on_epoch 190 = 0.1330749629996717


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0027243270	 time_on_epoch 200 = 0.1338208719389513
loss=0.0027181011	 time_on_epoch 210 = 0.13309926202055067
loss=0.0027149218	 time_on_epoch 220 = 0.1331594130024314
loss=0.0027126610	 time_on_epoch 230 = 0.13304181303828955
loss=0.0027108916	 time_on_epoch 240 = 0.13323908299207687
loss=0.0027094238	 time_on_epoch 250 = 0.13319915300235152
loss=0.0027081578	 time_on_epoch 260 = 0.1330402420135215
loss=0.0027070356	 time_on_epoch 270 = 0.13318077300209552
loss=0.0027060214	 time_on_epoch 280 = 0.13328849291428924
loss=0.0027050919	 time_on_epoch 290 = 0.1330608029384166


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0025033033	 time_on_epoch 300 = 0.13397369300946593
loss=0.0025002023	 time_on_epoch 310 = 0.13306451204698533
loss=0.0024984859	 time_on_epoch 320 = 0.1330777129624039
loss=0.0024971917	 time_on_epoch 330 = 0.1330963730579242
loss=0.0024961303	 time_on_epoch 340 = 0.13312530203256756
loss=0.0024952167	 time_on_epoch 350 = 0.13308885309379548
loss=0.0024944055	 time_on_epoch 360 = 0.1330688629532233
loss=0.0024936702	 time_on_epoch 370 = 0.13315177301410586
loss=0.0024929936	 time_on_epoch 380 = 0.13311038305982947
loss=0.0024923647	 time_on_epoch 390 = 0.1330994329182431


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0023977929	 time_on_epoch 400 = 0.13366935297381133
loss=0.0023956214	 time_on_epoch 410 = 0.13307630200870335
loss=0.0023943873	 time_on_epoch 420 = 0.1334142730338499
loss=0.0023934440	 time_on_epoch 430 = 0.13307357300072908
loss=0.0023926622	 time_on_epoch 440 = 0.13311167294159532
loss=0.0023919833	 time_on_epoch 450 = 0.13307638303376734
loss=0.0023913759	 time_on_epoch 460 = 0.13314971199724823
loss=0.0023908215	 time_on_epoch 470 = 0.13303734199143946
loss=0.0023903082	 time_on_epoch 480 = 0.13307578291278332
loss=0.0023898282	 time_on_epoch 490 = 0.13305517204571515


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0023359454	 time_on_epoch 500 = 0.13412763306405395
loss=0.0023344245	 time_on_epoch 510 = 0.13310393202118576
loss=0.0023335227	 time_on_epoch 520 = 0.13315586210228503
loss=0.0023328158	 time_on_epoch 530 = 0.13313941296655685
loss=0.0023322180	 time_on_epoch 540 = 0.13303359202109277
loss=0.0023316906	 time_on_epoch 550 = 0.13307232304941863
