# Notebook: Scaling Law Analysis of Spectral vs. Vanilla Attention

**Objective:** To empirically validate the computational complexity and performance of Spectral Attention (`O(n log n)`) against standard Self-Attention (`O(n^2)`).

We will perform a micro-benchmark by running both models across a range of sequence lengths and measure three key metrics:
1.  **Throughput** (tokens/second)
2.  **Latency** (ms/iteration)
3.  **Peak GPU Memory** (MB)

The expectation is to observe a significant performance advantage for Spectral Attention, especially at longer sequence lengths.

In [None]:
#
# Cell 2: Setup and Configuration
#
import os
import sys
import subprocess
import json
import importlib
import platform
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# --- Resolve script path robustly ---
if os.path.exists("scripts/bench_spectral_attention.py"):
    BENCH_SCRIPT = os.path.abspath("scripts/bench_spectral_attention.py")
else:
    # If running from notebooks/ directory, go up one level
    maybe = os.path.abspath(os.path.join(os.getcwd(), "..", "scripts", "bench_spectral_attention.py"))
    if os.path.exists(maybe):
        BENCH_SCRIPT = maybe
    else:
        raise FileNotFoundError("Could not locate scripts/bench_spectral_attention.py from current working directory")

# --- Experiment Configuration ---

# Path to the log file where results will be stored
# The provided README mentions logs are stored in 'experiments/runs/**'.
# We will create a specific directory for this analysis.
LOG_DIR = "experiments/runs/scaling_analysis"
LOG_FILE = os.path.join(LOG_DIR, "metrics.jsonl")

# Create directory if it doesn't exist
os.makedirs(LOG_DIR, exist_ok=True)

# Clear previous log file if it exists
if os.path.exists(LOG_FILE):
    os.remove(LOG_FILE)
    print(f"Removed old log file: {LOG_FILE}")

# --- Model & Benchmark Parameters ---

# Sequence lengths to test. We'll go up to a point where vanilla attention likely fails.
SEQ_LENGTHS = [512, 1024, 2048, 4096, 8192, 16384]
VANILLA_MAX_SEQ = 4096 # The sequence length at which vanilla attention often runs out of memory on consumer GPUs

# Fixed model parameters for a fair comparison
# These parameters are based on the examples in the README.
DMODEL = 512
HEADS = 8
DEPTH = 6
BATCH = 4
DEVICE = "gpu"  # Force GPU if available
USE_COMPILE = False  # Disable torch.compile to avoid CUDA/Triton conflicts

def get_compile_flag():
    """Return ['--compile'] only when explicitly enabled and environment supports it."""
    if not USE_COMPILE:
        return []
    # Require CUDA + Triton + non-Windows (inductor+triton works best on Linux)
    cuda_ok = False
    try:
        import torch
        cuda_ok = torch.cuda.is_available()
    except Exception:
        cuda_ok = False
    triton_ok = importlib.util.find_spec("triton") is not None
    on_windows = platform.system().lower().startswith("win")
    if cuda_ok and triton_ok and not on_windows:
        return ["--compile"]
    print("[warn] --compile disabled at runtime: CUDA/Triton not available or unsupported OS")
    return []

def effective_batch(seq:int, kind:str) -> int:
    """Use smaller batch for vanilla at long seq to avoid OOM on consumer GPUs."""
    if kind == "vanilla" and seq >= VANILLA_MAX_SEQ:
        return max(1, BATCH // 4)  # drop batch when seq is very large
    return BATCH

print("Setup complete. Configuration is ready.")

Setup complete. Configuration is ready.


### Running the Benchmark

Now, we will execute the benchmark script (`scripts/bench_spectral_attention.py`) using the parameters defined above. We will loop through each sequence length and run the benchmark for both the `spectral` and `vanilla` models.

The output of each run will be appended to our log file (`experiments/runs/scaling_analysis/metrics.jsonl`). We will stop benchmarking the `vanilla` model after it reaches its maximum defined sequence length to avoid out-of-memory errors.

In [26]:
#
# Cell 4: Execute Benchmark Script
#
import sys
print("Starting benchmark execution...")


def run_cmd(cmd: list[str]):
    proc = subprocess.run(cmd, capture_output=True, text=True)
    if proc.returncode != 0:
        print("[stderr]\n" + (proc.stderr or ""))
        print("[stdout]\n" + (proc.stdout or ""))
        raise RuntimeError(f"Command failed ({proc.returncode}): {' '.join(cmd)}")
    if proc.stdout:
        print(proc.stdout)

for seq in SEQ_LENGTHS:
    # --- Run Spectral Attention Benchmark ---
    print(f"-> Running SPECTRAL with Sequence Length: {seq}")
    spectral_command = [
        sys.executable, BENCH_SCRIPT,
        "--kind", "spectral",
        "--seq", str(seq),
        "--dmodel", str(DMODEL),
        "--heads", str(HEADS),
        "--depth", str(DEPTH),
        "--batch", str(effective_batch(seq, "spectral")),
        "--device", DEVICE,
        "--logdir", LOG_DIR,  # Directing output to our specific log directory
        *get_compile_flag(),
    ]
    run_cmd(spectral_command)

    # --- Run Vanilla Attention Benchmark (with sequence length limit) ---
    if seq <= VANILLA_MAX_SEQ:
        print(f"-> Running VANILLA with Sequence Length: {seq}")
        vanilla_command = [
            sys.executable, BENCH_SCRIPT,
            "--kind", "vanilla",
            "--seq", str(seq),
            "--dmodel", str(DMODEL),
            "--heads", str(HEADS),
            "--depth", str(DEPTH),
            "--batch", str(effective_batch(seq, "vanilla")),
            "--device", DEVICE,
            "--logdir", LOG_DIR,
            *get_compile_flag(),
        ]
        run_cmd(vanilla_command)
    else:
        print(f"-> Skipping VANILLA for Sequence Length {seq} (>= {VANILLA_MAX_SEQ})")

print("\nBenchmark execution complete!")

Starting benchmark execution...
-> Running SPECTRAL with Sequence Length: 512
Device=gpu  (tensor on cuda)  B=4  T=512  d_model=512  heads=8  depth=6  kind=spectral
spectral_dct  tokens/s: 219,589   ms/iter: 9.33   peakMB: 102.5

-> Running VANILLA with Sequence Length: 512
Device=gpu  (tensor on cuda)  B=4  T=512  d_model=512  heads=8  depth=6  kind=spectral
spectral_dct  tokens/s: 219,589   ms/iter: 9.33   peakMB: 102.5

-> Running VANILLA with Sequence Length: 512
Device=gpu  (tensor on cuda)  B=4  T=512  d_model=512  heads=8  depth=6  kind=vanilla
     vanilla  tokens/s: 144,775   ms/iter: 14.15   peakMB: 130.3

-> Running SPECTRAL with Sequence Length: 1024
Device=gpu  (tensor on cuda)  B=4  T=512  d_model=512  heads=8  depth=6  kind=vanilla
     vanilla  tokens/s: 144,775   ms/iter: 14.15   peakMB: 130.3

-> Running SPECTRAL with Sequence Length: 1024
Device=gpu  (tensor on cuda)  B=4  T=1024  d_model=512  heads=8  depth=6  kind=spectral
spectral_dct  tokens/s: 225,204   ms/iter: