In [1]:
# MODEL = "meta-llama/Llama-2-7b-hf"
MODEL = "meta-llama/Meta-Llama-3.1-8B"
BASE_PPL = 5.606692790985107

In [2]:
from transformers import AutoModelForCausalLM

model_pt = AutoModelForCausalLM.from_pretrained(
    '/mnt/LLM/hub/models--meta-llama--Meta-Llama-3.1-8B/snapshots/13f04ed6f85ef2aa2fd11b960a275c3e31a8069e/',
    trust_remote_code=True, torch_dtype="auto", device_map='meta',
)

def get_module_by_path(model, path):
    if path == '':
        return model
    splitted = path.split('.', 1)
    if len(splitted) == 1:
        splitted.append('')
    next_name, suffix = splitted

    try:
        next_module = model[int(next_name)]
    except:
        next_module = getattr(model, next_name)

    return get_module_by_path(next_module, suffix)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
import functools


@functools.cache
def get_numel(path):
    return get_module_by_path(model_pt, path).weight.numel()

total_params = sum(p.numel() for p in model_pt.model.layers.parameters())

In [4]:
import tqdm
import pandas as pd 
import wandb
import functools


@functools.cache
def get_df_from_wandb(path):
    api = wandb.Api()

    # Project is specified by <entity/project-name>
    runs = api.runs(path)
    
    data_df_lines = []
    for run in tqdm.tqdm(runs): 
        data_df_lines.append({
            'Name': run.name,
            'Commit': run.commit,
            **run.summary._json_dict,
            **{k: v for k,v in run.config.items() if not k.startswith('_')},
        })
    data_df = pd.DataFrame(data_df_lines)
    return data_df

In [5]:
def get_scale_by_layer(exp_path):
    data_df = get_df_from_wandb(exp_path)
    
    data_df = data_df.rename(columns={
        'wikitext2_PPL': 'wikitext2',
    })
    
    data_df = data_df[data_df['model'] == MODEL]
    data_df = data_df[['layer_idx', 'edenn_d', 'edenn_n', 'wikitext2']]
    
    data_df = data_df.dropna().copy()
    
    layer_names = []
    
    for layer_idx in range(32):
        layer_names.append(f'model.layers.{layer_idx}.self_attn.q_proj')
        layer_names.append(f'model.layers.{layer_idx}.self_attn.k_proj')
        layer_names.append(f'model.layers.{layer_idx}.self_attn.v_proj')
        layer_names.append(f'model.layers.{layer_idx}.self_attn.o_proj')
        layer_names.append(f'model.layers.{layer_idx}.mlp.gate_proj')
        layer_names.append(f'model.layers.{layer_idx}.mlp.up_proj')
        layer_names.append(f'model.layers.{layer_idx}.mlp.down_proj')

    import requests
    from ast import literal_eval
    import pandas as pd
    
    grids = literal_eval(requests.get(
        'https://gist.githubusercontent.com/BlackSamorez/c74f24a648eb8bbfbbbf83f3145ba3c7/raw/ddc3280a4861938e2e2034c29d6802817e26e799/gistfile1.txt'
    ).text)
    
    grids.append({
        'edenn_d': -1,
        'edenn_n': -1,
        'bits': 16,
        'mse': 0.0,
    })
    
    grids = pd.DataFrame(grids)
    grids['name'] = grids.apply(
        lambda row: 'edenn_d=' + str(row['edenn_d']) + ';edenn_n=' + str(row['edenn_n']),
        axis=1,
    )
    grids = grids[['bits', 'mse', 'name', 'edenn_d', 'edenn_n']]
    def get_mse(grid_tuple):
        edenn_d, edenn_n = grid_tuple
        name = f'edenn_d={edenn_d}.0;edenn_n={edenn_n}.0'
        output = grids[grids['name'] == name]['mse'].values[0]
        return output

    data_df['mse'] = data_df[['edenn_d', 'edenn_n']].apply(
        lambda row: get_mse(tuple(row.values)),
        axis=1,
    )
    
    data_df['layer'] = data_df['layer_idx'].apply(lambda idx: layer_names[idx])
    
    layers = sorted(set(data_df['layer']))
    
    from sklearn.linear_model import LinearRegression
    
    scale_by_layer = {}
    intercept_by_layer = {}
    
    for layer_idx, layer in enumerate(layers):
        to_fit = data_df[data_df['layer'] == layer]
        to_fit = to_fit[to_fit['mse'] < 0.04]
    
        slope = LinearRegression(fit_intercept=False).fit(to_fit['mse'].values.reshape(-1, 1), to_fit['wikitext2'] - BASE_PPL).coef_
        
        scale_by_layer[layer] = slope.item()
        intercept_by_layer[layer] = BASE_PPL
    
    return scale_by_layer

