# ResNet-BK vs Mamba: Colab Head-to-Head

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture/blob/main/notebooks/colab_resnetbk_vs_mamba_headtohead.ipynb)

Runs a fair, reproducible head-to-head between the repo's ResNet-BK implementation and the bundled Mamba baseline. Both models share tokenizer, dataset (WikiText-2 raw), optimizer (AdamW), schedule (cosine), seeds, and logging.
- Quick sanity: run 8k tokens for a few hundred steps.
- Stress: bump `SEQ_LENGTHS` to 32k/128k with smaller batch sizes.
- Outputs: JSON loss traces + matplotlib curves per sequence length.

## Check runtime (GPU/versions)

In [None]:
import platform, torch, os

print(f"Python: {platform.python_version()}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")

## Install dependencies (Colab)
- If you already have the repo checked out, skip the clone step.
- Installation keeps everything CPU-friendly except Mamba kernels (requires GPU runtime).

In [None]:
# Core dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q datasets transformers matplotlib seaborn tqdm
# Optional: official mamba-ssm kernels (not required for the bundled baseline)
!pip install -q mamba-ssm

## Clone repository & set paths

In [None]:
# Repo setup (clone if needed, add to sys.path)
import os, sys, subprocess, pathlib
REPO_URL = 'https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git'
REPO_DIR = 'Project-ResNet-BK-An-O-N-Language-Model-Architecture'
cwd = pathlib.Path.cwd()
candidates = [cwd, cwd.parent, cwd / REPO_DIR, cwd.parent / REPO_DIR]
root = next((p for p in candidates if (p / 'src').exists()), None)
if root is None:
    root = cwd / REPO_DIR
    if not root.exists():
        subprocess.run(['git', 'clone', REPO_URL, str(root)], check=True)
if root != pathlib.Path.cwd():
    os.chdir(root)
root_str = str(pathlib.Path.cwd())
if root_str not in sys.path:
    sys.path.insert(0, root_str)
print('PWD:', root_str)


## Imports & helpers

In [None]:
import itertools, json, math, random, time
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt

from types import SimpleNamespace

from src.models.resnet_bk import LanguageModel as ResNetBK
from src.models.mamba_baseline import MambaLM, create_mamba_from_resnetbk_config

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
torch.backends.cudnn.benchmark = True

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

## Shared configuration

In [None]:
BASE_CONFIG = {
    "dataset": {"name": "wikitext", "config": "wikitext-2-raw-v1"},
    "tokenizer_name": "gpt2",
    "training": {
        "learning_rate": 3e-4,
        "min_lr": 1e-5,
        "weight_decay": 0.01,
        "max_steps": 200,
        "log_every": 20,
        "grad_clip": 1.0,
        "batch_size": 2,
        "seed": 42,
        "use_amp": True,
    },
    "model": {
        "d_model": 256,
        "n_layers": 6,
        "num_experts": 4,
        "top_k": 1,
        "dropout": 0.1,
    },
}

SEQ_LENGTHS = [8192, 32768]
PER_SEQ_STEPS = {8192: 200, 32768: 120, 131072: 40}

## Dataset + dataloader (shared)

In [None]:
from functools import lru_cache

_tokenizer = None

def get_tokenizer():
    global _tokenizer
    if _tokenizer is None:
        tok = AutoTokenizer.from_pretrained(BASE_CONFIG["tokenizer_name"])
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token
        _tokenizer = tok
    return _tokenizer


@lru_cache()
def load_lm_dataset(seq_length: int):
    tokenizer = get_tokenizer()
    raw = load_dataset(BASE_CONFIG["dataset"]["name"], BASE_CONFIG["dataset"]["config"])

    def tok_fn(examples):
        return tokenizer(examples["text"], add_special_tokens=False)

    tokenized = raw["train"].map(tok_fn, batched=True, remove_columns=["text"])

    seq_plus_one = seq_length + 1

    def group_texts(examples):
        concatenated = list(itertools.chain.from_iterable(examples["input_ids"]))
        total_length = len(concatenated) // seq_plus_one * seq_plus_one
        concatenated = concatenated[:total_length]
        result = [concatenated[i : i + seq_plus_one] for i in range(0, total_length, seq_plus_one)]
        return {"input_ids": result}

    grouped = tokenized.map(group_texts, batched=True, remove_columns=tokenized["train"].column_names)
    grouped.set_format(type="torch", columns=["input_ids"])
    return grouped["train"]


def make_dataloader(seq_length: int, batch_size: int, seed: int):
    dataset = load_lm_dataset(seq_length)
    g = torch.Generator()
    g.manual_seed(seed)

    def collate(batch):
        inputs = torch.stack([b["input_ids"][:-1] for b in batch])
        targets = torch.stack([b["input_ids"][1:] for b in batch])
        return inputs, targets

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        generator=g,
        collate_fn=collate,
    )

## Model builders (identical dims)

