# ResNet-BK Efficiency: GradientCache (ACT proxy) vs Mamba

Toy-scale FLOPs estimate with/without GradientCache as ACT proxy. Saves JSON + optional ZIP.

In [None]:
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q datasets transformers tqdm

In [None]:
import os, sys, subprocess
REPO_URL = "https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git"
REPO_DIR = "Project-ResNet-BK-An-O-N-Language-Model-Architecture"
if not os.path.exists(REPO_DIR):
    subprocess.run(["git", "clone", REPO_URL], check=True)
os.chdir(REPO_DIR)
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

In [None]:
import itertools, json, random
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from types import SimpleNamespace
from pathlib import Path

from src.models.resnet_bk import LanguageModel as ResNetBK
from src.models.mamba_baseline import MambaLM, create_mamba_from_resnetbk_config
from src.training.gradient_caching import GradientCache

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RUN_TAG = "efficiency"
BASE_CONFIG = {"tokenizer_name": "gpt2", "seq_length": 2048, "train_steps": 40, "batch_size": 2, "lr": 3e-4}

In [None]:
import random

def set_seed(seed):
    random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_tokenizer():
    tok = AutoTokenizer.from_pretrained(BASE_CONFIG["tokenizer_name"])
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    return tok


def make_loader(seq_length, batch_size, seed):
    tok = get_tokenizer()
    raw = load_dataset("wikitext", "wikitext-2-raw-v1")
    def tok_fn(examples):
        return tok(examples["text"], add_special_tokens=False)
    tokenized = raw["train"].map(tok_fn, batched=True, remove_columns=["text"])
    seq_plus_one = seq_length + 1
    def group(examples):
        concat = list(itertools.chain.from_iterable(examples["input_ids"]))
        total = len(concat) // seq_plus_one * seq_plus_one
        concat = concat[:total]
        return {"input_ids": [concat[i:i+seq_plus_one] for i in range(0, total, seq_plus_one)]}
    grouped = tokenized.map(group, batched=True, remove_columns=tokenized["train"].column_names)
    grouped.set_format(type="torch", columns=["input_ids"])
    g = torch.Generator().manual_seed(seed)
    def collate(batch):
        inputs = torch.stack([b["input_ids"][:-1] for b in batch])
        targets = torch.stack([b["input_ids"][1:] for b in batch])
        return inputs, targets
    return DataLoader(grouped["train"], batch_size=batch_size, shuffle=True, drop_last=True, generator=g, collate_fn=collate)

In [None]:
def build_models(seq_length, vocab_size):
    d_model = 256; n_layers = 6; dropout = 0.1
    bk_base = ResNetBK(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_seq=seq_length, num_experts=4, top_k=1, dropout_p=dropout, use_scattering_router=False, use_birman_schwinger=False)
    bk_act = ResNetBK(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_seq=seq_length, num_experts=4, top_k=1, dropout_p=dropout, use_scattering_router=False, use_birman_schwinger=False)
    res_cfg = SimpleNamespace(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_seq=seq_length, dropout=dropout, tie_weights=True)
    mamba = MambaLM(create_mamba_from_resnetbk_config(res_cfg))
    return bk_base, bk_act, mamba

In [None]:
def train_and_profile(model, loader, use_act=False):
    steps = BASE_CONFIG["train_steps"]
    model = model.to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_CONFIG["lr"])
    losses = []
    flop_samples = []
    model.train()
    for step, (inp, tgt) in enumerate(loader):
        if step >= steps:
            break
        inp = inp.to(DEVICE); tgt = tgt.to(DEVICE)
        opt.zero_grad()
        def forward_pass(x):
            return model(x)
        if use_act:
            cache = GradientCache(chunk_size=max(1, inp.shape[1]//2))
            logits = cache(forward_pass, inp)
        else:
            logits = forward_pass(inp)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1))
        loss.backward(); opt.step()
        losses.append(loss.item())
        with torch.autograd.profiler.profile(enabled=True, use_cuda=DEVICE=="cuda") as prof:
            _ = forward_pass(inp)
        flops = sum(e.flops or 0 for e in prof.function_events)
        flop_samples.append(flops)
        if (step+1) % 10 == 0:
            print("step", step+1, "loss", loss.item(), "flops", flops)
    avg_flops = float(sum(flop_samples)/max(1,len(flop_samples)))
    return {"losses": losses, "avg_flops": avg_flops}

Run efficiency exp

In [None]:
loader = make_loader(BASE_CONFIG["seq_length"], BASE_CONFIG["batch_size"], seed=42)
tok = get_tokenizer()
bk_base, bk_act, mamba = build_models(BASE_CONFIG["seq_length"], tok.vocab_size)

res_base = train_and_profile(bk_base, loader, use_act=False)
res_act = train_and_profile(bk_act, loader, use_act=True)
res_mamba = train_and_profile(mamba, loader, use_act=False)

results = {"resnet_bk": res_base, "resnet_bk_act": res_act, "mamba": res_mamba}
Path(f"colab_efficiency_results_{RUN_TAG}.json").write_text(json.dumps(results, indent=2))
print(json.dumps(results, indent=2))

Zip artifacts

In [None]:
import shutil, glob

def zip_artifacts(prefix=None):
    prefix = prefix or f"artifacts_efficiency_{RUN_TAG}"
    targets = glob.glob(f"colab_efficiency_results_{RUN_TAG}.json")
    if not targets:
        print("No artifacts found yet.")
        return
    shutil.make_archive(prefix, "zip", ".")
    print("Created", f"{prefix}.zip", "with", targets)

# After running:
# zip_artifacts()