In [7]:
scale_by_layer_gptq = get_scale_by_layer('rock-and-roll/GALQIWI_EDENN_GPTQ')

layers = sorted(scale_by_layer_gptq.keys())

In [211]:
from ortools.linear_solver import pywraplp
import numpy as np

def find_grids_with_budget(
    slopes,    # INVALID linear coefficients for [layerwise mse -> metric]
    numels,   # INVALID linear coefficients for [layer bitwidth -> total bitwidth] (1 / num_blocks for blockwise)
    budget,    # INVALID upper bound on total bitwidth
    num_codebooks, # INVALID available grid bitwidths
    num_bits_per_codebook, # INVALID available grid bitwidths
    grid_mses  # INVALID available grid mses
) -> tuple[float, list]:
    num_layers = len(slopes)
    num_grids = len(num_codebooks)
    assert len(num_codebooks) == len(num_bits_per_codebook)
    assert len(grid_mses) == num_grids
    
    solver = pywraplp.Solver.CreateSolver("CP-SAT")

    x = {(j, i) : solver.BoolVar("name") for i in range(num_grids) for j in range(num_layers)}
    
    for j in range(num_layers) :
        solver.Add(sum(x[(j, i)] for i in range(num_grids)) == 1)
    
    solver.Add(sum(x[(j, i)] * (
        numels[j] * num_codebooks[i] * num_bits_per_codebook[i] / 8 + (2 ** num_bits_per_codebook[i]) * 8 * 16 * num_codebooks[i]
    ) for j in range(num_layers) for i in range(num_grids)) <= budget)
    
    solver.Minimize(sum(x[(j, i)] * slopes[j] * grid_mses[i] for j in range(num_layers) for i in range(num_grids)))

    status = solver.Solve()
    if status == pywraplp.Solver.OPTIMAL:
        # avg_bits = sum(x[(j, i)].solution_value() * numels[j] * grid_bits[i] for j in range(num_layers) for i in range(num_grids))
        solution = np.asarray([[x[(j, i)].solution_value() for i in range(num_grids)] for j in range(num_layers)])
        indices = np.argwhere(solution == 1.0)
        assert len(indices) == num_layers
        return 0.0, indices[:,1]
    else:
        raise Exception("Didn't solve")

In [212]:
from io import StringIO
import requests

In [213]:
aqlm_config = pd.read_csv(StringIO(requests.get('https://gist.githubusercontent.com/galqiwi/bd1ac3d724aa9fc0a058c2d0ee94d541/raw/000ec2847320d5126c856dd3ebf220833e77ceb2/aqlm_configs.csv').text))
aqlm_config.sample(3)

Unnamed: 0.1,Unnamed: 0,n_codebooks,n_bits_per_codebook,wbits,mse
26,26,2,11,3.265625,0.032575
37,37,3,6,2.289062,0.085089
50,50,4,3,1.519531,0.219392


In [214]:
aqlm_config['wbits'] = aqlm_config.apply(
    lambda row: row['n_codebooks'] * row['n_bits_per_codebook'] / 8 + (
        2 ** row['n_bits_per_codebook'] * 16 / 4096 / 4096 * 8 * row['n_codebooks']
    ),
    axis=1,
)

In [249]:
mse_by_aqlm_config = dict(zip([tuple(k) for k in aqlm_config[['n_codebooks', 'n_bits_per_codebook']].values], aqlm_config['mse']))
mse_by_aqlm_config

