# Usage

A quick guide to `weight_formats` library usage.

**Imports**

By convention, we use `import weight_formats.quantisation as Q`, etc.

In [None]:
%env WANDB_SILENT=true
%env TOKENIZERS_PARALLELISM=true
%load_ext autoreload
%autoreload 2

from pathlib import Path
import torch

import weight_formats.analysis as A
import weight_formats.experiments as E
import weight_formats.fit as F
import weight_formats.model_quantisation as M
import weight_formats.quantisation as Q
import weight_formats.sensitivity as S

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env: WANDB_SILENT=true
env: TOKENIZERS_PARALLELISM=true


**Load a model and dataset (with calculation of reference output)**

In [None]:
model = E.RequantisableModel.load("meta-llama/Llama-3.2-1B", DEVICE, torch.bfloat16)
params = {k: v.detach() for k, v in model.model.named_parameters() if v.ndim == 2}

# For a quick test - 1 batch of shape (16, 256)
data = E.token_prediction.Dataset.load(model.model, sequence_length=256, batch_size=16, kl_topk=128, sequence_limit=16)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


**Quantise a single parameter, and evaluate**

In [3]:
model.reset()

p = model.model.state_dict()['model.layers.15.self_attn.v_proj.weight']
fmt = Q.CompressedLUTFormat.train_grid(p, p.std().item() / 4)
print(fmt.count_bits_tensor(p) / p.nelement(), "bpp")
print("R =", Q.qrmse_norm(fmt, p).cpu())
p[...] = fmt.quantise(p)

display(data.evaluate(model.model))

3.997058629989624 bpp
R = tensor(0.0724)


{'cross_entropy': tensor([2.5975, 2.6224, 3.1826, 2.2801, 3.3747, 2.7373, 3.1326, 2.9880, 3.0249,
         2.5361, 3.1598, 3.0744, 2.7587, 2.9885, 3.0168, 2.8383],
        device='cuda:0'),
 'kl_div': tensor([0.0010, 0.0010, 0.0009, 0.0008, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010,
         0.0009, 0.0009, 0.0010, 0.0009, 0.0010, 0.0009, 0.0010],
        device='cuda:0')}

**Quantise all parameters using `F.Scaled.fit`**

This takes a few seconds, as it runs k-means per weight tensor.

In [None]:
log = M.quantise_2d_fixed(model.model, F.Scaled(4, "lloyd_max", Q.BFLOAT16, (1, 64), "absmax", compressor=None, args=dict(init="kmeans++")))
print(log["bits_per_param"], "bits/param")
display(log["params"]["model.layers.0.mlp.gate_proj.weight"])

**Run a tiny sweep**

In [None]:
tests = [
    E.token_prediction.Baseline(),
    E.token_prediction.QuantiseFixed(F.Scaled(4, "fp", Q.BFLOAT16, (1, 32), "rms", compressor=None)),
]
E.token_prediction.run_sweep([E.token_prediction.Run("dev", test, E.core.MODELS[0]) for test in test])

In [None]:
runs = E.runs("dev")
print(len(runs), "runs")

print("\nLast run:")
log = runs[-1]
print(log["summary"]["bits_per_param"], "bits/param")
print(torch.tensor(log["summary"]["kl_div"]).mean().item(), "KL")
print(torch.tensor(log["summary"]["cross_entropy"]).mean().item(), "X-Ent")

7 runs

Last run:
4.5006289 bits/param
0.3004560172557831 KL
2.523895025253296 X-Ent
