Copyright (c) 2025 Graphcore Ltd. All rights reserved.

# Follow-up

Addressing comments regarding the work.

In [63]:
%load_ext autoreload
%autoreload 2

import collections
import logging
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import tensor, Tensor
from typing import Any

import weight_formats.quantisation as Q
import weight_formats.experiments as E
import plot_utils

def to_markdown(d: pd.DataFrame) -> str:
    s = "| " + " | ".join(d.columns) + " |\n"
    s += "| " + " | ".join(":-:" for _ in d.columns) + " |\n"
    for _, row in d.iterrows():
        s += "| " + " | ".join(map(str, row)) + " |\n"
    return s

def flatten_columns(d: pd.DataFrame) -> pd.DataFrame:
    d.columns = ["_".join(c) for c in d.columns]
    return d

logging.basicConfig(level=logging.WARNING, force=True)
plot_utils.configure()
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
torch.set_default_device(DEVICE)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Recommend (Ubuntu):
  sudo apt-get install cm-super dvipng fonts-cmu texlive-latex-extra


## Downstream tasks and QAT

In [None]:
def fmt_name(fmt: E.AttrDict) -> str:
    s = [{(None, None): "Tensor", (1, None): "Channel"}.get(tuple(fmt.block_shape), "Block") + f" {fmt.scaling.capitalize()}"]
    if fmt.compressor:
        s.append("Compression")
    if fmt.sparse_ratio:
        s.append("Sparse")
    return " + ".join(s).replace("Rms", "RMS")

def load_run(run: E.AttrDict) -> dict[str, Any]:
    # Config
    d = dict(
        id=run.id,
        tag=run.config.tag,
        model=run.config.model.split("-")[-1],
    )
    if run.config.test.type == "qat":
        # Config (QAT)
        d.update(
            steps=run.config.train.steps,
            fmt=fmt_name(run.config.test.fmt),
            element_bits=run.config.test.fmt.element_bits,
        )
    else:
        d.update(steps=None, fmt=None, element_bits=None)
    if "train" in run.summary:
        # Training log
        d.update(
            step=torch.arange(len(run.summary.train.loss)).mul(run.config.train.log_interval).tolist(),
            step_loss=run.summary.train.loss,
            step_valid_kl=run.summary.train.valid_kl_div,
        )
    # Summary
    d.update(
        bits_per_param=run.summary.get("bits_per_param"),
        valid_kl=run.summary.get("valid_kl_div"),
        downstream={k: v.primary_score for k, v in sorted(run.summary.get("downstream", {}).items())},
    )
    return d


df = pd.DataFrame.from_records([
    load_run(run) for run in E.runs("20250708-qat-main", progress=True)
])
df.head()

query: 0it [00:00, ?it/s]

query: 204it [00:00, 226.01it/s]


Unnamed: 0,id,tag,model,steps,fmt,element_bits,step,step_loss,step_valid_kl,bits_per_param,valid_kl,downstream
0,20250708-qat-main/tScYXZ23ky,baseline,1B,,,,[0.0],[None],[0.00095602684],16.0,0.001064,"{'arc_challenge:mc': 0.31772575, 'arc_easy:mc'..."
1,20250708-qat-main/KS6wJynNyU,baseline,3B,,,,[0.0],[None],[0.00080347393],16.0,0.000912,"{'arc_challenge:mc': 0.71906355, 'arc_easy:mc'..."
2,20250708-qat-main/zev9ClTfSs,baseline,8B,,,,[0.0],[None],[0.00091089529],16.0,0.001009,"{'arc_challenge:mc': 0.79598662, 'arc_easy:mc'..."
3,20250708-qat-main/cPgba6wDKM,direct-cast,1B,0.0,Tensor RMS + Compression,3.0,[0.0],[None],[0.51605082],3.007565,0.515736,"{'arc_challenge:mc': 0.22742475, 'arc_easy:mc'..."
4,20250708-qat-main/G5xBOc0sd1,direct-cast,1B,0.0,Tensor RMS + Sparse,3.0,[0.0],[None],[1.1893706],3.047597,1.185388,"{'arc_challenge:mc': 0.32107023, 'arc_easy:mc'..."


