In [None]:
# ==============================================================================
# ERROR HANDLING SETUP ‚Äî Full Tracebacks and Formatting
# ==============================================================================
import sys, traceback
sys.tracebacklimit = 50  # show up to 50 frames

def format_exception(e: Exception, context_lines: int = 5) -> str:
    """Format exception with full traceback.

    Args:
        e: Exception instance
        context_lines: Unused placeholder for future code context.
    Returns: String with full traceback.
    """
    tb_lines = traceback.format_exception(type(e), e, e.__traceback__)
    return ''.join(tb_lines)

# Install IPython custom exception handler to avoid truncation
try:
    from IPython import get_ipython
    ip = get_ipython()
    if ip is not None:
        def _custom_exc(shell, etype, evalue, tb, tb_offset=None):
            print('‚ùå Exception occurred')
            print('=' * 60)
            print(''.join(traceback.format_exception(etype, evalue, tb)))
            print('=' * 60)
            # Return None to let IPython handle the exception normally after printing
            return None
        ip.set_custom_exc((Exception,), _custom_exc)
except Exception:
    pass

# ==============================================================================
# NETWORK RETRY MONKEY-PATCH ‚Äî urllib.urlopen with retries (GitHub/HF)
# ==============================================================================
try:
    import urllib.request as _ur, urllib.error as _ue, time as _t, random as _r
    _orig_urlopen = _ur.urlopen
    def _retrying_urlopen(req, timeout=20, max_retries=5, backoff=1.0):
        attempt = 0
        while True:
            try:
                return _orig_urlopen(req, timeout=timeout)
            except _ue.HTTPError as e:
                code = getattr(e, 'code', None)
                if code == 404:
                    raise
                if code in (429, 500, 502, 503, 504, 403):
                    attempt += 1
                    if attempt > max_retries:
                        raise
                    ra = getattr(e, 'headers', {}).get('Retry-After') if hasattr(e, 'headers') else None
                    try:
                        ra_val = float(ra) if ra is not None else None
                    except Exception:
                        ra_val = None
                    sleep_for = ra_val if ra_val is not None else backoff * (2 ** (attempt - 1))
                    sleep_for += _r.random() * 0.25 * sleep_for
                    print(f"‚è≥ Network retry {attempt}/{max_retries} in {sleep_for:.1f}s (HTTP {code})")
                    _t.sleep(min(sleep_for, 30.0))
                    continue
                raise
            except Exception:
                attempt += 1
                if attempt > max_retries:
                    raise
                sleep_for = backoff * (2 ** (attempt - 1))
                sleep_for += _r.random() * 0.25 * sleep_for
                print(f"‚è≥ Network retry {attempt}/{max_retries} in {sleep_for:.1f}s")
                _t.sleep(min(sleep_for, 30.0))
    def urlopen_with_retry(req, timeout=20):
        return _retrying_urlopen(req, timeout=timeout)
    _ur.urlopen = urlopen_with_retry
except Exception:
    pass

---

### üõ†Ô∏è Troubleshooting

- This notebook shows full Python tracebacks (up to 50 frames).
- When an error occurs, you'll see the complete stack to the root cause.
- If a model load fails, check ImportError messages and missing packages.

Tip: You can also call `print(format_exception(e))` inside your own try/except blocks to display a full traceback.

# üß™ Transformer Builder - Advanced Testing Lab

**Welcome! This notebook tests your custom transformer architecture.**

---

## üöÄ **Quick Start (3 Steps)**

### **STEP 1:** Paste Your Gist ID
‚Üì Scroll down to Cell 3 and paste the Gist ID you received from Transformer Builder

### **STEP 2:** Run All Cells  
Click **Runtime ‚Üí Run all** (or run cells one-by-one)

### **STEP 3:** Review Test Results
Your model will be validated through 3 testing tiers

---

## üìã **What's Included:**

- ‚úÖ **Tier 1:** Critical validation (shape, gradients, numerical stability)
- üî¨ **Tier 2:** Advanced analysis (attention patterns, robustness, profiling)
- üöÄ **Tier 3:** Training utilities (fine-tuning, hyperparameter sweeps, benchmarks)

---

