In [2]:
%load_ext autoreload
%autoreload 2

import collections
import logging
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import tensor

import weight_formats.experiments as E

logging.basicConfig(level=logging.WARNING, force=True)
matplotlib.rcParams.update({
    "axes.spines.top": False, "axes.spines.right": False, "legend.frameon": False,
    "figure.figsize": (5, 3),
})

2025-07-10:09:37:50,448 INFO     [rouge_scorer.py:83] Using default tokenizer.
2025-07-10:09:37:51,149 INFO     [_client.py:1025] HTTP Request: GET https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json "HTTP/1.1 200 OK"


In [3]:
runs = E.runs("20250708-qat-main", progress=True)

query: 57it [00:00, 140.71it/s]


In [5]:
display(collections.Counter(((run.config.model.split("-")[-1], run.meta.status) for run in runs)))

Counter({('1B', 'finished'): 19,
         ('3B', 'finished'): 19,
         ('8B', 'finished'): 19})

In [6]:
df = pd.DataFrame.from_records([
    dict(
        model=run.config.model.split("-")[-1],
        fmt=run.config.test.get("fmt_str", "bfloat16"),
        bits_per_param=run.summary.bits_per_param,
        neg_valid_kl_div=-run.summary.valid_kl_div,
        **{k: v.primary_score for k, v in sorted(run.summary.downstream.items())},
    )
    for run in runs
    if run.meta.status == "finished"
])
# df.set_index(["model", "fmt", "bits_per_param"]).sort_values("neg_valid_kl_div", ascending=False).style.background_gradient()
for model, d in df.groupby("model"):
    print(f"# {model}")
    display(d.sort_values("neg_valid_kl_div", ascending=False).style.background_gradient())

# 1B


Unnamed: 0,model,fmt,bits_per_param,neg_valid_kl_div,arc_challenge:mc,arc_easy:mc,boolq,csqa:mc,hellaswag,openbookqa:mc,piqa,socialiqa:mc,winogrande
0,1B,bfloat16,16.0,-0.001064,0.317726,0.550877,0.642202,0.418509,0.655,0.428,0.748096,0.467247,0.605367
15,1B,"5b-int+Zoptimal{*,*:BFLOAT16:rms:search}",5.000572,-0.02261,0.334448,0.522807,0.637615,0.380016,0.652,0.394,0.752992,0.440123,0.60221
17,1B,"5b-t(mode=asymmetric){1,64:BFLOAT16:absmax:search}",5.250635,-0.02549,0.327759,0.526316,0.625076,0.432432,0.648,0.42,0.745919,0.462129,0.602999
16,1B,"5b-t(mode=asymmetric){*,*:BFLOAT16:rms:search}+S[1e-03:None]",5.047522,-0.033453,0.38796,0.517544,0.637615,0.381654,0.642,0.41,0.744287,0.454452,0.587214
18,1B,"5b-t(mode=asymmetric){1,*:BFLOAT16:absmax:search}",5.007188,-0.04583,0.331104,0.517544,0.633945,0.385749,0.633,0.438,0.741023,0.440123,0.603788
9,1B,"4b-int+Zoptimal{*,*:BFLOAT16:rms:search}",4.003795,-0.094183,0.317726,0.492982,0.60367,0.386568,0.634,0.358,0.742111,0.408393,0.588003
11,1B,"4b-t(mode=asymmetric){1,64:BFLOAT16:absmax:search}",4.250666,-0.1093,0.317726,0.457895,0.61896,0.377559,0.635,0.362,0.736126,0.41044,0.601421
10,1B,"4b-t(mode=asymmetric){*,*:BFLOAT16:rms:search}+S[1e-03:None]",4.047554,-0.154208,0.364548,0.475439,0.633639,0.344799,0.63,0.362,0.73395,0.429887,0.581689
12,1B,"4b-t(mode=asymmetric){1,*:BFLOAT16:absmax:search}",4.007219,-0.185236,0.274247,0.412281,0.623853,0.324324,0.616,0.314,0.723069,0.390993,0.5809
19,1B,"5b-t(mode=asymmetric){*,*:BFLOAT16:absmax:search}",5.00065,-0.199299,0.331104,0.429825,0.599388,0.316953,0.614,0.32,0.725245,0.394575,0.58011


# 3B


