# Compression
Doing compression with modular addition to see how it affects interaction matrices and computation. 

## Compact experiment: bottleneck sweep for modular addition (P=64)
This notebook runs a lean experiment that trains bilinear+projection models for `(a+b) mod 64` across varying hidden/bottleneck dimensions.
It records the smallest hidden dimension that achieves perfect generalisation (100% accuracy on the full dataset) and lets you inspect interaction matrices via a slider.

In [1]:
# Compact experiment: imports and model definitions
import torch
from torch import nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import einsum
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import pickle
import datetime

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Bilinear layer similar to the other notebook (left/right via chunk)
class Bilinear(nn.Linear):
    def __init__(self, d_in: int, d_out: int, bias=False) -> None:
        super().__init__(d_in, 2 * d_out, bias=bias)
    def forward(self, x):
        left, right = super().forward(x).chunk(2, dim=-1)
        return left * right
    @property
    def w_l(self):
        return self.weight.chunk(2, dim=0)[0]
    @property
    def w_r(self):
        return self.weight.chunk(2, dim=0)[1]

@dataclass
class ModelConfig:
    p: int = 64
    d_hidden: int | None = None
    bias: bool = False

class Model(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.bi_linear = Bilinear(d_in=2*cfg.p, d_out=cfg.d_hidden, bias=cfg.bias)
        self.projection = nn.Linear(cfg.d_hidden, cfg.p, bias=cfg.bias)
    def forward(self, x):
        return self.projection(self.bi_linear(x))
    @property
    def w_l(self):
        return self.bi_linear.w_l
    @property
    def w_r(self):
        return self.bi_linear.w_r
    @property
    def w_p(self):
        return self.projection.weight

def init_model(p, d_hidden):
    cfg = ModelConfig(p=p, d_hidden=d_hidden, bias=False)
    return Model(cfg)

# Data generation for modular addition
def generate_modular_addition_data(P: int):
    a_vals = torch.arange(P).repeat_interleave(P)
    b_vals = torch.arange(P).repeat(P)
    a_1hot = F.one_hot(a_vals, num_classes=P).float()
    b_1hot = F.one_hot(b_vals, num_classes=P).float()
    x_1hot = torch.cat((a_1hot, b_1hot), dim=-1)
    targets = (a_vals + b_vals) % P
    return x_1hot, targets

# Create train/validation split (no separate test set)
def make_splits(x, y, train_frac=0.8, batch_size=256):
    # shuffle
    n = x.size(0)
    perm = torch.randperm(n)
    x = x[perm]
    y = y[perm]
    n_train = int(train_frac * n)
    x_train = x[:n_train]
    y_train = y[:n_train]
    x_val = x[n_train:]
    y_val = y[n_train:]
    train_ds = torch.utils.data.TensorDataset(x_train, y_train)
    val_ds = torch.utils.data.TensorDataset(x_val, y_val)
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, (x_val, y_val)

print('Ready')

Device: cuda
Ready


Train with cross-entropy loss, having a Bilinear Layer in a Transformer architecture in mind, where the last predictive step is done via softmax. 

In [38]:
# Training utilities (lean) - now uses train and validation loaders
def train_model(model, train_loader, val_loader, epochs=500, lr=3e-3, weight_decay=1e-4):
    model = model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    crit = nn.CrossEntropyLoss()
    history = {'train_loss': [], 'val_acc': []}
    for ep in tqdm(range(epochs), desc='epochs'):
        model.train()
        losses = []
        for xb, tb in train_loader:
            xb = xb.to(device)
            tb = tb.to(device)
            logits = model(xb)
            loss = crit(logits, tb)

            opt.zero_grad()
            loss.backward()
            opt.step()
            losses.append(loss.item())
        avg_loss = sum(losses)/len(losses) if len(losses)>0 else 0.0
        history['train_loss'].append(avg_loss)

        # validation
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for xb, tb in val_loader:
                xb = xb.to(device)
                tb = tb.to(device)
                logits = model(xb)
                preds = logits.argmax(dim=-1)
                correct += (preds == tb).sum().item()
                total += tb.numel()
            val_acc = correct/total if total>0 else 0.0
        history['val_acc'].append(val_acc)
        if val_acc == 1.0:
            print(f"Grokking at epoch {ep}.")
            break
    return model, history, epochs

print('Training helper ready')

Training helper ready


In [None]:
# Run the sweep: P=64, dims 1..64, up to 400 epochs each (stops when first perfect model found)
P = 64
x, y = generate_modular_addition_data(P)
print('Dataset size:', x.size(0))

# We'll sweep dims from no-bottleneck (d=P) down to d=1 so the slider can go P->1
dims = list(range(P, 0, -1))
max_epochs = 600
wd = 3e-4
lr = 3e-3
models_state = {}
int_mats, val_acc = {}, {}
remainders = [0, 5, 22, 42, 63]
# create splits once (same dataset across dims)
train_loader, val_loader, test_pair = make_splits(x, y, train_frac=0.75, batch_size=64)
x_test, y_test = test_pair
generalizes = True
for d in dims:
    print(f'\nTraining d_hidden={d}')
    m = init_model(P, d)
    trained, hist, used_epochs = train_model(m, train_loader, val_loader, epochs=max_epochs, 
                                             lr=lr, weight_decay=wd)
    val_acc[d] = hist['val_acc'][-1] if len(hist['val_acc'])>0 else 0.0
    if generalizes and val_acc[d] < 1.0:
        print(f"Model stops fully generalizing at bottleneck dim {d}.")
        print(f"Validation accuracy: {val_acc[d]:.3f}")
        generalizes = False
    elif not generalizes and val_acc[d] == 1.0:
        print(f"Model generalizes again at bottleneck dim {d}.")
        generalizes = True
    else:
        print(f"Validation accuracy: {val_acc[d]:.3f}")
    # save state dict on CPU to keep memory reasonable
    models_state[d] = {k: v.cpu().clone() for k, v in trained.state_dict().items()}
    # compute interaction matrices for the remainders
    mats = []
    with torch.no_grad():
        for r in remainders:
            q = einsum(trained.w_p[r].to('cpu'), trained.w_l.to('cpu'), trained.w_r.to('cpu'), 'hid, hid in1, hid in2 -> in1 in2')
            q = 0.5 * (q + q.mT)
            mats.append(q)
    int_mats[d] = torch.stack(mats, dim=0)

# Save to disk
date = datetime.datetime.now()
with open(f'sweep_results_{date.strptime("%d%m")}.pkl', 'wb') as f:
    pickle.dump({'val_accs': val_acc, 'models': models_state, 'int_mats': int_mats, 'remainders': remainders}, f)
print('Saved sweep_results.pkl')

Dataset size: 4096

Training d_hidden=64


epochs:   0%|          | 0/600 [00:00<?, ?it/s]

epochs: 100%|██████████| 600/600 [01:13<00:00,  8.15it/s]


Model stops fully generalizing at bottleneck dim 64.
Validation accuracy: 0.996

Training d_hidden=63


epochs:  43%|████▎     | 258/600 [00:32<00:42,  7.96it/s]


Grokking at epoch 258.
Model generalizes again at bottleneck dim 63.

Training d_hidden=62


epochs:  89%|████████▊ | 532/600 [01:07<00:08,  7.84it/s]


Grokking at epoch 532.
Validation accuracy: 1.000

Training d_hidden=61


epochs:  85%|████████▌ | 510/600 [01:07<00:11,  7.60it/s]


Grokking at epoch 510.
Validation accuracy: 1.000

Training d_hidden=60


epochs: 100%|██████████| 600/600 [01:20<00:00,  7.41it/s]


Model stops fully generalizing at bottleneck dim 60.
Validation accuracy: 0.999

Training d_hidden=59


epochs:  49%|████▊     | 292/600 [00:40<00:42,  7.24it/s]


Grokking at epoch 292.
Model generalizes again at bottleneck dim 59.

Training d_hidden=58


epochs: 100%|██████████| 600/600 [01:14<00:00,  8.03it/s]


Model stops fully generalizing at bottleneck dim 58.
Validation accuracy: 0.996

Training d_hidden=57


epochs: 100%|██████████| 600/600 [01:21<00:00,  7.39it/s]


Validation accuracy: 0.997

Training d_hidden=56


epochs: 100%|██████████| 600/600 [01:17<00:00,  7.74it/s]


Validation accuracy: 0.998

Training d_hidden=55


epochs:  33%|███▎      | 199/600 [00:28<00:58,  6.91it/s]


Grokking at epoch 199.
Model generalizes again at bottleneck dim 55.

Training d_hidden=54


epochs: 100%|██████████| 600/600 [01:26<00:00,  6.94it/s]


Model stops fully generalizing at bottleneck dim 54.
Validation accuracy: 0.998

Training d_hidden=53


epochs: 100%|██████████| 600/600 [01:22<00:00,  7.26it/s]


Validation accuracy: 0.993

Training d_hidden=52


epochs: 100%|██████████| 600/600 [01:15<00:00,  7.94it/s]


Validation accuracy: 0.993

Training d_hidden=51


epochs:  74%|███████▎  | 441/600 [00:53<00:19,  8.19it/s]


Grokking at epoch 441.
Model generalizes again at bottleneck dim 51.

Training d_hidden=50


epochs: 100%|██████████| 600/600 [01:13<00:00,  8.20it/s]


Model stops fully generalizing at bottleneck dim 50.
Validation accuracy: 0.995

Training d_hidden=49


epochs:  69%|██████▉   | 413/600 [00:50<00:22,  8.19it/s]


Grokking at epoch 413.
Model generalizes again at bottleneck dim 49.

Training d_hidden=48


epochs: 100%|██████████| 600/600 [01:13<00:00,  8.20it/s]


Model stops fully generalizing at bottleneck dim 48.
Validation accuracy: 0.997

Training d_hidden=47


epochs: 100%|██████████| 600/600 [01:14<00:00,  8.06it/s]


Validation accuracy: 0.992

Training d_hidden=46


epochs: 100%|██████████| 600/600 [01:16<00:00,  7.83it/s]


Validation accuracy: 0.994

Training d_hidden=45


epochs:  69%|██████▉   | 413/600 [00:50<00:22,  8.21it/s]


Grokking at epoch 413.
Model generalizes again at bottleneck dim 45.

Training d_hidden=44


epochs:  81%|████████  | 484/600 [00:59<00:14,  8.20it/s]


Grokking at epoch 484.
Validation accuracy: 1.000

Training d_hidden=43


epochs: 100%|██████████| 600/600 [01:16<00:00,  7.81it/s]


Model stops fully generalizing at bottleneck dim 43.
Validation accuracy: 0.999

Training d_hidden=42


epochs: 100%|██████████| 600/600 [01:11<00:00,  8.39it/s]


Validation accuracy: 0.991

Training d_hidden=41


epochs: 100%|██████████| 600/600 [01:13<00:00,  8.21it/s]


Validation accuracy: 0.997

Training d_hidden=40


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.31it/s]