{(1, 1): 0.919545590877533,
 (1, 2): 0.8224712610244751,
 (1, 3): 0.7082441449165344,
 (1, 4): 0.6006791591644287,
 (1, 5): 0.5165958404541016,
 (1, 6): 0.4382802546024322,
 (1, 7): 0.3716720640659332,
 (1, 8): 0.3143085539340973,
 (1, 9): 0.2638692259788513,
 (1, 10): 0.2200138568878173,
 (1, 11): 0.1806469708681106,
 (1, 12): 0.1454421728849411,
 (1, 13): 0.1135108694434166,
 (1, 14): 0.0839976817369461,
 (1, 15): 0.0562547110021114,
 (1, 16): 0.028201874345541,
 (2, 1): 0.8393333554267883,
 (2, 2): 0.644393801689148,
 (2, 3): 0.4761597216129303,
 (2, 4): 0.3545707464218139,
 (2, 5): 0.2630593776702881,
 (2, 6): 0.1916685402393341,
 (2, 7): 0.1385392993688583,
 (2, 8): 0.0993240177631378,
 (2, 9): 0.0706783756613731,
 (2, 10): 0.0486727841198444,
 (2, 11): 0.0325747616589069,
 (2, 12): 0.0205543953925371,
 (2, 13): 0.0119948349893093,
 (2, 14): 0.0060074115172028,
 (2, 15): 0.0022380454465746,
 (2, 16): 0.0002627053763717,
 (3, 1): 0.7597394585609436,
 (3, 2): 0.4846304357051849,
 (3

In [215]:
aqlm_config.head(16)

Unnamed: 0.1,Unnamed: 0,n_codebooks,n_bits_per_codebook,wbits,mse
0,0,1,1,0.125015,0.919546
1,1,1,2,0.250031,0.822471
2,2,1,3,0.375061,0.708244
3,3,1,4,0.500122,0.600679
4,4,1,5,0.625244,0.516596
5,5,1,6,0.750488,0.43828
6,6,1,7,0.875977,0.371672
7,7,1,8,1.001953,0.314309
8,8,1,9,1.128906,0.263869
9,9,1,10,1.257812,0.220014


In [216]:
# aqlm_config = pd.concat([
#     aqlm_config,
#     pd.DataFrame([{
#         'n_codebooks': -1,
#         'n_bits_per_codebook': -1,
#         'wbits': 16,
#         'mse': 4 ** -16,
#     }])
# ])

In [290]:
aqlm_config

Unnamed: 0.1,Unnamed: 0,n_codebooks,n_bits_per_codebook,wbits,mse
0,0,1,1,0.125015,9.195456e-01
1,1,1,2,0.250031,8.224713e-01
2,2,1,3,0.375061,7.082441e-01
3,3,1,4,0.500122,6.006792e-01
4,4,1,5,0.625244,5.165958e-01
...,...,...,...,...,...
59,59,4,12,6.125000,4.258627e-04
60,60,4,13,6.750000,1.344127e-04
61,61,4,14,7.500000,3.031856e-05
62,62,4,15,8.500000,5.611464e-07


In [273]:
ok_aqlm_config = aqlm_config[((aqlm_config['n_codebooks'] >= 2) & (aqlm_config['n_bits_per_codebook'] >= 10)) | (aqlm_config['n_codebooks'] >= 3)]

In [274]:
layers = sorted(layers)

scales = [scale_by_layer_gptq[layer] for layer in layers]
numels = [get_numel(layer) for layer in layers]
# num_codebooks = aqlm

solution_size, solution_idxs = find_grids_with_budget(
    scales,
    numels,
    budget=sum(numels) * 3,
    num_codebooks=ok_aqlm_config['n_codebooks'].values,
    num_bits_per_codebook=ok_aqlm_config['n_bits_per_codebook'].values,
    grid_mses=ok_aqlm_config['mse'].values,
)

print(f'{solution_size / sum(numels)} bits')

0.0 bits


In [275]:
sum([scale * ok_aqlm_config['mse'].values[solution_idx] for scale, solution_idx in zip(scales, solution_idxs)])

0.5400175584965011

In [279]:
sum([scale * mse_by_aqlm_config[(2, 12)] for scale, solution_idx in zip(scales, solution_idxs)])

0.5994111631069067

In [287]:
opt_conf = dict(list([(layer, (ok_aqlm_config['n_codebooks'].values[idx], ok_aqlm_config['n_bits_per_codebook'].values[idx])) for layer, idx in zip(layers, solution_idxs)]))
opt_conf

{'model.layers.0.mlp.down_proj': (2, 14),
 'model.layers.0.mlp.gate_proj': (2, 10),
 'model.layers.0.mlp.up_proj': (2, 10),
 'model.layers.0.self_attn.k_proj': (2, 10),
 'model.layers.0.self_attn.o_proj': (2, 10),
 'model.layers.0.self_attn.q_proj': (3, 5),
 'model.layers.0.self_attn.v_proj': (2, 12),
 'model.layers.1.mlp.down_proj': (2, 13),
 'model.layers.1.mlp.gate_proj': (2, 10),
 'model.layers.1.mlp.up_proj': (2, 12),
 'model.layers.1.self_attn.k_proj': (2, 11),
 'model.layers.1.self_attn.o_proj': (2, 14),
 'model.layers.1.self_attn.q_proj': (3, 5),
 'model.layers.1.self_attn.v_proj': (3, 12),
 'model.layers.10.mlp.down_proj': (2, 11),
 'model.layers.10.mlp.gate_proj': (2, 10),
 'model.layers.10.mlp.up_proj': (2, 11),
 'model.layers.10.self_attn.k_proj': (2, 13),
 'model.layers.10.self_attn.o_proj': (2, 13),
 'model.layers.10.self_attn.q_proj': (2, 11),
 'model.layers.10.self_attn.v_proj': (3, 11),
 'model.layers.11.mlp.down_proj': (2, 10),
 'model.layers.11.mlp.gate_proj': (2, 10

In [288]:
n_bits = 0

for layer in layers:
    n_codebooks, n_bits_per_codebook = opt_conf[layer]
    # n_codebooks, n_bits_per_codebook = 2, 12
    layer_bits = get_numel(layer) * n_bits_per_codebook * n_codebooks / 8
    layer_bits += 2 ** n_bits_per_codebook * 16 * 8 * n_codebooks
    print(layer, layer_bits / get_numel(layer), n_codebooks, n_bits_per_codebook)
    n_bits += layer_bits

n_bits / total_params

model.layers.0.mlp.down_proj 3.5714285714285716 2 14
model.layers.0.mlp.gate_proj 2.5044642857142856 2 10
model.layers.0.mlp.up_proj 2.5044642857142856 2 10
model.layers.0.self_attn.k_proj 2.5625 2 10
model.layers.0.self_attn.o_proj 2.515625 2 10
model.layers.0.self_attn.q_proj 1.875732421875 3 5
model.layers.0.self_attn.v_proj 3.25 2 12
model.layers.1.mlp.down_proj 3.2857142857142856 2 13
model.layers.1.mlp.gate_proj 2.5044642857142856 2 10
model.layers.1.mlp.up_proj 3.017857142857143 2 12
model.layers.1.self_attn.k_proj 2.875 2 11
model.layers.1.self_attn.o_proj 3.75 2 14
model.layers.1.self_attn.q_proj 1.875732421875 3 5
model.layers.1.self_attn.v_proj 4.875 3 12
model.layers.10.mlp.down_proj 2.7589285714285716 2 11
model.layers.10.mlp.gate_proj 2.5044642857142856 2 10
model.layers.10.mlp.up_proj 2.7589285714285716 2 11
model.layers.10.self_attn.k_proj 3.75 2 13
model.layers.10.self_attn.o_proj 3.375 2 13
model.layers.10.self_attn.q_proj 2.78125 2 11
model.layers.10.self_attn.v_proj

2.9998532863849765

In [289]:
n_bits = 0

for block_idx in range(32):
    n_block_bits = 0
    n_block_numel = 0
    block_err = 0
    for layer in layers:
        if layer.startswith(f'model.layers.{block_idx}'):
            continue
        n_codebooks, n_bits_per_codebook = opt_conf[layer]

        # n_codebooks, n_bits_per_codebook = 2, 12
        
        layer_bits = get_numel(layer) * n_bits_per_codebook * n_codebooks / 8
        layer_bits += 2 ** n_bits_per_codebook * 16 * 8 * n_codebooks
        # print(layer, layer_bits / get_numel(layer), n_codebooks, n_bits_per_codebook)
        block_err += scale_by_layer_gptq[layer] * mse_by_aqlm_config[(n_codebooks, n_bits_per_codebook)]
        n_block_bits += layer_bits
        n_block_numel += get_numel(layer)
    print(n_block_bits / n_block_numel, block_err, sep='\t')

n_bits / total_params

3.0077173881436106	0.5240948105743083
3.0277041480654763	0.3499951586643389
2.979858845581502	0.3550153634281454
2.9803586600629974	0.4940040494336869
2.994342964872829	0.5239372041646434
3.0005076671060795	0.5218408296703543
3.0011667862127793	0.5226888026690527
3.0000811782723327	0.5247808469117028
3.004578696882754	0.5213248232135047
3.0045399251705955	0.522476137738773
3.0067886844758065	0.5218502323291769
3.009076215493176	0.5207743051290842
3.0067886844758065	0.5221113940491396
3.003842034351737	0.5223924931937499
3.001593275046526	0.5218110293307012
2.999383287453474	0.5235161635247775
2.9984915380738215	0.5240817005047229
2.998181364376551	0.5232810860634712
2.9958550616470223	0.5239559447847765
2.996940669587469	0.5237897902601301
2.99527348596464	0.5244701211560231
2.9986078532102978	0.5228569377939412
2.9969018978753104	0.5231169481267809
2.9975997886941688	0.5227707796497779
2.9975997886941688	0.5238492574573457
2.997948734103598	0.5237213070274888
2.999964863135856	0.52168

0.0

In [224]:
1

1

In [None]:
# Quantizing module mlp.down_proj of layer 0
# print(f"Quantizing module {sublayer_name} of layer {layer_index}")

In [41]:
sublayer_name = 'mlp.down_proj'
layer_index = 0

key = f'model.layers.{layer_index}.{sublayer_name}'
key

'model.layers.0.mlp.down_proj'

In [46]:
import requests
from ast import literal_eval

quant_aqlm_config = literal_eval(requests.get(
    'https://gist.githubusercontent.com/galqiwi/7c231a815da694fbcf374ebca14fb15f/raw/b894055c406fe9290e90d7dbc20ebd16335de5e9/optimal_3bit_aqlm'
).text)

quant_aqlm_config[key]

(3, 9)