In [None]:
def build_resnetbk(seq_length: int):
    cfg = BASE_CONFIG["model"]
    tok = get_tokenizer()
    return ResNetBK(
        vocab_size=tok.vocab_size,
        d_model=cfg["d_model"],
        n_layers=cfg["n_layers"],
        n_seq=seq_length,
        num_experts=cfg["num_experts"],
        top_k=cfg["top_k"],
        dropout_p=cfg["dropout"],
        use_scattering_router=False,
        use_birman_schwinger=False,
    )


def build_mamba(seq_length: int):
    cfg = BASE_CONFIG["model"]
    tok = get_tokenizer()
    resnet_cfg = SimpleNamespace(
        vocab_size=tok.vocab_size,
        d_model=cfg["d_model"],
        n_layers=cfg["n_layers"],
        n_seq=seq_length,
        dropout=cfg["dropout"],
        tie_weights=True,
    )
    mamba_cfg = create_mamba_from_resnetbk_config(resnet_cfg)
    return MambaLM(mamba_cfg)

## Training loop (shared optimizer/schedule/logging)

In [None]:
def train_one(model_name, builder, seq_length: int, max_steps: int):
    cfg = BASE_CONFIG["training"]
    batch_size = cfg["batch_size"]
    # keep long-context runs memory-friendly
    if seq_length >= 32768:
        batch_size = max(1, batch_size // 2)

    set_seed(cfg["seed"])
    dataloader = make_dataloader(seq_length, batch_size=batch_size, seed=cfg["seed"])
    model = builder(seq_length).to(DEVICE)

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=cfg["learning_rate"],
        betas=(0.9, 0.999),
        weight_decay=cfg["weight_decay"],
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt, T_max=max_steps, eta_min=cfg["min_lr"]
    )
    scaler = torch.cuda.amp.GradScaler(enabled=cfg["use_amp"] and DEVICE == "cuda")

    losses = []
    wall_start = time.time()

    for step, (inputs, targets) in enumerate(dataloader):
        if step >= max_steps:
            break

        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)
        opt.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=cfg["use_amp"] and DEVICE == "cuda"):
            out = model(inputs)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
            )

        if not torch.isfinite(loss):
            print(f"[{model_name}] divergence detected at step {step+1} (loss={loss.item():.4f})")
            break

        scaler.scale(loss).backward()
        if cfg["grad_clip"] is not None:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
        scaler.step(opt)
        scaler.update()
        scheduler.step()

        losses.append(loss.item())

        if (step + 1) % cfg["log_every"] == 0:
            print(
                f"[{model_name}] step {step+1}/{max_steps} loss={loss.item():.4f} lr={scheduler.get_last_lr()[0]:.2e} bs={batch_size} seq={seq_length}"
            )

    return {
        "model": model_name,
        "seq_length": seq_length,
        "losses": losses,
        "steps": len(losses),
        "batch_size": batch_size,
        "wall_clock_sec": time.time() - wall_start,
    }

## Orchestrate the match

In [None]:
        def run_headtohead(seq_lengths=None):
            seq_lengths = seq_lengths or SEQ_LENGTHS
            all_results = {}
            for seq_len in seq_lengths:
                max_steps = PER_SEQ_STEPS.get(seq_len, BASE_CONFIG["training"]["max_steps"])
                print(f"
=== Sequence length {seq_len} | steps {max_steps} ===")

                torch.cuda.empty_cache()
                resnet_result = train_one("resnet_bk", build_resnetbk, seq_len, max_steps)
                torch.cuda.empty_cache()
                mamba_result = train_one("mamba", build_mamba, seq_len, max_steps)

                all_results[seq_len] = {"resnet_bk": resnet_result, "mamba": mamba_result}

                out_path = Path(f"colab_headtohead_{seq_len}.json")
                out_path.write_text(json.dumps(all_results[seq_len], indent=2))
                print(f"Saved {out_path}")

            return all_results


        # placeholder to hold results across cells
        results = {}

## Run (choose quick or full)

In [None]:
# Quick sanity (8k)
# results = run_headtohead([8192])

# Full sweep (defaults to SEQ_LENGTHS)
# results = run_headtohead()

## Plot losses

In [None]:
def plot_losses(results_dict, seq_length: int):
    data = results_dict[seq_length]
    plt.figure(figsize=(10, 5))
    for name, color in [("resnet_bk", "#1f77b4"), ("mamba", "#d62728")]:
        losses = data[name]["losses"]
        steps = range(1, len(losses) + 1)
        plt.plot(steps, losses, label=name, color=color)
    plt.xlabel("Step")
    plt.ylabel("Cross-entropy loss")
    plt.title(f"Training loss vs steps (seq_len={seq_length})")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()


# Example after running:
# plot_losses(results, 8192)
# plot_losses(results, 32768)

## Inspect raw numbers

In [None]:
# After run_headtohead, run this to see summary
# import pprint
# pprint.pp(results)