Validation accuracy: 0.997

Training d_hidden=39


epochs: 100%|██████████| 600/600 [01:14<00:00,  8.10it/s]


Validation accuracy: 0.994

Training d_hidden=38


epochs:  78%|███████▊  | 468/600 [01:00<00:16,  7.79it/s]


Grokking at epoch 468.
Model generalizes again at bottleneck dim 38.

Training d_hidden=37


epochs: 100%|██████████| 600/600 [01:15<00:00,  7.94it/s]


Model stops fully generalizing at bottleneck dim 37.
Validation accuracy: 0.991

Training d_hidden=36


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.22it/s]


Validation accuracy: 0.994

Training d_hidden=35


epochs: 100%|██████████| 600/600 [01:16<00:00,  7.87it/s]


Validation accuracy: 0.997

Training d_hidden=34


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.22it/s]


Validation accuracy: 0.975

Training d_hidden=33


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.24it/s]


Validation accuracy: 0.997

Training d_hidden=32


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.25it/s]


Validation accuracy: 0.997

Training d_hidden=31


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.26it/s]


Validation accuracy: 0.998

Training d_hidden=30


epochs: 100%|██████████| 600/600 [01:17<00:00,  7.70it/s]


Validation accuracy: 0.991

Training d_hidden=29


epochs: 100%|██████████| 600/600 [01:10<00:00,  8.52it/s]


