# ResNet-BK vs Mamba: Long-Context Stability (Colab)

Two-mode run: 
- Step1 (baseline): USE_THEORY=False → check if ResNet-BK is at least as stable as Mamba.
- Step2 (theory ON): USE_THEORY=True → show the gap widens with scattering/Birman-Schwinger.

Prerequisites: GPU runtime. If using Colab, set Runtime -> Change runtime type -> GPU.

In [ ]:
# Stable install: Torch 2.3.1 cu121 + deps pinned for transformers/sklearn/scipy
# If you just upgraded pip/setuptools, restart runtime once before running below.
!pip install --upgrade --no-cache-dir pip setuptools wheel ninja packaging cmake jedi
# Pin numpy/scipy/sklearn to avoid ABI issues on Colab (Py3.12)
!pip install --force-reinstall --no-cache-dir numpy==2.1.4 scipy==1.13.1 scikit-learn==1.5.2 transformers==4.43.4 datasets==2.20.0 matplotlib==3.8.4
!pip install --force-reinstall --no-cache-dir torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
%env FORCE_CUDA=1
%env MAX_JOBS=4
%env TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6"
!pip install --no-cache-dir --no-build-isolation mamba-ssm==2.2.2 --extra-index-url https://download.pytorch.org/whl/cu121


In [None]:
import platform, torch
print("Python:", platform.python_version())
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

Install dependencies

In [None]:
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q datasets transformers matplotlib seaborn tqdm
!pip install -q mamba-ssm

Clone repo and add to path

In [None]:
# Repo setup (clone if needed, add to sys.path)
import os, sys, subprocess, pathlib
REPO_URL = 'https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git'
REPO_DIR = 'Project-ResNet-BK-An-O-N-Language-Model-Architecture'
cwd = pathlib.Path.cwd()
candidates = [cwd, cwd.parent, cwd / REPO_DIR, cwd.parent / REPO_DIR]
root = next((p for p in candidates if (p / 'src').exists()), None)
if root is None:
    root = cwd / REPO_DIR
    if not root.exists():
        subprocess.run(['git', 'clone', REPO_URL, str(root)], check=True)
if root != pathlib.Path.cwd():
    os.chdir(root)
root_str = str(pathlib.Path.cwd())
if root_str not in sys.path:
    sys.path.insert(0, root_str)
print('PWD:', root_str)


Imports and shared helpers

In [None]:
import itertools, json, random, time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
from types import SimpleNamespace

from src.models.resnet_bk import LanguageModel as ResNetBK
from src.models.mamba_baseline import MambaLM, create_mamba_from_resnetbk_config

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.benchmark = True

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

Configuration (toggle theory ON/OFF)

In [None]:
# Toggle for theory features (scattering router + Birman-Schwinger)
USE_THEORY = False  # Step1: keep False. Step2: set True and re-run the notebook.
RUN_TAG = "theory_on" if USE_THEORY else "vanilla"

BASE_CONFIG = {
    "dataset": {"name": "wikitext", "config": "wikitext-2-raw-v1"},
    "tokenizer_name": "gpt2",
    "training": {
        "learning_rate": 3e-4,
        "min_lr": 1e-5,
        "weight_decay": 0.01,
        "max_steps": 200,
        "log_every": 20,
        "grad_clip": 1.0,
        "batch_size": 2,
        "seed": 42,
        "use_amp": True,
    },
    "model": {
        "d_model": 256,
        "n_layers": 6,
        "num_experts": 4,
        "top_k": 1,
        "dropout": 0.1,
    },
}

# Default sweep includes 8k/32k/131k
SEQ_LENGTHS = [8192, 32768, 131072]
SEEDS = [42, 43, 44]
PER_SEQ_STEPS = {8192: 200, 32768: 120, 131072: 60}
SAVE_PLOTS = True

Dataset and dataloader

In [None]:
from functools import lru_cache
_tokenizer = None