## ‚ö†Ô∏è **First Time Setup:**

If this is your first time OR you're continuing from a previous session:

1. **Runtime** ‚Üí **Restart runtime** (takes 5 seconds)
2. **Edit** ‚Üí **Clear all outputs** (optional, cleans up UI)
3. **Scroll down to Cell 3** ‚Üí Paste your Gist ID
4. **Runtime** ‚Üí **Run all**

This ensures a clean environment and prevents dependency conflicts.

---

**Source:** Generated from [Transformer Builder](https://transformer-builder.com)

# üß™ Transformer Builder - Advanced Testing Lab

Welcome! This notebook provides comprehensive testing and training capabilities for your custom transformer architecture.

**What's included:**
- ‚úÖ **Tier 1:** Critical validation (shape, gradients, numerical stability)
- üî¨ **Tier 2:** Advanced analysis (attention patterns, robustness, profiling)
- üöÄ **Tier 3:** Training utilities (fine-tuning, hyperparameter sweeps, benchmarks)

**Quick Start:**
1. Click "Run all" (Runtime ‚Üí Run all)
2. Review Tier 1 results (should complete in ~1 minute)
3. Explore Tier 2/3 sections as needed

**Source:** Generated from [Transformer Builder](https://transformer-builder.com)

---

---

## üìã **STEP 1: Paste Your Gist ID**

When you exported from **Transformer Builder**, you received a **Gist ID**.

**Paste it in the cell below and run it.**

If you don't have a Gist ID yet, go back to Transformer Builder and click **"Export to Colab"**.

In [None]:
# ==============================================================================
# GIST ID CONFIGURATION - Auto-detect from URL or manual input
# ==============================================================================

#@title üì• **Gist ID Configuration** { display-mode: "form" }

import os
import re
import time

# ==============================================================================
# Step 1: Try to extract gist_id from URL
# ==============================================================================

gist_id_from_url = ''

# JavaScript that extracts gist_id and returns it synchronously
js_extraction_code = """
(function() {
    let gist_id = '';
    
    try {
        // Method 1: Check URL query parameters
        const url = new URL(window.location.href);
        gist_id = url.searchParams.get('gist_id') || '';
        
        // Method 2: Try parent window
        if (!gist_id) {
            try {
                const parentUrl = new URL(window.parent.location.href);
                gist_id = parentUrl.searchParams.get('gist_id') || '';
            } catch (e) {}
        }
        
        // Method 3: Check document.referrer
        if (!gist_id && document.referrer) {
            try {
                const refUrl = new URL(document.referrer);
                gist_id = refUrl.searchParams.get('gist_id') || '';
            } catch (e) {}
        }
        
        // Method 4: Check hash params
        if (!gist_id) {
            const hash = window.location.hash || '';
            if (hash.includes('gist_id')) {
                const hashParams = new URLSearchParams(hash.substring(1));
                gist_id = hashParams.get('gist_id') || '';
            }
        }
    } catch (e) {
        console.log('URL extraction error:', e.message);
    }
    
    return gist_id || '';
})();
"""

# Try extraction with Colab's output.eval_js (synchronous)
try:
    from google.colab import output
    
    # Multiple attempts with delays (Colab can be slow to initialize)
    for attempt in range(3):
        try:
            result = output.eval_js(js_extraction_code)
            if result and isinstance(result, str) and result.strip():
                gist_id_from_url = result.strip()
                print(f"‚úÖ Auto-detected Gist ID from URL: {gist_id_from_url}")
                break
        except Exception as e:
            if attempt < 2:
                time.sleep(0.3)  # Brief delay between attempts
            
except ImportError:
    print("‚ÑπÔ∏è  Not running in Google Colab - URL auto-detection skipped")
except Exception as e:
    print(f"‚ÑπÔ∏è  URL auto-detection unavailable: {type(e).__name__}")

# ==============================================================================
# Step 2: Manual input (primary method if auto-detect fails)
# ==============================================================================

#@markdown ---
#@markdown **Enter Your Gist ID:**
GIST_ID_MANUAL = ""  #@param {type:"string"}

#@markdown ---
#@markdown **How to get your Gist ID:**
#@markdown 1. Go to Transformer Builder
#@markdown 2. Click "Export to Colab"
#@markdown 3. Copy the Gist ID shown in the modal
#@markdown 4. Paste it in the field above

# ==============================================================================
# Step 3: Environment variable fallback
# ==============================================================================

gist_id_env = os.getenv('GIST_ID', '')

# ==============================================================================
# Step 4: Determine final value
# ==============================================================================

GIST_ID = gist_id_from_url or GIST_ID_MANUAL.strip() or gist_id_env

# ==============================================================================
# Status display
# ==============================================================================

print()
print("=" * 70)

if not GIST_ID:
    print("‚è≥ GIST ID NEEDED")
    print("=" * 70)
    print()
    print("URL auto-detection did not find a Gist ID.")
    print()
    print("üìù TO CONTINUE:")
    print("   1. Enter your Gist ID in the 'GIST_ID_MANUAL' field above")
    print("   2. Re-run this cell (click the play button or Ctrl+Enter)")
    print()
    print("Don't have a Gist ID?")
    print("   ‚Üí Go to Transformer Builder ‚Üí Click 'Export to Colab'")
    print()
    # Don't raise error - let user fill in the field and re-run
    GIST_ID = None  # Explicitly set to None for downstream checks

else:
    # Validate format
    if not re.fullmatch(r"[A-Za-z0-9]+", GIST_ID):
        print("‚ö†Ô∏è  INVALID GIST ID FORMAT")
        print("=" * 70)
        print()
        print(f"The value entered: {GIST_ID!r}")
        print("Gist IDs should be alphanumeric (e.g., 'abc123def456')")
        raise ValueError("Invalid Gist ID format - please check and re-enter")
    
    # Show success
    if gist_id_from_url:
        source = "URL (auto-detected)"
    elif GIST_ID_MANUAL.strip():
        source = "Manual input"
    else:
        source = "Environment variable"
    
    print("‚úÖ GIST ID CONFIGURED")
    print("=" * 70)
    print()
    print(f"Gist ID: {GIST_ID}")
    print(f"Source:  {source}")
    print()
    print("Ready to load your model! Continue to the next cells.")

In [None]:
# ==============================================================================# DEPENDENCY VERIFICATION - v3.4.0 (Zero Installation Strategy)# ==============================================================================print("=" * 70)print("üì¶ DEPENDENCY VERIFICATION")print("=" * 70)print()print("Strategy: Use Colab pre-installed packages (no pip install)")print("This prevents NumPy corruption caused by package reinstallation.")print()# ==============================================================================# VERIFY CORE DEPENDENCIES (All pre-installed in Google Colab 2025)# ==============================================================================required = {    'torch': '2.6+',    'numpy': '2.3+',    'pandas': '1.5+',    'matplotlib': '3.7+',    'seaborn': '0.12+',}print("Checking pre-installed packages...")print()all_good = Truefor package, min_version in required.items():    try:        module = __import__(package)        version = getattr(module, '__version__', 'unknown')        print(f"  ‚úÖ {package:15s} {version:10s} (required: {min_version})")    except ImportError:        print(f"  ‚ùå {package:15s} NOT FOUND (should be pre-installed!)")        all_good = Falseprint()# ==============================================================================# NUMPY INTEGRITY CHECK# ==============================================================================print("Checking NumPy integrity...")try:    from numpy._core.umath import _center    print("  ‚úÖ NumPy C extensions intact")except ImportError as e:    print("  ‚ùå NumPy corrupted!")    print()    print("=" * 70)    print("ERROR: NumPy corruption detected")    print("=" * 70)    print()    print("This usually happens if you:")    print("  1. Ran this notebook before without restarting runtime")    print("  2. Manually installed packages that corrupted NumPy")    print()    print("FIX: Runtime ‚Üí Restart runtime, then run all cells again")    print()    raise ImportError("NumPy corrupted - please restart runtime") from eprint()if not all_good:    print("=" * 70)    print("ERROR: Missing required packages")    print("=" * 70)    print()    print("This shouldn't happen in Google Colab.")    print("Are you running this notebook in a different environment?")    print()    raise RuntimeError("Required packages not found")print("=" * 70)print("‚úÖ ALL DEPENDENCIES VERIFIED")print("=" * 70)print()print("‚úÖ No installation needed - using Colab pre-installed packages")print("‚úÖ NumPy corruption risk: ELIMINATED")print()print("Note: Advanced features (Tier 2/3) will install packages on-demand")print()

In [None]:
# ==============================================================================
# DOWNLOAD UTILS PACKAGE
# ==============================================================================

print("üì¶ Downloading test utilities package...")

# Remove old utils directory if exists
!rm -rf utils/

# Download complete utils package from GitHub
!git clone --depth 1 --branch main https://github.com/matt-hans/transformer-builder-colab-templates.git temp_repo 2>/dev/null

# Copy utils directory
!cp -r temp_repo/utils ./

# Cleanup
!rm -rf temp_repo

# Verify package structure
import sys
import os

# Add current directory to Python path
if './' not in sys.path:
    sys.path.insert(0, './')

# Verify utils package is importable
try:
    import utils
    print(f"‚úÖ Utils package loaded (version {utils.__version__})")
    
    # Verify package structure
    utils_path = os.path.join(os.getcwd(), 'utils')
    subdirs = ['adapters', 'tokenization', 'training', 'ui']
    
    for subdir in subdirs:
        subdir_path = os.path.join(utils_path, subdir)
        if os.path.exists(subdir_path):
            print(f"‚úÖ {subdir}/ directory found")
        else:
            print(f"‚ö†Ô∏è  {subdir}/ directory missing")
    
    # Test importing test functions (backward compatibility)
    from utils import (
        test_shape_robustness,
        test_gradient_flow,
        test_output_stability,
        run_all_tier1_tests
    )
    print("‚úÖ Test functions importable")
    
    print("\n‚úÖ Utils package ready!")
    
except ImportError as e:
    print(f"‚ùå Failed to import utils package: {e}")
    print("Falling back to direct file download...")
    # Fallback: download test_functions.py directly
    !wget -q https://raw.githubusercontent.com/matt-hans/transformer-builder-colab-templates/main/utils/test_functions.py

In [None]:
# ==============================================================================
# LOAD CUSTOM MODEL FROM GIST
# ==============================================================================

import os, re, json, urllib.request, urllib.error

print("=" * 70)
print("MODEL LOADING")
print("=" * 70)
print()

# ==============================================================================
# VERIFY GIST ID WAS PROVIDED
# ==============================================================================

if 'GIST_ID' not in globals() or not GIST_ID:
    print("‚ùå NO GIST ID CONFIGURED")
    print()
    print("=" * 70)
    print("üîô ACTION REQUIRED")
    print("=" * 70)
    print()
    print("Please go back and configure your Gist ID:")
    print()
    print("  1. Scroll up to the 'üì• Gist ID Configuration' cell")
    print("  2. Enter your Gist ID in the 'GIST_ID_MANUAL' field")
    print("  3. Run that cell")
    print("  4. Then run this cell again")
    print()
    print("If you don't have a Gist ID:")
    print("  ‚Üí Go to Transformer Builder ‚Üí Click 'Export to Colab'")
    print()
    raise ValueError("Gist ID required - see instructions above")

gist_id = GIST_ID
model_name = "Model"  # Default, will be overridden from config

print(f"üì• Loading model from GitHub Gist: {gist_id}")
print()

# ==============================================================================
# FETCH GIST AND LOAD MODEL FILES
# ==============================================================================

def _fetch_gist(gid: str) -> dict:
    """Fetch Gist data from GitHub API."""
    url = f"https://api.github.com/gists/{gid}"
    req = urllib.request.Request(url, headers={
        "Accept": "application/vnd.github+json",
        "User-Agent": "transformer-builder-colab"
    })
    try:
        with urllib.request.urlopen(req, timeout=20) as resp:
            return json.loads(resp.read().decode("utf-8"))
    except urllib.error.HTTPError as e:
        detail = f"HTTP {e.code}"
        try:
            body = e.read().decode("utf-8")
            if "rate limit" in body.lower():
                detail += " - GitHub API rate limit (try again in an hour)"
            elif e.code == 404:
                detail += " - Gist not found (check your Gist ID)"
        except:
            pass
        raise RuntimeError(f"GitHub API error: {detail}") from e
    except Exception as e:
        raise RuntimeError(f"Network error: {e}") from e

def _write(path: str, text: str):
    """Write text to file."""
    with open(path, "w") as f:
        f.write(text)

# Fetch Gist
try:
    gist_data = _fetch_gist(gist_id)
    files = gist_data.get("files") or {}
    
    # Check for required files
    if "model.py" not in files:
        raise RuntimeError("Gist is missing 'model.py' - please re-export from Transformer Builder")
    if "config.json" not in files:
        raise RuntimeError("Gist is missing 'config.json' - please re-export from Transformer Builder")
    
    model_code = files["model.py"].get("content", "")
    config_json = files["config.json"].get("content", "")
    
    if not model_code or not config_json:
        raise RuntimeError("Empty content in model.py or config.json")
    
    # Write to files
    _write("custom_transformer.py", model_code)
    _write("config.json", config_json)
    
    print(f"‚úÖ Model loaded successfully!")
    print(f"‚úÖ Gist URL: {gist_data.get('html_url', 'N/A')}")
    print(f"‚úÖ Model code: {len(model_code):,} bytes")
    print(f"‚úÖ Config: {len(config_json):,} bytes")
    print()
    
    # Parse model name from config if available
    try:
        config_dict = json.loads(config_json)
        if 'model_name' in config_dict:
            model_name = config_dict['model_name']
            print(f"‚úÖ Model name: {model_name}")
            print()
    except:
        pass

except Exception as e:
    print(f"‚ùå Failed to load model from Gist!")
    print()
    print(f"Error: {e}")
    print()
    print("=" * 70)
    print("TROUBLESHOOTING")
    print("=" * 70)
    print()
    print("Common issues:")
    print(f"  1. Verify Gist ID is correct: {gist_id}")
    print("  2. Ensure you exported from Transformer Builder successfully")
    print("  3. Check you're not hitting GitHub rate limit (60 requests/hour)")
    print("  4. Try re-exporting from Transformer Builder")
    print()
    print("Direct link to check:")
    print(f"  ‚Üí https://gist.github.com/{gist_id}")
    print()
    raise

print("=" * 70)
print("‚úÖ MODEL LOADING COMPLETE")
print("=" * 70)
print()
print("Next: Continue to model instantiation and testing below!")
print()

# Store model_name for next cell
params = {"name": model_name}

In [ ]:
# ==============================================================================
# DYNAMIC TRAINING LINK - Pass Gist ID to training.ipynb automatically
# ==============================================================================

from IPython.display import display, Javascript

# Get current gist_id and model_name from Python variables
# Note: model_name is defined earlier in cell 8 after loading config
gist_id_for_js = GIST_ID
model_name_for_js = model_name if 'model_name' in dir() else 'Model'

js_code = f"""
(function() {{
    // Find all Colab badge links pointing to training.ipynb
    const links = document.querySelectorAll('a[href*="training.ipynb"]');
    
    links.forEach(link => {{
        const baseUrl = link.href.split('#')[0];  // Remove existing hash if any
        const gistId = '{gist_id_for_js}';
        const modelName = '{model_name_for_js}';
        
        if (gistId && gistId.trim()) {{
            // Append hash parameters for training.ipynb to read
            link.href = baseUrl + '#gist_id=' + encodeURIComponent(gistId) + '&name=' + encodeURIComponent(modelName);
            console.log('‚úÖ Updated training link:', link.href);
        }}
    }});
}})()
"""

display(Javascript(js_code))

print("=" * 70)
print("‚úÖ TRAINING LINK UPDATED")
print("=" * 70)
print()
print(f"Gist ID: {gist_id_for_js}")
print(f"Model Name: {model_name_for_js}")
print()
print("The 'Open Training Notebook' button will now automatically pass")
print("your Gist ID to training.ipynb - no need to enter it again!")
print()
print("üí° Scroll down to the 'Tier 3' section and click the Colab badge")


## üìÑ View Loaded Model Code

This cell displays the Python code that was loaded from your Transformer Builder export. You can review the architecture before running tests.

In [None]:
# Display the loaded model code for transparency
print("=" * 80)
print("üìÑ LOADED MODEL CODE (custom_transformer.py)")
print("=" * 80)
print()

with open('custom_transformer.py', 'r') as f:
    model_code_display = f.read()

# Use syntax highlighting
from IPython.display import Code
display(Code(model_code_display, language='python'))

print()
print("=" * 80)
print("üìã MODEL CONFIGURATION (config.json)")
print("=" * 80)
print()

with open('config.json', 'r') as f:
    config_display = json.load(f)

# Pretty print JSON
print(json.dumps(config_display, indent=2))
print()
print("‚úÖ You can now proceed to run the model instantiation and tests below!")

## Import and Instantiate Model

Load your custom transformer and prepare for testing.

In [None]:
import torch
import torch.nn as nn
import inspect

# Import the custom model
exec(open('custom_transformer.py').read())

# Load config
with open('config.json') as f:
    config_dict = json.load(f)

# Find the model class
model_class = None
for name, obj in list(globals().items()):
    if isinstance(obj, type) and issubclass(obj, nn.Module) and obj is not nn.Module:
        if name == params['name']:
            model_class = obj
            break

if model_class is None:
    # Fallback: find any nn.Module subclass
    for name, obj in list(globals().items()):
        if isinstance(obj, type) and issubclass(obj, nn.Module) and obj is not nn.Module:
            model_class = obj
            print(f"‚ö†Ô∏è Using {name} (expected {params['name']})")
            break

if model_class:
    # Instantiate model - try both parameterless and parameterized approaches
    try:
        # Check if __init__ accepts parameters (besides self)
        sig = inspect.signature(model_class.__init__)
        params_list = [p for p in sig.parameters.values() if p.name != 'self']
        
        if len(params_list) == 0:
            # Parameterless constructor (Transformer Builder models)
            print("‚ÑπÔ∏è  Model has parameterless constructor (Transformer Builder export)")
            model = model_class()
        else:
            # Parameterized constructor (traditional models)
            print(f"‚ÑπÔ∏è  Model accepts {len(params_list)} parameter(s)")
            model = model_class(**config_dict)
        
        model.eval()
        
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"‚úÖ Model instantiated: {model_class.__name__}")
        print(f"‚úÖ Total parameters: {total_params:,}")
        print(f"‚úÖ Trainable parameters: {trainable_params:,}")
        
        # Move to GPU if available
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        print(f"‚úÖ Device: {device}")
        
        # Display model summary (using native torch instead of torchinfo)
        print()
        print("=" * 70)
        print("MODEL SUMMARY")
        print("=" * 70)
        print()
        print(model)
        print()
        print("=" * 70)
        print(f"Total parameters:      {total_params:,}")
        print(f"Trainable parameters:  {trainable_params:,}")
        print(f"Non-trainable params:  {total_params - trainable_params:,}")
        
        # Calculate model size
        param_size = sum(p.numel() * p.element_size() for p in model.parameters())
        buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
        size_mb = (param_size + buffer_size) / 1024**2
        print(f"Model size:            {size_mb:.2f} MB")
        print("=" * 70)
        print()
        
    except Exception as e:
        print(f"‚ùå Failed to instantiate model: {e}")
        import traceback
        traceback.print_exc()
        raise
else:
    raise RuntimeError(f"Could not find model class '{params['name']}' in generated code")

# Create config object for test functions (with proper vocab_size)
class ModelConfig:
    def __init__(self, **kwargs):
        # Set defaults
        self.vocab_size = 50257
        self.max_seq_len = 512
        self.max_batch_size = 8
        
        # If nodes-based config, extract common params
        if 'nodes' in kwargs:
            for node in kwargs['nodes']:
                node_params = node.get('params', {})
                if 'vocab_size' in node_params:
                    self.vocab_size = node_params['vocab_size']
                if 'max_seq_len' in node_params or 'seq_length' in node_params:
                    self.max_seq_len = node_params.get('max_seq_len') or node_params.get('seq_length', 512)
        
        # Override with flat params if present
        for key, value in kwargs.items():
            if key not in ['nodes', 'version', 'model_name']:
                setattr(self, key, value)

config = ModelConfig(**config_dict)
print(f"‚úÖ Config prepared (vocab_size={config.vocab_size}, max_seq_len={config.max_seq_len})")
print("‚úÖ Ready for testing!")


---

# üîç Tier 1: Critical Validation

These tests verify your model is mathematically sound and ready for training.

**Estimated time:** ~1 minute

**What's tested:**
- ‚úÖ Shape validation across edge cases
- ‚úÖ Gradient flow (detect vanishing/exploding gradients)
- ‚úÖ Numerical stability (NaN/Inf detection)
- ‚úÖ Parameter initialization quality
- ‚úÖ Memory footprint scaling
- ‚úÖ Inference speed benchmarks

In [None]:
# Import test utilities from the cloned utils package
from utils.test_functions import (
    test_shape_robustness,
    test_gradient_flow,
    test_output_stability,
    test_parameter_initialization,
    test_memory_footprint,
    test_inference_speed
)

print("‚úÖ Test functions loaded from utils package")

In [None]:
print("=" * 80)
print("TIER 1: CRITICAL VALIDATION")
print("=" * 80)
print()

# Test 1: Shape Robustness
print("Test 1/6: Shape Validation")
print("-" * 80)
shape_results = test_shape_robustness(model, config)
display(shape_results)
print()

# Test 2: Gradient Flow
print("Test 2/6: Gradient Flow Analysis")
print("-" * 80)
grad_results = test_gradient_flow(model, config)
display(grad_results)
print()

# Test 3: Output Stability
print("Test 3/6: Numerical Stability")
print("-" * 80)
stability_stats = test_output_stability(model, config, n_samples=100)
print()

# Test 4: Parameter Initialization
print("Test 4/6: Parameter Initialization")
print("-" * 80)
param_results = test_parameter_initialization(model)
display(param_results)
print()

# Test 5: Memory Footprint
print("Test 5/6: Memory Footprint Analysis")
print("-" * 80)
memory_results = test_memory_footprint(model, config)
display(memory_results)
print()

# Test 6: Inference Speed
print("Test 6/6: Inference Speed Benchmark")
print("-" * 80)
speed_stats = test_inference_speed(model, config, n_trials=50)
print()

print("=" * 80)
print("‚úÖ TIER 1 VALIDATION COMPLETE")
print("=" * 80)
print()
print("All critical tests passed! Your model is mathematically sound.")
print()
print("üìù Next steps:")
print("   ‚Ä¢ Tier 2: Advanced analysis (attention patterns, attribution)")
print("     ‚Üí Install optional dependencies in the cell before Tier 2")
print("     ‚Üí Then run Tier 2 tests")
print()
print("   ‚Ä¢ Tier 3: Training utilities (fine-tuning, hyperparameter search)")
print("     ‚Üí Install optional dependencies in the cell before Tier 3")
print("     ‚Üí Then run Tier 3 tests")

---

# üî¨ Tier 2: Advanced Analysis

Deep dive into model behavior with advanced diagnostic tools.

**Estimated time:** ~3-5 minutes

**What's tested:**
- üéØ **Attention Patterns:** Visualize attention weights, detect collapsed attention, analyze head specialization
- üîç **Attribution Analysis:** Identify which input tokens contribute most to predictions (using Captum)
- üõ°Ô∏è **Robustness Testing:** Measure stability under input perturbations and noise

**Note:** These tests are optional but highly recommended for understanding model behavior.

## Tier 2: Advanced Analysis

**Note:** Tier 2 tests use only Colab pre-installed packages (no installation required).

- Test 1: Attention Pattern Analysis (uses built-in PyTorch)
- Test 2: Robustness Testing (uses numpy/torch)

All tests run automatically in the cell below.

In [None]:
# Import Tier 2 test functions
from utils.test_functions import (
    test_attention_patterns,
    test_robustness
)

print("=" * 80)
print("TIER 2: ADVANCED ANALYSIS")
print("=" * 80)
print()

# Test 1: Attention Patterns
print("Test 1/2: Attention Pattern Analysis")
print("-" * 80)
try:
    attention_results = test_attention_patterns(model, config)
    if attention_results is not None:
        display(attention_results)
    print("‚úÖ Attention analysis complete")
except Exception as e:
    print(f"‚ö†Ô∏è Attention analysis skipped: {e}")
print()

# Test 2: Robustness Testing
print("Test 2/2: Robustness Under Noise")
print("-" * 80)
try:
    robustness_results = test_robustness(model, config, n_samples=20)
    if robustness_results is not None:
        display(robustness_results)
    print("‚úÖ Robustness analysis complete")
except Exception as e:
    print(f"‚ö†Ô∏è Robustness analysis skipped: {e}")
print()

print("=" * 80)
print("‚úÖ TIER 2 ANALYSIS COMPLETE")
print("=" * 80)
print()
print("Next: Scroll down for Tier 3 (Training & Fine-Tuning)")

---

# üöÄ Tier 3: Training & Production Utilities

**Training utilities have been moved to a separate notebook to prevent dependency conflicts.**

## üìì Continue to Training Notebook

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/matt-hans/transformer-builder-colab-templates/blob/main/training.ipynb)