In [None]:
def format_columns(s: pd.Series):
    if s.name.lower() == "format":
        return s.fillna("Baseline")
    if s.name == "b":
        return s.apply("{:.2f}".format)
    if s.name.lower().startswith("kl"):
        return s.apply("{:.3f}".format)
    if s.name.lower().startswith("downstream"):
        return s.apply("{:.3f}".format)
    if s.name in {"arc_challenge:mc", "arc_easy:mc", "boolq", "csqa:mc", "hellaswag", "openbookqa:mc", "piqa", "socialiqa:mc", "winogrande"}:
        return s.apply(lambda x: f"{100*x:.1f}")
    return s.apply(str)

d = df[df.tag.isin(["baseline", "direct-cast", "qat-v2"]) & (df.model == "8B") & (df.element_bits.isna() | (df.element_bits == 3))]
downstream_baselines = d[d.fmt.isna()].downstream.iloc[0]

(d.pipe(lambda d: d.assign(downstream_mean_ratio=d.apply(lambda s: tensor([
        torch.tensor(s.downstream[task] / baseline_accuracy).clip(0, 1)
        for task, baseline_accuracy in downstream_baselines.items()
    ]).mean().item(), axis=1)))
.pipe(lambda d: d[~d.fmt.isna()])
.pipe(lambda d: d.assign(kind=d.steps.apply(lambda s: "qat" if s else "dc")))
.pivot(index=["fmt"], columns="kind", values=["valid_kl", "downstream_mean_ratio"])
.pipe(flatten_columns)
[["valid_kl_dc", "downstream_mean_ratio_dc", "downstream_mean_ratio_qat"]]
.reset_index()
.rename(columns=dict(
    fmt="Format", valid_kl_dc="KL", downstream_mean_ratio_dc="Downstream Mean Ratio",
    downstream_mean_ratio_qat="Downstream Mean Ratio (QAT)"))
.sort_values("KL")
.apply(format_columns)
.pipe(lambda d: display(d.style.hide()) or print(to_markdown(d)))
)

Format,KL,Downstream Mean Ratio,Downstream Mean Ratio (QAT)
Tensor RMS + Compression,0.205,0.929,0.972
Tensor RMS + Sparse,0.503,0.799,0.951
Channel Absmax,1.075,0.608,0.892
Block Absmax,1.264,0.545,0.925
Tensor Absmax,4.577,0.466,0.754
Tensor RMS,9.115,0.429,0.471


| Format | KL | Downstream Mean Ratio | Downstream Mean Ratio (QAT) |
| :-: | :-: | :-: | :-: |
| Tensor RMS + Compression | 0.205 | 0.929 | 0.972 |
| Tensor RMS + Sparse | 0.503 | 0.799 | 0.951 |
| Channel Absmax | 1.075 | 0.608 | 0.892 |
| Block Absmax | 1.264 | 0.545 | 0.925 |
| Tensor Absmax | 4.577 | 0.466 | 0.754 |
| Tensor RMS | 9.115 | 0.429 | 0.471 |



In [157]:
for tag in ["direct-cast", "qat-v2"]:
    print(f"### {tag}")
    (df[df.tag.isin(["baseline", tag]) & (df.model == "8B") & (df.element_bits.isna() | (df.element_bits == 3))]
    [["fmt", "bits_per_param", "valid_kl", "downstream"]]
    .pipe(lambda d: pd.concat([d, d.downstream.apply(pd.Series)], axis=1))
    .drop(columns="downstream")
    .sort_values("valid_kl")
    .rename(columns=dict(
        fmt="Format", bits_per_param="b", valid_kl="KL"))
    .apply(format_columns)
    .pipe(lambda d: display(d.style.hide()) or print(to_markdown(d)))
    )

### direct-cast