Unnamed: 0,model,fmt,bits_per_param,neg_valid_kl_div,arc_challenge:mc,arc_easy:mc,boolq,csqa:mc,hellaswag,openbookqa:mc,piqa,socialiqa:mc,winogrande
1,3B,bfloat16,16.0,-0.000912,0.719064,0.817544,0.736391,0.641278,0.749,0.626,0.775299,0.625384,0.690608
33,3B,"5b-int+Zoptimal{*,*:BFLOAT16:rms:search}",5.000959,-0.015582,0.695652,0.808772,0.739144,0.652744,0.744,0.634,0.773667,0.622313,0.691397
35,3B,"5b-t(mode=asymmetric){1,64:BFLOAT16:absmax:search}",5.250617,-0.016523,0.705686,0.829825,0.730275,0.625717,0.749,0.65,0.769859,0.625896,0.692976
34,3B,"5b-t(mode=asymmetric){*,*:BFLOAT16:rms:search}+S[1e-03:None]",5.047504,-0.022491,0.702341,0.831579,0.720489,0.650287,0.754,0.622,0.771491,0.625896,0.696133
36,3B,"5b-t(mode=asymmetric){1,*:BFLOAT16:absmax:search}",5.005125,-0.029194,0.682274,0.812281,0.719266,0.642097,0.749,0.622,0.776931,0.625384,0.704815
27,3B,"4b-int+Zoptimal{*,*:BFLOAT16:rms:search}",4.003353,-0.049629,0.688963,0.8,0.725382,0.597052,0.74,0.598,0.774755,0.599795,0.681137
29,3B,"4b-t(mode=asymmetric){1,64:BFLOAT16:absmax:search}",4.250656,-0.061345,0.665552,0.812281,0.71896,0.619165,0.738,0.628,0.773123,0.620266,0.679558
28,3B,"4b-t(mode=asymmetric){*,*:BFLOAT16:rms:search}+S[1e-03:None]",4.047543,-0.078931,0.632107,0.780702,0.722324,0.561016,0.723,0.584,0.772035,0.579836,0.689029
30,3B,"4b-t(mode=asymmetric){1,*:BFLOAT16:absmax:search}",4.005164,-0.114173,0.632107,0.764912,0.723242,0.561835,0.71,0.584,0.779108,0.593654,0.66614
21,3B,"3b-int+Zoptimal{*,*:BFLOAT16:rms:search}",2.998034,-0.246294,0.511706,0.707018,0.506116,0.502047,0.701,0.482,0.757345,0.527636,0.658248


# 8B


Unnamed: 0,model,fmt,bits_per_param,neg_valid_kl_div,arc_challenge:mc,arc_easy:mc,boolq,csqa:mc,hellaswag,openbookqa:mc,piqa,socialiqa:mc,winogrande
2,8B,bfloat16,16.0,-0.001009,0.795987,0.9,0.822018,0.701884,0.808,0.76,0.818281,0.647902,0.737174
51,8B,"5b-int+Zoptimal{*,*:BFLOAT16:rms:search}",5.000855,-0.014903,0.80602,0.905263,0.818043,0.697789,0.806,0.762,0.813384,0.648925,0.740331
53,8B,"5b-t(mode=asymmetric){1,64:BFLOAT16:absmax:search}",5.250371,-0.016905,0.795987,0.903509,0.824159,0.69697,0.801,0.754,0.815016,0.638178,0.736385
52,8B,"5b-t(mode=asymmetric){*,*:BFLOAT16:rms:search}+S[1e-03:None]",5.047253,-0.024815,0.772575,0.891228,0.812232,0.698608,0.802,0.744,0.813384,0.638178,0.738753
54,8B,"5b-t(mode=asymmetric){1,*:BFLOAT16:absmax:search}",5.003632,-0.034894,0.795987,0.894737,0.818349,0.678133,0.807,0.742,0.809576,0.63306,0.74191
45,8B,"4b-int+Zoptimal{*,*:BFLOAT16:rms:search}",4.001627,-0.051359,0.779264,0.885965,0.778899,0.678952,0.801,0.744,0.813384,0.636643,0.729282
47,8B,"4b-t(mode=asymmetric){1,64:BFLOAT16:absmax:search}",4.250397,-0.066579,0.792642,0.885965,0.801223,0.678952,0.794,0.724,0.807943,0.625384,0.740331
46,8B,"4b-t(mode=asymmetric){*,*:BFLOAT16:rms:search}+S[1e-03:None]",4.047279,-0.095332,0.795987,0.882456,0.762997,0.660115,0.791,0.736,0.802503,0.618731,0.711918
48,8B,"4b-t(mode=asymmetric){1,*:BFLOAT16:absmax:search}",4.003658,-0.137074,0.802676,0.852632,0.763303,0.645373,0.792,0.69,0.800871,0.590583,0.721389
39,8B,"3b-int+Zoptimal{*,*:BFLOAT16:rms:search}",3.000824,-0.2052,0.715719,0.84386,0.73792,0.653563,0.792,0.638,0.807399,0.586489,0.722178


In [46]:
b = df[(df.model == "1B") & (df.fmt == "bfloat16")].drop(columns=["model", "fmt", "bits_per_param", "neg_valid_kl_div", "csqa:mc"])
d = df[(df.model == "1B") & (df.fmt == "3b-t(mode=asymmetric){1,64:BFLOAT16:absmax:search}")].drop(columns=["model", "fmt", "bits_per_param", "neg_valid_kl_div", "csqa:mc"])
display(d.style.format(lambda x: f"{x*100:.1f}"))
display((d.reset_index(drop=True) - b.reset_index(drop=True)).style.format(lambda x: f"{x*100:.1f}"))

Unnamed: 0,arc_challenge:mc,arc_easy:mc,boolq,hellaswag,openbookqa:mc,piqa,socialiqa:mc,winogrande
5,22.7,26.5,54.5,43.6,25.6,64.1,32.8,53.6


Unnamed: 0,arc_challenge:mc,arc_easy:mc,boolq,hellaswag,openbookqa:mc,piqa,socialiqa:mc,winogrande
0,-9.0,-28.6,-9.7,-21.9,-17.2,-10.7,-13.9,-6.9
