# ResNet-BK vs Mamba: Quantization (INT8 / Fake INT4)

- Runs a small train, then evaluates PPL in FP32/INT8/fake-INT4.
- Fake INT4 is per-tensor clipping on Linear weights (reference only).
- Outputs JSON + optional ZIP of artifacts.

Prerequisites: GPU runtime recommended for speed.

In [ ]:
# Stable install: Torch 2.3.1 cu121 + mamba-ssm build-friendly
# If you just upgraded pip/setuptools, restart runtime once before running below.
!pip install --upgrade --no-cache-dir pip setuptools wheel ninja packaging cmake jedi
!pip install --force-reinstall --no-cache-dir torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
# Prepare a minimal build path for mamba-ssm (no heredoc)
%env FORCE_CUDA=1
%env MAX_JOBS=4
%env TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6"
!pip install --no-cache-dir --no-build-isolation mamba-ssm==2.2.2 --extra-index-url https://download.pytorch.org/whl/cu121


In [None]:
import platform, torch
print("Python:", platform.python_version())
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

Install deps

In [None]:
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q datasets transformers matplotlib tqdm

Clone repo

In [None]:
# Repo setup (clone if needed, add to sys.path)
import os, sys, subprocess, pathlib
REPO_URL = 'https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git'
REPO_DIR = 'Project-ResNet-BK-An-O-N-Language-Model-Architecture'
cwd = pathlib.Path.cwd()
candidates = [cwd, cwd.parent, cwd / REPO_DIR, cwd.parent / REPO_DIR]
root = next((p for p in candidates if (p / 'src').exists()), None)
if root is None:
    root = cwd / REPO_DIR
    if not root.exists():
        subprocess.run(['git', 'clone', REPO_URL, str(root)], check=True)
if root != pathlib.Path.cwd():
    os.chdir(root)
root_str = str(pathlib.Path.cwd())
if root_str not in sys.path:
    sys.path.insert(0, root_str)
print('PWD:', root_str)


Imports + config

In [None]:
import itertools, json, random
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from types import SimpleNamespace

from src.models.resnet_bk import LanguageModel as ResNetBK
from src.models.mamba_baseline import MambaLM, create_mamba_from_resnetbk_config

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RUN_TAG = "quant"

BASE_CONFIG = {
    "tokenizer_name": "gpt2",
    "dataset": {"name": "wikitext", "config": "wikitext-2-raw-v1"},
    "model": {"d_model": 256, "n_layers": 4, "num_experts": 2, "top_k": 1, "dropout": 0.1},
    "training": {"steps": 80, "lr": 3e-4, "weight_decay": 0.01, "batch_size": 2, "seed": 42},
}

Data helpers

In [None]:
def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_tokenizer():
    tok = AutoTokenizer.from_pretrained(BASE_CONFIG["tokenizer_name"])
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    return tok

def load_data(seq_length):
    tok = get_tokenizer()
    raw = load_dataset(BASE_CONFIG["dataset"]["name"], BASE_CONFIG["dataset"]["config"])
    def tok_fn(examples):
        return tok(examples["text"], add_special_tokens=False)
    tokenized = raw["train"].map(tok_fn, batched=True, remove_columns=["text"])
    seq_plus_one = seq_length + 1
    def group(examples):
        concat = list(itertools.chain.from_iterable(examples["input_ids"]))
        total = len(concat) // seq_plus_one * seq_plus_one
        concat = concat[:total]
        return {"input_ids": [concat[i:i+seq_plus_one] for i in range(0, total, seq_plus_one)]}
    grouped = tokenized.map(group, batched=True, remove_columns=tokenized["train"].column_names)
    grouped.set_format(type="torch", columns=["input_ids"])
    return grouped["train"], raw["validation"], tok

def make_loader(dataset, batch_size, seed):
    g = torch.Generator().manual_seed(seed)
    def collate(batch):
        inputs = torch.stack([b["input_ids"][:-1] for b in batch])
        targets = torch.stack([b["input_ids"][1:] for b in batch])
        return inputs, targets
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, generator=g, collate_fn=collate)

Model builders

In [None]:
def build_models(seq_length, vocab_size):
    cfg = BASE_CONFIG["model"]
    bk = ResNetBK(vocab_size=vocab_size, d_model=cfg["d_model"], n_layers=cfg["n_layers"], n_seq=seq_length, num_experts=cfg["num_experts"], top_k=cfg["top_k"], dropout_p=cfg["dropout"], use_scattering_router=False, use_birman_schwinger=False)
    res_cfg = SimpleNamespace(vocab_size=vocab_size, d_model=cfg["d_model"], n_layers=cfg["n_layers"], n_seq=seq_length, dropout=cfg["dropout"], tie_weights=True)
    mb_cfg = create_mamba_from_resnetbk_config(res_cfg)
    mb = MambaLM(mb_cfg)
    return bk, mb

