Copyright (c) 2025 Graphcore Ltd. All rights reserved.

# Usage

A quick guide to `weight_formats` library usage.

#### Imports

In [1]:
%env TOKENIZERS_PARALLELISM=true
%load_ext autoreload
%autoreload 2

import torch

# By convention, we use `import weight_formats.quantisation as Q`, etc.
import weight_formats.analysis as A
import weight_formats.experiments as E
import weight_formats.experiments.token_prediction as ET
import weight_formats.fit as F
import weight_formats.model_quantisation as M
import weight_formats.quantisation as Q
import weight_formats.sensitivity as S

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env: TOKENIZERS_PARALLELISM=true


#### Load a model and dataset (with calculation of reference output)

In [2]:
model = E.RequantisableModel.load("meta-llama/Llama-3.2-1B", DEVICE, torch.bfloat16)
params = {k: v.detach() for k, v in model.model.named_parameters() if v.ndim == 2}

# For a quick test - 1 batch of shape (16, 256)
data = E.token_prediction.Dataset.load(model.model, batch_size=16, sequence_limit=16, sequence_length=256, kl_topk=128)

#### Quantise a single parameter, and evaluate

In [16]:
model.reset()
p = model.model.state_dict()['model.layers.15.self_attn.v_proj.weight']

fmt = Q.CompressedLUTFormat.train_grid(p, p.std().item() / 4)
print(f"b = {fmt.count_bits_tensor(p) / p.nelement():.2f}")
print(f"R = {Q.qrmse_norm(fmt, p).cpu():.3f}")
p[...] = fmt.quantise(p)

display({k: v.cpu() for k, v in data.evaluate(model.model).items()})

b = 4.00
R = 0.072


{'cross_entropy': tensor([2.6887, 2.6609, 2.9141, 3.2589, 2.1054, 2.9393, 1.9774, 3.2417, 2.9185,
         2.8662, 3.1528, 3.0990, 2.6329, 3.1736, 2.8883, 3.0719]),
 'kl_div': tensor([0.0011, 0.0011, 0.0011, 0.0012, 0.0010, 0.0010, 0.0012, 0.0013, 0.0012,
         0.0011, 0.0012, 0.0013, 0.0013, 0.0011, 0.0011, 0.0012])}

#### Quantise all parameters using `F.Scaled.fit`

In [None]:
# This takes a few seconds, as it runs k-means per parameter tensor.
model.reset()
log = M.quantise_2d_fixed(model.model, F.Scaled(4, "lloyd_max", Q.BFLOAT16, (1, 64), "absmax", compressor=None, sparse_ratio=0))
print(f"b = {log['bits_per_param']:.2f}")
k = "model.layers.0.mlp.gate_proj.weight"
print(f"{k}: {log['params'][k]}")

b = 4.25
model.layers.0.mlp.gate_proj.weight: {'bits': 71303168, 'rmse': 0.0017555853119120002}


#### Run a mini-sweep

In [33]:
tests = [
    ET.QuantiseFixed(F.Scaled(4, "int", Q.BFLOAT16, (1, 64), "absmax")),
    ET.QuantiseFixed(F.Scaled(4, "int", Q.BFLOAT16, (1, 64), "signmax")),
]
ET.run_sweep([ET.Run("dev", test, "meta-llama/Llama-3.2-1B") for test in tests])

100%|██████████| 1/1 [00:12<00:00, 12.02s/it]
100%|██████████| 1/1 [00:12<00:00, 12.16s/it]


In [34]:
runs = E.runs("dev")
print(len(runs), "runs")
print()

for run in runs[-2:]:
    print(f"### {run.config.test.fmt_str} ###")
    print(f"     b = {run.summary.bits_per_param:.2f}")
    print(f"  D_KL = {torch.tensor(run.summary.kl_div).mean().item():.3f}")
    print()

14 runs

### 4b-int{1,64:BFLOAT16:absmax:search} ###
     b = 4.25
  D_KL = 0.196

### 4b-int{1,64:BFLOAT16:signmax:search} ###
     b = 4.25
  D_KL = 0.163