def get_tokenizer():
    global _tokenizer
    if _tokenizer is None:
        tok = AutoTokenizer.from_pretrained(BASE_CONFIG["tokenizer_name"])
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token
        _tokenizer = tok
    return _tokenizer

@lru_cache()
def load_lm_dataset(seq_length):
    tokenizer = get_tokenizer()
    raw = load_dataset(BASE_CONFIG["dataset"]["name"], BASE_CONFIG["dataset"]["config"])
    def tok_fn(examples):
        return tokenizer(examples["text"], add_special_tokens=False)
    tokenized = raw["train"].map(tok_fn, batched=True, remove_columns=["text"])
    seq_plus_one = seq_length + 1
    def group_texts(examples):
        concatenated = list(itertools.chain.from_iterable(examples["input_ids"]))
        total_length = len(concatenated) // seq_plus_one * seq_plus_one
        concatenated = concatenated[:total_length]
        result = [concatenated[i:i+seq_plus_one] for i in range(0, total_length, seq_plus_one)]
        return {"input_ids": result}
    grouped = tokenized.map(group_texts, batched=True, remove_columns=tokenized["train"].column_names)
    grouped.set_format(type="torch", columns=["input_ids"])
    return grouped["train"]

def make_dataloader(seq_length, batch_size, seed):
    dataset = load_lm_dataset(seq_length)
    g = torch.Generator().manual_seed(seed)
    def collate(batch):
        inputs = torch.stack([b["input_ids"][:-1] for b in batch])
        targets = torch.stack([b["input_ids"][1:] for b in batch])
        return inputs, targets
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, generator=g, collate_fn=collate)

Model builders (respect USE_THEORY toggle)

In [None]:
from types import SimpleNamespace
from src.models.resnet_bk import LanguageModel as ResNetBK
from src.models.mamba_baseline import MambaLM, create_mamba_from_resnetbk_config

def build_resnetbk(seq_length):
    cfg = BASE_CONFIG["model"]
    tok = get_tokenizer()
    return ResNetBK(
        vocab_size=tok.vocab_size,
        d_model=cfg["d_model"],
        n_layers=cfg["n_layers"],
        n_seq=seq_length,
        num_experts=cfg["num_experts"],
        top_k=cfg["top_k"],
        dropout_p=cfg["dropout"],
        use_scattering_router=USE_THEORY,
        use_birman_schwinger=USE_THEORY,
    )

def build_mamba(seq_length):
    cfg = BASE_CONFIG["model"]
    tok = get_tokenizer()
    resnet_cfg = SimpleNamespace(
        vocab_size=tok.vocab_size,
        d_model=cfg["d_model"],
        n_layers=cfg["n_layers"],
        n_seq=seq_length,
        dropout=cfg["dropout"],
        tie_weights=True,
    )
    mamba_cfg = create_mamba_from_resnetbk_config(resnet_cfg)
    return MambaLM(mamba_cfg)

Training loop

In [None]:
def train_one(model_name, builder, seq_length, max_steps, seed):
    cfg = BASE_CONFIG["training"]
    batch_size = cfg["batch_size"]
    if seq_length >= 32768:
        batch_size = max(1, batch_size // 2)
    set_seed(seed)
    dataloader = make_dataloader(seq_length, batch_size=batch_size, seed=seed)
    model = builder(seq_length).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg["learning_rate"], betas=(0.9, 0.999), weight_decay=cfg["weight_decay"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max_steps, eta_min=cfg["min_lr"])
    scaler = torch.cuda.amp.GradScaler(enabled=cfg["use_amp"] and DEVICE == "cuda")
    losses = []
    wall_start = time.time()
    for step, (inputs, targets) in enumerate(dataloader):
        if step >= max_steps:
            break
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=cfg["use_amp"] and DEVICE == "cuda"):
            logits = model(inputs)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        if not torch.isfinite(loss):
            print(f"{model_name} divergence at step {step+1} loss={loss.item():.4f}")
            break
        scaler.scale(loss).backward()
        if cfg["grad_clip"]:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
        scaler.step(opt); scaler.update(); scheduler.step()
        losses.append(loss.item())
        if (step + 1) % cfg["log_every"] == 0:
            print(f"{model_name} step {step+1}/{max_steps} loss={loss.item():.4f} lr={scheduler.get_last_lr()[0]:.2e} bs={batch_size} seq={seq_length} seed={seed} tag={RUN_TAG}")
    return {
        "model": model_name,
        "seq_length": seq_length,
        "losses": losses,
        "steps": len(losses),
        "batch_size": batch_size,
        "seed": seed,
        "run_tag": RUN_TAG,
        "use_theory": USE_THEORY,
        "wall_clock_sec": time.time() - wall_start,
    }