Train + eval

In [None]:
def train_small(model, loader):
    cfg = BASE_CONFIG["training"]
    model = model.to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
    losses = []
    model.train()
    for step, (inp, tgt) in enumerate(loader):
        if step >= cfg["steps"]:
            break
        inp = inp.to(DEVICE); tgt = tgt.to(DEVICE)
        opt.zero_grad()
        logits = model(inp)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1))
        loss.backward(); opt.step()
        losses.append(loss.item())
        if (step+1) % 20 == 0:
            print("step", step+1, "loss", loss.item())
    return losses

@torch.no_grad()
def eval_ppl(model, loader):
    model = model.to(DEVICE)
    model.eval()
    total_loss = 0; total_tokens = 0
    for inp, tgt in loader:
        inp = inp.to(DEVICE); tgt = tgt.to(DEVICE)
        logits = model(inp)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), reduction='sum')
        total_loss += loss.item()
        total_tokens += tgt.numel()
        if total_tokens > 200000:
            break
    return float(torch.exp(torch.tensor(total_loss / total_tokens)))

Quantizers

In [None]:
def quantize_int8_linear(model):
    import torch.ao.quantization as tq
    qmodel = tq.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    return qmodel

def quantize_int4_linear(model):
    import copy
    qmodel = copy.deepcopy(model).cpu()
    for _, mod in qmodel.named_modules():
        if isinstance(mod, torch.nn.Linear):
            w = mod.weight.data
            scale = w.abs().max() / 7.0 + 1e-8
            q = torch.clamp(torch.round(w / scale), -8, 7)
            mod.weight.data = (q * scale).to(mod.weight.dtype)
    return qmodel

Run experiment

In [None]:
seq_length = 2048
train_data, val_raw, tok = load_data(seq_length)
train_loader = make_loader(train_data, BASE_CONFIG["training"]["batch_size"], BASE_CONFIG["training"]["seed"])
val_tokenized = val_raw.map(lambda ex: tok(ex["text"], add_special_tokens=False), batched=True, remove_columns=["text"])
val_tokenized = val_tokenized.map(lambda ex: {"input_ids": [ids[:seq_length+1] for ids in ex["input_ids"] if len(ids) >= seq_length+1]}, batched=True)
val_tokenized = val_tokenized.filter(lambda ex: len(ex["input_ids"]) > 0)
val_tokenized.set_format(type="torch", columns=["input_ids"])
val_loader = make_loader(val_tokenized["validation"], batch_size=1, seed=0)

bk, mb = build_models(seq_length, tok.vocab_size)
print("Training ResNet-BK")
train_small(bk, train_loader)
print("Training Mamba")
train_small(mb, train_loader)

print("Eval FP32")
ppl_bk_fp32 = eval_ppl(bk, val_loader)
ppl_mb_fp32 = eval_ppl(mb, val_loader)

print("Quantize INT8")
bk_int8 = quantize_int8_linear(bk)
mb_int8 = quantize_int8_linear(mb)
ppl_bk_int8 = eval_ppl(bk_int8, val_loader)
ppl_mb_int8 = eval_ppl(mb_int8, val_loader)

print("Quantize INT4 (fake-quant)")
bk_int4 = quantize_int4_linear(bk)
mb_int4 = quantize_int4_linear(mb)
ppl_bk_int4 = eval_ppl(bk_int4, val_loader)
ppl_mb_int4 = eval_ppl(mb_int4, val_loader)

results = {
    "seq_length": seq_length,
    "ppl": {
        "resnet_bk": {"fp32": ppl_bk_fp32, "int8": ppl_bk_int8, "int4_fake": ppl_bk_int4},
        "mamba": {"fp32": ppl_mb_fp32, "int8": ppl_mb_int8, "int4_fake": ppl_mb_int4},
    },
}
Path(f"colab_quant_results_{RUN_TAG}.json").write_text(json.dumps(results, indent=2))
print(json.dumps(results, indent=2))

Zip artifacts

In [None]:
import shutil, glob
def zip_artifacts(prefix=None):
    prefix = prefix or f"artifacts_quant_{RUN_TAG}"
    targets = glob.glob(f"colab_quant_results_{RUN_TAG}.json")
    if not targets:
        print("No artifacts found yet.")
        return
    shutil.make_archive(prefix, "zip", ".")
    print("Created", f"{prefix}.zip", "with", targets)

# After running:
# zip_artifacts()