Validation accuracy: 0.993

Training d_hidden=28


epochs: 100%|██████████| 600/600 [01:03<00:00,  9.52it/s]


Validation accuracy: 0.993

Training d_hidden=27


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.23it/s]


Validation accuracy: 0.987

Training d_hidden=26


epochs: 100%|██████████| 600/600 [01:09<00:00,  8.69it/s]


Validation accuracy: 0.997

Training d_hidden=25


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.25it/s]


Validation accuracy: 0.988

Training d_hidden=24


epochs: 100%|██████████| 600/600 [01:11<00:00,  8.43it/s]


Validation accuracy: 0.985

Training d_hidden=23


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.27it/s]


Validation accuracy: 0.997

Training d_hidden=22


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.27it/s]


Validation accuracy: 0.998

Training d_hidden=21


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.29it/s]


Validation accuracy: 0.996

Training d_hidden=20


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.28it/s]


Validation accuracy: 0.961

Training d_hidden=19


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.29it/s]


Validation accuracy: 0.976

Training d_hidden=18


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.25it/s]


Validation accuracy: 0.967

Training d_hidden=17


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.23it/s]


Validation accuracy: 0.987

Training d_hidden=16


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.28it/s]


Validation accuracy: 0.972

Training d_hidden=15


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.27it/s]