Run experiments across seeds and sequence lengths

In [None]:
        def run_headtohead(seq_lengths=None, seeds=None):
            seq_lengths = seq_lengths or SEQ_LENGTHS
            seeds = seeds or SEEDS
            all_results = {}
            for seq_len in seq_lengths:
                max_steps = PER_SEQ_STEPS.get(seq_len, BASE_CONFIG["training"]["max_steps"])
                print(f"
=== Sequence length {seq_len} | steps {max_steps} | seeds {seeds} | tag {RUN_TAG} ===")
                seed_results = []
                for seed in seeds:
                    torch.cuda.empty_cache()
                    print(f"-- seed {seed} | ResNet-BK")
                    resnet_result = train_one("resnet_bk", build_resnetbk, seq_len, max_steps, seed)
                    torch.cuda.empty_cache()
                    print(f"-- seed {seed} | Mamba")
                    mamba_result = train_one("mamba", build_mamba, seq_len, max_steps, seed)
                    seed_results.append({"seed": seed, "resnet_bk": resnet_result, "mamba": mamba_result})
                all_results[seq_len] = seed_results
                out_path = Path(f"colab_long_context_{seq_len}_{RUN_TAG}_seeds.json")
                out_path.write_text(json.dumps(all_results[seq_len], indent=2))
                print("Saved", out_path)
            return all_results

        results = {}

Run (defaults include 8k/32k/131k and seeds 42,43,44)

In [None]:
# Step1 (baseline): USE_THEORY=False → run below
# results = run_headtohead()

# Step2 (theory ON): set USE_THEORY=True in the config cell above, re-run all cells, then run below
# results = run_headtohead()

Plot losses (per-seed overlays, saved with tag)

In [None]:
def plot_losses(results_dict, seq_length, save_fig=SAVE_PLOTS):
    data = results_dict[seq_length]
    plt.figure(figsize=(10,5))
    for entry in data:
        seed = entry["seed"]
        for name, color in [("resnet_bk", "blue"), ("mamba", "red")]:
            losses = entry[name]["losses"]
            steps = range(1, len(losses)+1)
            plt.plot(steps, losses, label=f"{name}-seed{seed}-{RUN_TAG}", color=color, alpha=0.4 if name=="mamba" else 0.7)
    plt.xlabel("Step"); plt.ylabel("Loss"); plt.title(f"Seq {seq_length} ({RUN_TAG})"); plt.legend(); plt.grid(alpha=0.3)
    if save_fig:
        fname = f"loss_{seq_length}_{RUN_TAG}.png"
        plt.savefig(fname, dpi=200, bbox_inches="tight")
        print("Saved", fname)
    plt.show()

# After running:
# plot_losses(results, 8192)
# plot_losses(results, 32768)
# plot_losses(results, 131072)

Zip artifacts (JSON/PNG)

In [None]:
import shutil, glob
def zip_artifacts(prefix=None):
    prefix = prefix or f"artifacts_long_context_{RUN_TAG}"
    targets = glob.glob(f"colab_long_context_*_{RUN_TAG}_seeds.json") + glob.glob(f"loss_*_{RUN_TAG}.png")
    if not targets:
        print("No artifacts found yet.")
        return
    shutil.make_archive(prefix, "zip", ".")
    print("Created", f"{prefix}.zip", "with", targets)

# After plots:
# zip_artifacts()