Format,b,KL,arc_challenge:mc,arc_easy:mc,boolq,csqa:mc,hellaswag,openbookqa:mc,piqa,socialiqa:mc,winogrande
Baseline,16.0,0.001,79.6,90.0,82.2,70.2,80.8,76.0,81.8,64.8,73.7
Tensor RMS + Compression,3.0,0.205,71.6,84.4,73.8,65.4,79.2,63.8,80.7,58.6,72.2
Tensor RMS + Sparse,3.05,0.503,47.2,71.6,63.7,53.9,72.2,50.2,77.1,53.6,68.8
Channel Absmax,3.0,1.075,27.1,37.5,69.2,24.8,65.3,31.8,73.9,35.3,62.5
Block Absmax,3.25,1.264,24.4,33.0,62.0,21.9,46.1,27.0,70.4,37.1,59.2
Tensor Absmax,3.0,4.577,21.1,24.6,51.0,19.9,37.4,27.0,60.5,32.2,51.7
Tensor RMS,3.0,9.115,26.8,27.2,41.7,19.1,26.4,24.4,50.5,32.9,49.4


| Format | b | KL | arc_challenge:mc | arc_easy:mc | boolq | csqa:mc | hellaswag | openbookqa:mc | piqa | socialiqa:mc | winogrande |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Baseline | 16.00 | 0.001 | 79.6 | 90.0 | 82.2 | 70.2 | 80.8 | 76.0 | 81.8 | 64.8 | 73.7 |
| Tensor RMS + Compression | 3.00 | 0.205 | 71.6 | 84.4 | 73.8 | 65.4 | 79.2 | 63.8 | 80.7 | 58.6 | 72.2 |
| Tensor RMS + Sparse | 3.05 | 0.503 | 47.2 | 71.6 | 63.7 | 53.9 | 72.2 | 50.2 | 77.1 | 53.6 | 68.8 |
| Channel Absmax | 3.00 | 1.075 | 27.1 | 37.5 | 69.2 | 24.8 | 65.3 | 31.8 | 73.9 | 35.3 | 62.5 |
| Block Absmax | 3.25 | 1.264 | 24.4 | 33.0 | 62.0 | 21.9 | 46.1 | 27.0 | 70.4 | 37.1 | 59.2 |
| Tensor Absmax | 3.00 | 4.577 | 21.1 | 24.6 | 51.0 | 19.9 | 37.4 | 27.0 | 60.5 | 32.2 | 51.7 |
| Tensor RMS | 3.00 | 9.115 | 26.8 | 27.2 | 41.7 | 19.1 | 26.4 | 24.4 | 50.5 | 32.9 | 49.4 |

### qat-v2


Format,b,KL,arc_challenge:mc,arc_easy:mc,boolq,csqa:mc,hellaswag,openbookqa:mc,piqa,socialiqa:mc,winogrande
Baseline,16.0,0.001,79.6,90.0,82.2,70.2,80.8,76.0,81.8,64.8,73.7
Tensor RMS + Compression,3.0,0.09,75.9,90.2,79.4,67.7,79.8,72.8,80.6,61.5,72.8
Tensor RMS + Sparse,3.05,0.129,73.9,86.5,76.5,66.3,78.6,70.6,79.8,61.4,71.6
Block Absmax,3.25,0.132,64.9,84.2,79.4,63.3,78.2,65.4,80.0,59.5,72.7
Channel Absmax,3.0,0.152,64.5,80.0,79.7,58.8,75.5,62.2,78.2,56.5,69.3
Tensor RMS,3.0,0.169,20.1,27.2,61.1,19.1,35.0,26.6,56.2,32.8,51.4
Tensor Absmax,3.0,0.286,46.8,65.4,69.8,41.1,70.4,43.4,75.7,49.9,66.5