Validation accuracy: 0.963

Training d_hidden=14


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.29it/s]


Validation accuracy: 0.934

Training d_hidden=13


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.27it/s]


Validation accuracy: 0.959

Training d_hidden=12


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.31it/s]


Validation accuracy: 0.931

Training d_hidden=11


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.29it/s]


Validation accuracy: 0.967

Training d_hidden=10


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.27it/s]


Validation accuracy: 0.937

Training d_hidden=9


epochs: 100%|██████████| 600/600 [01:12<00:00,  8.23it/s]


Validation accuracy: 0.856

Training d_hidden=8


epochs: 100%|██████████| 600/600 [01:01<00:00,  9.80it/s]


Validation accuracy: 0.938

Training d_hidden=7


epochs: 100%|██████████| 600/600 [01:07<00:00,  8.85it/s]


Validation accuracy: 0.895

Training d_hidden=6


epochs: 100%|██████████| 600/600 [01:11<00:00,  8.44it/s]


Validation accuracy: 0.827

Training d_hidden=5


epochs: 100%|██████████| 600/600 [01:11<00:00,  8.43it/s]


Validation accuracy: 0.769

Training d_hidden=4


epochs: 100%|██████████| 600/600 [01:10<00:00,  8.48it/s]


Validation accuracy: 0.873

Training d_hidden=3


epochs: 100%|██████████| 600/600 [01:10<00:00,  8.47it/s]


Validation accuracy: 0.419

Training d_hidden=2


epochs: 100%|██████████| 600/600 [01:11<00:00,  8.42it/s]


Validation accuracy: 0.159

Training d_hidden=1


epochs: 100%|██████████| 600/600 [01:11<00:00,  8.39it/s]


Validation accuracy: 0.026
Saved compression_sweep_results.pkl


Looking at the more compressed values (d=16), one seees that the main circuits is still visible  in the weights, but similar circuits are also visible, albeit at a smaller scale. 