**Click the button above to open the training notebook in Colab.**

### What's included in training.ipynb:
- üéì **Fine-Tuning:** Training loop with loss tracking and gradient monitoring
- üîß **Hyperparameter Search:** Automated optimization using Optuna
- üìä **Benchmark Comparison:** Compare against production baselines (distilgpt2, bert-base)

### Before running training.ipynb:
1. **Runtime ‚Üí Restart runtime** (fresh environment required)
2. **Paste your same Gist ID** from Cell 3 above
3. **Run all cells** - dependencies will install automatically

**Estimated time:** 10-20 minutes (GPU recommended)

---

### Why separate notebooks?

Training utilities require `pytorch-lightning` and `optuna`, which have NumPy version requirements that conflict with the zero-installation strategy used in this testing notebook.

Running them in separate runtimes ensures:
- ‚úÖ Testing notebook (this one) stays fast and dependency-free
- ‚úÖ Training notebook has all the tools it needs without corruption
- ‚úÖ Clear separation between validation and training workflows

---

**Repository:** [transformer-builder-colab-templates](https://github.com/matt-hans/transformer-builder-colab-templates)

**Source:** Generated from [Transformer Builder](https://transformer-builder.com)

In [None]:
# Mode selection and config preview (v4.0.0)
from utils.ui.presets import build_configs_for_mode

# Choose a mode: FAST_DEV, STANDARD_EXPERIMENT, ABLATION_SWEEP
mode = 'FAST_DEV'
training_cfg, task_spec, eval_cfg = build_configs_for_mode(mode)

print('Mode:', mode)
print('TrainingConfig:', training_cfg)
print('TaskSpec:', task_spec)
print('EvalConfig:', eval_cfg)


In [None]:
# Load model from GitHub Gist (with revision pinning)
from utils.adapters.gist_loader import load_gist_model
from utils.training.experiment_db import ExperimentDB
from pathlib import Path
import importlib.util, sys

gist_id = 'abcdef1234567890'  # replace with your gist id
revision = None  # or a specific revision sha
md = load_gist_model(gist_id, revision)
print('Gist owner:', md.owner)
print('Files:', md.file_names)
print('SHA256:', md.sha256)

# Optional: dynamic import model.py if present
root = Path('./external/gists') / md.gist_id / (md.revision or 'latest')
model_path = root / 'model.py'
model = None
if model_path.exists():
    spec = importlib.util.spec_from_file_location('gist_model', str(model_path))
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    # Expect either build_model() or Model class
    if hasattr(mod, 'build_model'):
        model = mod.build_model()
    elif hasattr(mod, 'Model'):
        model = mod.Model()
    print('Loaded model from gist')
else:
    print('model.py not found in gist; define model manually')

# Log gist metadata to ExperimentDB
try:
    db = ExperimentDB('experiments.db')
    run_id = db.log_run(
        run_name='gist-validation',
        config={'source': 'gist'},
        notes='Gist load test',
        gist_id=md.gist_id,
        gist_revision=md.revision,
        gist_sha256=md.sha256,
    )
    print('Logged run_id:', run_id)
except Exception as e:
    print('DB logging skipped:', e)