| Format | b | KL | arc_challenge:mc | arc_easy:mc | boolq | csqa:mc | hellaswag | openbookqa:mc | piqa | socialiqa:mc | winogrande |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Baseline | 16.00 | 0.001 | 79.6 | 90.0 | 82.2 | 70.2 | 80.8 | 76.0 | 81.8 | 64.8 | 73.7 |
| Tensor RMS + Compression | 3.00 | 0.090 | 75.9 | 90.2 | 79.4 | 67.7 | 79.8 | 72.8 | 80.6 | 61.5 | 72.8 |
| Tensor RMS + Sparse | 3.05 | 0.129 | 73.9 | 86.5 | 76.5 | 66.3 | 78.6 | 70.6 | 79.8 | 61.4 | 71.6 |
| Block Absmax | 3.25 | 0.132 | 64.9 | 84.2 | 79.4 | 63.3 | 78.2 | 65.4 | 80.0 | 59.5 | 72.7 |
| Channel Absmax | 3.00 | 0.152 | 64.5 | 80.0 | 79.7 | 58.8 | 75.5 | 62.2 | 78.2 | 56.5 | 69.3 |
| Tensor RMS | 3.00 | 0.169 | 20.1 | 27.2 | 61.1 | 19.1 | 35.0 | 26.6 | 56.2 | 32.8 | 51.4 |
| Tensor Absmax | 3.00 | 0.286 | 46.8 | 65.4 | 69.8 | 41.1 | 70.4 | 43.4 | 75.7 | 49.9 | 66.5 |



## Huffman overhead

In [None]:
runs = E.runs("20250729-results-additional")
df = pd.DataFrame.from_records([dict(
    fmt=run.config.test.fmt_str.replace(":BFLOAT16", "").replace(":search", "").replace("(mode=asymmetric)", ""),
    bpp=run.summary.bits_per_param,
    kl_div=tensor(run.summary.kl_div).mean().item(),
) for run in runs if "duration" in run.meta and run.config.test.fmt.vq_length is None and run.config.model.split("-")[-1] == "8B"])
df.sort_values(["kl_div"])

Unnamed: 0,fmt,bpp,kl_div
0,"3b-int+Zoptimal{*,*:rms}",3.000367,0.206721
4,"3b-int+Zhuffman{*,*:rms}",3.003872,0.223946
3,"3b-t{*,*:rms}+S[1e-03]",3.047305,0.503222
1,"3b-t{1,64:absmax}",3.250423,1.273593
2,"3b-lloyd_max{1,64:absmax}",3.250423,5.362602


## VQ

### Experiments

In [10]:
runs = E.runs("20250729-results-additional")

In [None]:
def fmt_name(fmt: E.AttrDict) -> str:
    s = []
    if fmt.vq_length:
        s.append(f"VQ[{fmt.vq_length:.0f}]")
    else:
        s.append(dict(int="INT", lloyd_max="Lloyd-Max", t=r"$\sqrt[3]{p}$ t")[fmt.element_family])
    s.append({(None, None): "Tensor", (1, None): "Channel"}.get(tuple(fmt.block_shape), "Block") + f" {fmt.scaling.capitalize().replace('Rms', 'RMS')}")
    if fmt.compressor:
        s.append("Compression")
    if fmt.sparse_ratio:
        s.append("Sparse")
    if fmt.rotation:
        s.append("Rotation")
    return " + ".join(s)

df = pd.DataFrame.from_records([dict(
    model_size=run.config.model.split("-")[-1],
    fmt=fmt_name(fmt),
    rotation=bool(fmt.rotation),
    bits_per_param=run.summary.bits_per_param,
    kl_div=tensor(run.summary.kl_div).mean().item(),
) for run in runs
  if run.meta.get("status") == "finished"
  for fmt in [run.config.test.fmt]
  if fmt.compressor != "huffman"
  if run.config.model.startswith("meta-llama")
])

Unnamed: 0,model_size,fmt,rotation,bits_per_param,kl_div
0,1B,VQ[2] + Tensor RMS,False,3.0009,7.26792
1,1B,VQ[2] + Block Absmax,False,3.250885,0.730076
2,1B,VQ[2] + Block Signmax,False,3.250885,0.582039
3,1B,VQ[2] + Tensor RMS + Sparse,False,3.047772,0.67185
4,1B,VQ[4] + Tensor RMS,False,3.024682,0.694887


In [45]:
(df.pipe(lambda d: d[d.model_size=="8B"]).drop(columns="model_size")
 .pipe(lambda d: d[~d.rotation])
 .drop(columns="rotation")
 .sort_values("kl_div")
)

