In [1]:
%env CUDA_VISIBLE_DEVICES=0
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env OMP_NUM_THREADS=16

import os
import sys
import time
from tqdm.auto import trange
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight, _reconstruct_weight  # see adjacent file (aq.py)
from src.utils import  calc_avg_bits, get_mean_nbits_by_codebook  # see adjacent file (aq.py)

torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_loading_dir = '/extra_disk_1/vahe1994/BRRR/layer10.self_attn.q_proj.input_activation.pt'
num_codebooks = 7
nbits_per_codebook = 12
out_group_size = 1
in_group_size = 32
batch_size = 16384
beam_size = 1
big_beam_size = 8
beam_search_epochs = 100
big_beam_search_epochs = 1000
sparsity_regularizer = 0
print_frequency = 10
alpha = 1.0
kmeans_init = True
naming = ""

#for each group we store num_codebooks * nbits_per_codebook - bits 
# for W matrix  we store out_features*(in_features // group_size) * num_codebooks * nbits_per_codebook - bits 
# 1 codebook store codebook_size*group_size*16 - bit 
# all codebooks store num_codebooks* codebook_size*group_size*16 -bits
in_features, out_features  = 8192, 8192

estimated_bits_per_param = calc_avg_bits(num_codebooks, out_group_size, in_group_size,
                                         nbits_per_codebook, in_features,out_features)
print("Estimated bits / param", estimated_bits_per_param)

env: CUDA_VISIBLE_DEVICES=0
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: OMP_NUM_THREADS=16
Estimated bits / param 2.140623092651367


In [2]:
import wandb

os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), ipynbname.name() + ".ipynb")

naming += ("KMeans" if kmeans_init else "")+("Sparse" if sparsity_regularizer!=0 else "")

# start a new wandb run to track this script
run = wandb.init(
    # set the wandb project where this run will be logged
    dir =os.getcwd(),
    project="AddQuantization-debug",
    entity = "rock-and-roll",
    save_code=True,
    name = f"{naming}_AQ_{num_codebooks=}_{out_group_size=}_{in_group_size=}_{nbits_per_codebook=}_{beam_search_epochs=}_{big_beam_search_epochs=}",
    settings=wandb.Settings(code_dir="."),
    # track hyperparameters and run metadata
    config={
    "num_codebooks" : num_codebooks,
    "out_group_size": out_group_size,
    "in_group_size": in_group_size,
    "group_size" : out_group_size * in_group_size,
    "batch_size" : batch_size,
    "beam_size" : beam_size,
    "big_beam_size" : big_beam_size,
    "nbits_per_codebook" : nbits_per_codebook,
    "Avg_bits": estimated_bits_per_param,
    "beam_search_epochs": beam_search_epochs,
    "big_beam_search_epochs": big_beam_search_epochs,
    "sparsity_regularizer": sparsity_regularizer,
    "alpha":alpha,
    "kmeans_init":kmeans_init
    }
)
run.log({"Avg_bits": estimated_bits_per_param})

[34m[1mwandb[0m: Currently logged in as: [33mjustheuristic[0m ([33mrock-and-roll[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-70b-hf",
                                                          torch_dtype='auto', low_cpu_mem_usage=True)

X = torch.load(input_loading_dir,
               map_location='cpu').float().flatten(0, -2)

reference_weight = model.model.layers[10].self_attn.q_proj.weight.detach().cuda().float()

XTX = torch.zeros(X.shape[-1], X.shape[-1], device=device, dtype=torch.float64)
for i in range(0, len(X), batch_size):
    x_batch = X[i: i + batch_size].cuda().double()
    XTX.addmm_(x_batch.T, x_batch, alpha=1/len(X))
    del x_batch
XTX = XTX.float()
del X

In [None]:
quantized_layer = QuantizedWeight(
    weight_shape=reference_weight.shape, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook,
    out_group_size=out_group_size, in_group_size=in_group_size, 
    device=device,
    init_kmeans=kmeans_init, reference_weight=reference_weight, alpha=1.0, verbose=True
)

opt = torch.optim.Adam(quantized_layer.parameters(), lr=1e-5, betas=(0.9, 0.95))

for epoch in range(40_000):
    epoch_losses = []
    start = time.perf_counter()
    
    reconstructed_weight = _reconstruct_weight(quantized_layer.codes, quantized_layer.codebooks)
    delta_weight = (reconstructed_weight - reference_weight).double()
    loss = (delta_weight @ XTX.double()).flatten() @ delta_weight.flatten() / len(delta_weight)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    run.log({'loss':loss.item()}, step=epoch)
    
    if epoch % print_frequency == 0:
        print(f"loss={loss.item():.10f}\t",
              f"time_on_epoch {epoch} = {time.perf_counter() - start}")
    if (epoch + 1) % beam_search_epochs == 0:
        if (epoch + 1) % big_beam_search_epochs == 0:
            print("BIG beam search")
        quantized_layer.requantize_(
            XTX, reference_weight,
            beam_size=beam_size if (epoch + 1) % 1000 != 0 else big_beam_size,
            sparsity_regularizer=sparsity_regularizer,  # tip: use const_hparam * quantized_layer.codes.numel()
            verbose=True)
        if sparsity_regularizer != 0:
            sparsity_rate = ((quantized_layer.codes == 0).sum() / quantized_layer.codes.numel()).item()
            print(f"Sparsity rate {sparsity_rate:.5f}")
            run.log({'sparsity rate': sparsity_rate}, step=epoch)
            if (epoch + 1) % big_beam_search_epochs == 0:
                mean_code_nbits = sum(get_mean_nbits_by_codebook(quantized_layer.codes)) / num_codebooks
                print(f"mean_code_nbits {mean_code_nbits:.5f}")
                run.log({'Mean codebook ldngth nbits': mean_code_nbits}, step=epoch)
                if in_group_size>1 and out_group_size>1:
                    curr_avg_bits  = calc_avg_bits(num_codebooks, 1, mean_code_nbits,
                                         nbits_per_codebook, in_features,out_features)
                    run.log({"Avg_bits": curr_avg_bits}, step=epoch)