Unnamed: 0,fmt,bits_per_param,kl_div
84,INT + Tensor RMS + Compression,3.000367,0.206721
22,VQ[4] + Block Signmax,3.2578,0.260657
23,VQ[4] + Tensor RMS + Sparse,3.054683,0.263678
21,VQ[4] + Block Absmax,3.2578,0.289415
18,VQ[2] + Block Signmax,3.25048,0.299134
88,$\sqrt[3]{p}$ t + Block Signmax,3.250423,0.304966
19,VQ[2] + Tensor RMS + Sparse,3.047363,0.346648
17,VQ[2] + Block Absmax,3.25048,0.3943
20,VQ[4] + Tensor RMS,3.007809,0.431194
89,$\sqrt[3]{p}$ t + Tensor RMS + Sparse,3.047305,0.503222


In [59]:
selected_formats = [
    "INT + Tensor RMS + Compression",
    "VQ[4] + Block Signmax",
    "VQ[2] + Block Signmax",
    r"$\sqrt[3]{p}$ t + Block Signmax",
]
def format_columns(s: pd.Series):
    if s.name == "b":
        return s.apply("{:.2f}".format)
    if s.name.lower().startswith("kl"):
        return s.apply("{:.3f}".format)
    return s.apply(str)

(df.pipe(lambda d: d[~d.rotation].drop(columns="rotation"))
.pipe(lambda d: d[d.fmt.isin(selected_formats)])
.pivot(index=["fmt"], columns="model_size", values=["kl_div"])
.pipe(flatten_columns)
.reset_index()
.rename(columns=dict(fmt="Format", kl_div_8B="KL (Llama 8B)", kl_div_3B="KL (3B)", kl_div_1B="KL (1B)"))
[["Format", "KL (Llama 8B)", "KL (3B)", "KL (1B)"]]
.apply(format_columns)
.sort_values("KL (Llama 8B)")
.pipe(lambda d: display(d.style.hide()) or print(to_markdown(d)))
)

Format,KL (Llama 8B),KL (3B),KL (1B)
INT + Tensor RMS + Compression,0.207,0.247,0.515
VQ[4] + Block Signmax,0.261,0.279,0.472
VQ[2] + Block Signmax,0.299,0.423,0.582
$\sqrt[3]{p}$ t + Block Signmax,0.305,0.345,0.692


| Format | KL (Llama 8B) | KL (3B) | KL (1B) |
| :-: | :-: | :-: | :-: |
| INT + Tensor RMS + Compression | 0.207 | 0.247 | 0.515 |
| VQ[4] + Block Signmax | 0.261 | 0.279 | 0.472 |
| VQ[2] + Block Signmax | 0.299 | 0.423 | 0.582 |
| $\sqrt[3]{p}$ t + Block Signmax | 0.305 | 0.345 | 0.692 |



### Student-t example

In [73]:
torch.manual_seed(100)
t_df = 10
bits_per_param = 3

x = torch.distributions.StudentT(t_df).sample((2**22,))
sfmt = Q.lut_lloyd_max(x, bits_per_param, 10**-4)
vfmt2 = Q.vlut_lloyd_max(x.view(-1, 2), bits_per_param, 10**-3, Q.BFLOAT16)
vfmt4 = Q.vlut_lloyd_max(x.view(-1, 4), bits_per_param, 10**-3, Q.BFLOAT16)
cfmt = Q.CompressedLUTFormat.train_grid(x, 0.58)
print(f"{'fmt':<25}  {'b':<3}  {'R'}")
for fmt in [sfmt, vfmt2, vfmt4, cfmt]:
    print(f"{str(fmt):<25}  {fmt.count_bits_tensor(x) / x.nelement():.1f}  {Q.qrmse_norm(fmt, x):.3f}")

fmt                        b    R
LUT3[LM]                   3.0  0.212
VLUT3x2[LM]                3.0  0.186
VLUT3x4[LM]                3.1  0.161
LUT6[GRID{0.58}]+Zoptimal  3.0  0.150
