# 🧪 Transformer Builder - Advanced Testing Lab

**Welcome! This notebook tests your custom transformer architecture.**

---

## 🚀 **Quick Start (3 Steps)**

### **STEP 1:** Paste Your Gist ID
↓ Scroll down to Cell 3 and paste the Gist ID you received from Transformer Builder

### **STEP 2:** Run All Cells  
Click **Runtime → Run all** (or run cells one-by-one)

### **STEP 3:** Review Test Results
Your model will be validated through 3 testing tiers

---

## 📋 **What's Included:**

- ✅ **Tier 1:** Critical validation (shape, gradients, numerical stability)
- 🔬 **Tier 2:** Advanced analysis (attention patterns, robustness, profiling)
- 🚀 **Tier 3:** Training utilities (fine-tuning, hyperparameter sweeps, benchmarks)

---

## ⚠️ **First Time Setup:**

If this is your first time OR you're continuing from a previous session:

1. **Runtime** → **Restart runtime** (takes 5 seconds)
2. **Edit** → **Clear all outputs** (optional, cleans up UI)
3. **Scroll down to Cell 3** → Paste your Gist ID
4. **Runtime** → **Run all**

This ensures a clean environment and prevents dependency conflicts.

---

**Source:** Generated from [Transformer Builder](https://transformer-builder.com)

# 🧪 Transformer Builder - Advanced Testing Lab

Welcome! This notebook provides comprehensive testing and training capabilities for your custom transformer architecture.

**What's included:**
- ✅ **Tier 1:** Critical validation (shape, gradients, numerical stability)
- 🔬 **Tier 2:** Advanced analysis (attention patterns, robustness, profiling)
- 🚀 **Tier 3:** Training utilities (fine-tuning, hyperparameter sweeps, benchmarks)

**Quick Start:**
1. Click "Run all" (Runtime → Run all)
2. Review Tier 1 results (should complete in ~1 minute)
3. Explore Tier 2/3 sections as needed

**Source:** Generated from [Transformer Builder](https://transformer-builder.com)

---

---

## 📋 **STEP 1: Paste Your Gist ID**

When you exported from **Transformer Builder**, you received a **Gist ID**.

**Paste it in the cell below and run it.**

If you don't have a Gist ID yet, go back to Transformer Builder and click **"Export to Colab"**.

In [None]:
# ==============================================================================
# GIST ID INPUT - Paste the ID from Transformer Builder
# ==============================================================================

#@title 📥 **Paste Your Gist ID Here**
GIST_ID = ""  #@param {type:"string"}

#@markdown ---
#@markdown **Where to find your Gist ID:**
#@markdown 1. Go to Transformer Builder
#@markdown 2. Click "Export to Colab"
#@markdown 3. Copy the Gist ID from the modal
#@markdown 4. Paste it above and run this cell

if not GIST_ID or not GIST_ID.strip():
    print("=" * 70)
    print("⚠️  NO GIST ID PROVIDED")
    print("=" * 70)
    print()
    print("Please paste your Gist ID in the field above and re-run this cell.")
    print()
    print("If you don't have a Gist ID:")
    print("  1. Go to Transformer Builder")
    print("  2. Click 'Export to Colab'")
    print("  3. Copy the Gist ID from the modal")
    print("  4. Come back here and paste it")
    print()
    raise ValueError("Gist ID is required to load your custom model")
else:
    # Validate format
    import re
    if not re.fullmatch(r"[A-Za-z0-9]+", GIST_ID.strip()):
        print("=" * 70)
        print("⚠️  INVALID GIST ID FORMAT")
        print("=" * 70)
        print()
        print(f"The Gist ID you entered: {GIST_ID!r}")
        print()
        print("Gist IDs should be alphanumeric (e.g., 'abc123def456')")
        print("Please check and re-enter.")
        print()
        raise ValueError("Invalid Gist ID format")
    
    # Store for later use
    GIST_ID = GIST_ID.strip()
    
    print("=" * 70)
    print("✅ GIST ID SAVED")
    print("=" * 70)
    print()
    print(f"Gist ID: {GIST_ID}")
    print()
    print("You can now proceed to run the cells below to:")
    print("  1. Install dependencies")
    print("  2. Load your custom model")
    print("  3. Run tests")
    print()
    print("💡 Tip: Click 'Runtime → Run all' to execute everything automatically")

In [None]:
# ==============================================================================# RUNTIME FRESHNESS DETECTION - Prevents reused runtimes with corrupted packages# ==============================================================================import os# Check if this runtime was previously usedRUNTIME_MARKER = "/tmp/transformer_builder_runtime_used"print("=" * 70)print("🔍 RUNTIME FRESHNESS CHECK")print("=" * 70)print()if os.path.exists(RUNTIME_MARKER):    print("⚠️  WARNING: This runtime was previously used!")    print()    print("Reusing runtimes can cause dependency conflicts.")    print()    print("✅ RECOMMENDED: Restart runtime for clean environment")    print("   (Runtime → Restart runtime → Run all)")    print()    print("=" * 70)        # Give user option to continue anyway (advanced users)    response = input("Continue with reused runtime anyway? [y/N]: ")        if response.lower() != 'y':        print()        print("✅ Good choice! Please restart the runtime and try again.")        raise RuntimeError("Runtime restart recommended for clean environment")    else:        print()        print("⚠️  Proceeding with reused runtime...")        print()else:    # Mark runtime as used    with open(RUNTIME_MARKER, 'w') as f:        f.write("used")        print("✅ Fresh runtime detected!")    print()print("📌 Version: v3.3.2 (2025-01-13)")print("📌 Fix: Proactive numpy repair + minimal dependencies")print()print("=" * 70)print("✅ Runtime check complete")print("=" * 70)print()

In [None]:
# ==============================================================================# DEPENDENCY VERIFICATION - v3.4.0 (Zero Installation Strategy)# ==============================================================================print("=" * 70)print("📦 DEPENDENCY VERIFICATION")print("=" * 70)print()print("Strategy: Use Colab pre-installed packages (no pip install)")print("This prevents NumPy corruption caused by package reinstallation.")print()# ==============================================================================# VERIFY CORE DEPENDENCIES (All pre-installed in Google Colab 2025)# ==============================================================================required = {    'torch': '2.6+',    'numpy': '2.3+',    'pandas': '1.5+',    'matplotlib': '3.7+',    'seaborn': '0.12+',}print("Checking pre-installed packages...")print()all_good = Truefor package, min_version in required.items():    try:        module = __import__(package)        version = getattr(module, '__version__', 'unknown')        print(f"  ✅ {package:15s} {version:10s} (required: {min_version})")    except ImportError:        print(f"  ❌ {package:15s} NOT FOUND (should be pre-installed!)")        all_good = Falseprint()# ==============================================================================# NUMPY INTEGRITY CHECK# ==============================================================================print("Checking NumPy integrity...")try:    from numpy._core.umath import _center    print("  ✅ NumPy C extensions intact")except ImportError as e:    print("  ❌ NumPy corrupted!")    print()    print("=" * 70)    print("ERROR: NumPy corruption detected")    print("=" * 70)    print()    print("This usually happens if you:")    print("  1. Ran this notebook before without restarting runtime")    print("  2. Manually installed packages that corrupted NumPy")    print()    print("FIX: Runtime → Restart runtime, then run all cells again")    print()    raise ImportError("NumPy corrupted - please restart runtime") from eprint()if not all_good:    print("=" * 70)    print("ERROR: Missing required packages")    print("=" * 70)    print()    print("This shouldn't happen in Google Colab.")    print("Are you running this notebook in a different environment?")    print()    raise RuntimeError("Required packages not found")print("=" * 70)print("✅ ALL DEPENDENCIES VERIFIED")print("=" * 70)print()print("✅ No installation needed - using Colab pre-installed packages")print("✅ NumPy corruption risk: ELIMINATED")print()print("Note: Advanced features (Tier 2/3) will install packages on-demand")print()

In [None]:
# ==============================================================================
# DOWNLOAD UTILS PACKAGE
# ==============================================================================

print("📦 Downloading test utilities package...")

# Remove old utils directory if exists
!rm -rf utils/

# Download complete utils package from GitHub
!git clone --depth 1 --branch main https://github.com/matt-hans/transformer-builder-colab-templates.git temp_repo 2>/dev/null

# Copy utils directory
!cp -r temp_repo/utils ./

# Cleanup
!rm -rf temp_repo

# Verify package structure
import sys
import os

# Add current directory to Python path
if './' not in sys.path:
    sys.path.insert(0, './')

# Verify utils package is importable
try:
    import utils
    print(f"✅ Utils package loaded (version {utils.__version__})")
    
    # Verify package structure
    utils_path = os.path.join(os.getcwd(), 'utils')
    subdirs = ['adapters', 'tokenization', 'training', 'ui']
    
    for subdir in subdirs:
        subdir_path = os.path.join(utils_path, subdir)
        if os.path.exists(subdir_path):
            print(f"✅ {subdir}/ directory found")
        else:
            print(f"⚠️  {subdir}/ directory missing")
    
    # Test importing test functions (backward compatibility)
    from utils import (
        test_shape_robustness,
        test_gradient_flow,
        test_output_stability,
        run_all_tier1_tests
    )
    print("✅ Test functions importable")
    
    print("\n✅ Utils package ready!")
    
except ImportError as e:
    print(f"❌ Failed to import utils package: {e}")
    print("Falling back to direct file download...")
    # Fallback: download test_functions.py directly
    !wget -q https://raw.githubusercontent.com/matt-hans/transformer-builder-colab-templates/main/utils/test_functions.py

In [ ]:
# ==============================================================================
# LOAD CUSTOM MODEL - v3.4.0 (Simple Modal Approach)
# ==============================================================================

import os, re, json, urllib.request, urllib.error

print("=" * 70)
print("MODEL LOADING")
print("=" * 70)
print()

# ==============================================================================
# VERIFY GIST ID WAS PROVIDED
# ==============================================================================

if 'GIST_ID' not in globals() or not GIST_ID:
    print("❌ ERROR: No Gist ID found!")
    print()
    print("=" * 70)
    print("🔙 GO BACK TO CELL 3")
    print("=" * 70)
    print()
    print("You must run Cell 3 first to provide your Gist ID.")
    print()
    print("Steps:")
    print("  1. Scroll up to Cell 3")
    print("  2. Paste your Gist ID from Transformer Builder")
    print("  3. Run Cell 3")
    print("  4. Come back and run this cell")
    print()
    raise ValueError("Gist ID required - please run Cell 3 first")

gist_id = GIST_ID
model_name = "Model"  # Default name, will be overridden from config

print(f"📥 Loading model from GitHub Gist: {gist_id}")
print()

# ==============================================================================
# FETCH GIST AND LOAD MODEL FILES
# ==============================================================================

def _fetch_gist(gid: str) -> dict:
    """Fetch Gist data from GitHub API."""
    url = f"https://api.github.com/gists/{gid}"
    req = urllib.request.Request(url, headers={
        "Accept": "application/vnd.github+json",
        "User-Agent": "transformer-builder-colab"
    })
    try:
        with urllib.request.urlopen(req, timeout=20) as resp:
            return json.loads(resp.read().decode("utf-8"))
    except urllib.error.HTTPError as e:
        detail = f"HTTP {e.code}"
        try:
            body = e.read().decode("utf-8")
            if "rate limit" in body.lower():
                detail += " - GitHub API rate limit (try again in an hour)"
            elif e.code == 404:
                detail += " - Gist not found (check your Gist ID)"
        except:
            pass
        raise RuntimeError(f"GitHub API error: {detail}") from e
    except Exception as e:
        raise RuntimeError(f"Network error: {e}") from e

def _write(path: str, text: str):
    """Write text to file."""
    with open(path, "w") as f:
        f.write(text)

# Fetch Gist
try:
    gist_data = _fetch_gist(gist_id)
    files = gist_data.get("files") or {}
    
    # Check for required files
    if "model.py" not in files:
        raise RuntimeError("Gist is missing 'model.py' - please re-export from Transformer Builder")
    if "config.json" not in files:
        raise RuntimeError("Gist is missing 'config.json' - please re-export from Transformer Builder")
    
    model_code = files["model.py"].get("content", "")
    config_json = files["config.json"].get("content", "")
    
    if not model_code or not config_json:
        raise RuntimeError("Empty content in model.py or config.json")
    
    # Write to files
    _write("custom_transformer.py", model_code)
    _write("config.json", config_json)
    
    print(f"✅ Model loaded successfully!")
    print(f"✅ Gist URL: {gist_data.get('html_url', 'N/A')}")
    print(f"✅ Model code: {len(model_code):,} bytes")
    print(f"✅ Config: {len(config_json):,} bytes")
    print()
    
    # Parse model name from config if available
    try:
        config_dict = json.loads(config_json)
        if 'model_name' in config_dict:
            model_name = config_dict['model_name']
            print(f"✅ Model name: {model_name}")
            print()
    except:
        pass

except Exception as e:
    print(f"❌ Failed to load model from Gist!")
    print()
    print(f"Error: {e}")
    print()
    print("=" * 70)
    print("TROUBLESHOOTING")
    print("=" * 70)
    print()
    print("Common issues:")
    print("  1. Check your Gist ID is correct (go back to Cell 3)")
    print("  2. Ensure you exported from Transformer Builder successfully")
    print("  3. Check you're not hitting GitHub rate limit (60 requests/hour)")
    print("  4. Try re-exporting from Transformer Builder")
    print()
    print("If the problem persists:")
    print(f"  • Gist URL: https://gist.github.com/{gist_id}")
    print("  • Verify the Gist contains model.py and config.json")
    print()
    raise

print("=" * 70)
print("✅ MODEL LOADING COMPLETE")
print("=" * 70)
print()
print("Next: Continue to model instantiation and testing below!")
print()

# Store model_name for next cell
params = {"name": model_name}

## 📄 View Loaded Model Code

This cell displays the Python code that was loaded from your Transformer Builder export. You can review the architecture before running tests.

In [None]:
# Display the loaded model code for transparency
print("=" * 80)
print("📄 LOADED MODEL CODE (custom_transformer.py)")
print("=" * 80)
print()

with open('custom_transformer.py', 'r') as f:
    model_code_display = f.read()

# Use syntax highlighting
from IPython.display import Code
display(Code(model_code_display, language='python'))

print()
print("=" * 80)
print("📋 MODEL CONFIGURATION (config.json)")
print("=" * 80)
print()

with open('config.json', 'r') as f:
    config_display = json.load(f)

# Pretty print JSON
print(json.dumps(config_display, indent=2))
print()
print("✅ You can now proceed to run the model instantiation and tests below!")

## Dynamic Dependency Detection

Automatically detect and install any custom dependencies your model needs.

In [None]:
import ast
import subprocess
import sys

# Parse imports from generated code
with open('custom_transformer.py', 'r') as f:
    source_code = f.read()
    tree = ast.parse(source_code)

# Extract all imports
imports = set()
for node in ast.walk(tree):
    if isinstance(node, ast.Import):
        for alias in node.names:
            imports.add(alias.name.split('.')[0])
    elif isinstance(node, ast.ImportFrom):
        if node.module:
            imports.add(node.module.split('.')[0])

print(f"Detected imports: {', '.join(sorted(imports))}")

# Standard library modules (don't need pip install)
stdlib_modules = {
    'abc', 'collections', 'dataclasses', 'functools', 'json', 'math',
    'typing', 'warnings', 'os', 'sys', 're', 'time', 'copy'
}

# Already installed
installed_modules = {
    'torch', 'transformers', 'numpy', 'scipy', 'matplotlib',
    'pandas', 'seaborn', 'tqdm', 'torchinfo', 'captum', 'optuna'
}

# Find missing packages
missing = imports - stdlib_modules - installed_modules

if missing:
    print(f"\nInstalling additional dependencies: {', '.join(missing)}")
    for package in missing:
        try:
            subprocess.check_call(
                [sys.executable, '-m', 'pip', 'install', '-q', package],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL
            )
            print(f"  ✅ Installed {package}")
        except subprocess.CalledProcessError:
            print(f"  ⚠️ Failed to install {package} (may not be a pip package)")
else:
    print("\n✅ All dependencies already installed")

## Import and Instantiate Model

Load your custom transformer and prepare for testing.

In [None]:
import torchimport torch.nn as nnimport inspect# Import the custom modelexec(open('custom_transformer.py').read())# Load configwith open('config.json') as f:    config_dict = json.load(f)# Find the model classmodel_class = Nonefor name, obj in list(globals().items()):    if isinstance(obj, type) and issubclass(obj, nn.Module) and obj is not nn.Module:        if name == params['name']:            model_class = obj            breakif model_class is None:    # Fallback: find any nn.Module subclass    for name, obj in list(globals().items()):        if isinstance(obj, type) and issubclass(obj, nn.Module) and obj is not nn.Module:            model_class = obj            print(f"⚠️ Using {name} (expected {params['name']})")            breakif model_class:    # Instantiate model - try both parameterless and parameterized approaches    try:        # Check if __init__ accepts parameters (besides self)        sig = inspect.signature(model_class.__init__)        params_list = [p for p in sig.parameters.values() if p.name != 'self']                if len(params_list) == 0:            # Parameterless constructor (Transformer Builder models)            print("ℹ️  Model has parameterless constructor (Transformer Builder export)")            model = model_class()        else:            # Parameterized constructor (traditional models)            print(f"ℹ️  Model accepts {len(params_list)} parameter(s)")            model = model_class(**config_dict)                model.eval()                total_params = sum(p.numel() for p in model.parameters())        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)                print(f"✅ Model instantiated: {model_class.__name__}")        print(f"✅ Total parameters: {total_params:,}")        print(f"✅ Trainable parameters: {trainable_params:,}")                # Move to GPU if available        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')        model = model.to(device)        print(f"✅ Device: {device}")                # Display model summary (using native torch instead of torchinfo)        print()        print("=" * 70)        print("MODEL SUMMARY")        print("=" * 70)        print()        print(model)        print()        print("=" * 70)        print(f"Total parameters:      {total_params:,}")        print(f"Trainable parameters:  {trainable_params:,}")        print(f"Non-trainable params:  {total_params - trainable_params:,}")                # Calculate model size        param_size = sum(p.numel() * p.element_size() for p in model.parameters())        buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())        size_mb = (param_size + buffer_size) / 1024**2        print(f"Model size:            {size_mb:.2f} MB")        print("=" * 70)        print()            except Exception as e:        print(f"❌ Failed to instantiate model: {e}")        import traceback        traceback.print_exc()        raiseelse:    raise RuntimeError(f"Could not find model class '{params['name']}' in generated code")# Create config object for test functions (with proper vocab_size)class ModelConfig:    def __init__(self, **kwargs):        # Set defaults        self.vocab_size = 50257        self.max_seq_len = 512        self.max_batch_size = 8                # If nodes-based config, extract common params        if 'nodes' in kwargs:            for node in kwargs['nodes']:                node_params = node.get('params', {})                if 'vocab_size' in node_params:                    self.vocab_size = node_params['vocab_size']                if 'max_seq_len' in node_params or 'seq_length' in node_params:                    self.max_seq_len = node_params.get('max_seq_len') or node_params.get('seq_length', 512)                # Override with flat params if present        for key, value in kwargs.items():            if key not in ['nodes', 'version', 'model_name']:                setattr(self, key, value)config = ModelConfig(**config_dict)print(f"✅ Config prepared (vocab_size={config.vocab_size}, max_seq_len={config.max_seq_len})")print("✅ Ready for testing!")

---

# 🔍 Tier 1: Critical Validation

These tests verify your model is mathematically sound and ready for training.

**Estimated time:** ~1 minute

**What's tested:**
- ✅ Shape validation across edge cases
- ✅ Gradient flow (detect vanishing/exploding gradients)
- ✅ Numerical stability (NaN/Inf detection)
- ✅ Parameter initialization quality
- ✅ Memory footprint scaling
- ✅ Inference speed benchmarks

In [None]:
# Import test utilities from the cloned utils package
from utils.test_functions import (
    test_shape_robustness,
    test_gradient_flow,
    test_output_stability,
    test_parameter_initialization,
    test_memory_footprint,
    test_inference_speed
)

print("✅ Test functions loaded from utils package")

In [None]:
print("=" * 80)
print("TIER 1: CRITICAL VALIDATION")
print("=" * 80)
print()

# Test 1: Shape Robustness
print("Test 1/6: Shape Validation")
print("-" * 80)
shape_results = test_shape_robustness(model, config)
display(shape_results)
print()

# Test 2: Gradient Flow
print("Test 2/6: Gradient Flow Analysis")
print("-" * 80)
grad_results = test_gradient_flow(model, config)
display(grad_results)
print()

# Test 3: Output Stability
print("Test 3/6: Numerical Stability")
print("-" * 80)
stability_stats = test_output_stability(model, config, n_samples=100)
print()

# Test 4: Parameter Initialization
print("Test 4/6: Parameter Initialization")
print("-" * 80)
param_results = test_parameter_initialization(model)
display(param_results)
print()

# Test 5: Memory Footprint
print("Test 5/6: Memory Footprint Analysis")
print("-" * 80)
memory_results = test_memory_footprint(model, config)
display(memory_results)
print()

# Test 6: Inference Speed
print("Test 6/6: Inference Speed Benchmark")
print("-" * 80)
speed_stats = test_inference_speed(model, config, n_trials=50)
print()

print("=" * 80)
print("✅ TIER 1 VALIDATION COMPLETE")
print("=" * 80)
print()
print("All critical tests passed! Your model is mathematically sound.")
print()
print("📝 Next steps:")
print("   • Tier 2: Advanced analysis (attention patterns, attribution)")
print("     → Install optional dependencies in the cell before Tier 2")
print("     → Then run Tier 2 tests")
print()
print("   • Tier 3: Training utilities (fine-tuning, hyperparameter search)")
print("     → Install optional dependencies in the cell before Tier 3")
print("     → Then run Tier 3 tests")

---

# 🔬 Tier 2: Advanced Analysis

Deep dive into model behavior with advanced diagnostic tools.

**Estimated time:** ~3-5 minutes

**What's tested:**
- 🎯 **Attention Patterns:** Visualize attention weights, detect collapsed attention, analyze head specialization
- 🔍 **Attribution Analysis:** Identify which input tokens contribute most to predictions (using Captum)
- 🛡️ **Robustness Testing:** Measure stability under input perturbations and noise

**Note:** These tests are optional but highly recommended for understanding model behavior.

In [None]:
# ==============================================================================# TIER 2 OPTIONAL DEPENDENCIES - Lazy Installation# ==============================================================================print("=" * 70)print("📦 TIER 2: OPTIONAL PACKAGE INSTALLATION")print("=" * 70)print()print("⚠️  WARNING: Installing packages may cause NumPy corruption")print("⚠️  If you encounter errors, restart runtime and skip this tier")print()print("Tier 2 requires: captum (for feature attribution analysis)")print()# Check if already installedtry:    import captum    print("✅ Captum already available")    print()except ImportError:    print("📦 Installing captum...")    print("⏳ This may take 10-15 seconds...")    print()        # Install with --no-deps to reduce corruption risk    !pip install -q --no-deps captum        print()        # Verify installation    try:        import captum        print("✅ Captum installed successfully")        print()                # Check numpy integrity after install        from numpy._core.umath import _center        print("✅ NumPy still intact after installation")        print()            except ImportError as e:        print("❌ Installation failed or NumPy corrupted")        print()        print("=" * 70)        print("RECOVERY STEPS:")        print("=" * 70)        print()        print("1. Runtime → Restart runtime")        print("2. Run all cells EXCEPT this one")        print("3. Skip to Tier 3 or just use Tier 1 results")        print()        print("Note: Tier 1 tests are the most important - Tier 2 is optional")        print()        raiseprint("=" * 70)print("✅ Tier 2 dependencies ready")print("=" * 70)print()print("You can now run the Tier 2 tests below")print()

## 📦 Optional: Install Tier 2 Dependencies

**Note:** Tier 2 tests use lazy imports - they will skip gracefully if dependencies are missing.

**To enable all Tier 2 functionality, run the cell below:**

In [None]:
# ==============================================================================# TIER 3 OPTIONAL DEPENDENCIES - Lazy Installation# ==============================================================================print("=" * 70)print("📦 TIER 3: OPTIONAL PACKAGE INSTALLATION")print("=" * 70)print()print("⚠️  WARNING: Installing packages may cause NumPy corruption")print("⚠️  If you encounter errors, restart runtime and skip this tier")print()print("Tier 3 requires:")print("  • pytorch-lightning (training utilities)")print("  • optuna (hyperparameter optimization)")print()# Check if already installedtry:    import pytorch_lightning as pl    import optuna    print("✅ Training dependencies already available")    print()except ImportError:    print("📦 Installing training dependencies...")    print("⏳ This may take 20-30 seconds...")    print()        # Install with --no-deps to reduce corruption risk    !pip install -q --no-deps pytorch-lightning torchmetrics lightning-utilities    !pip install -q --no-deps optuna alembic colorlog sqlalchemy        print()        # Verify installation    try:        import pytorch_lightning as pl        import optuna        print("✅ Training dependencies installed successfully")        print()                # Check numpy integrity after install        from numpy._core.umath import _center        print("✅ NumPy still intact after installation")        print()            except ImportError as e:        print("❌ Installation failed or NumPy corrupted")        print()        print("=" * 70)        print("RECOVERY STEPS:")        print("=" * 70)        print()        print("1. Runtime → Restart runtime")        print("2. Run all cells EXCEPT this one")        print("3. Use Tier 1/2 results only")        print()        print("Note: Tier 3 is compute-intensive and optional")        print()        raiseprint("=" * 70)print("✅ Tier 3 dependencies ready")print("=" * 70)print()print("You can now run the Tier 3 tests below")print()

## 📦 Optional: Install Tier 3 Dependencies

**Note:** Tier 3 tests use lazy imports - they will skip gracefully if dependencies are missing.

**To enable all Tier 3 functionality, run the cell below:**

⚠️ **Warning:** These dependencies are more complex and may take longer to install.

In [None]:
# Import Tier 2 test functions
from utils.test_functions import (
    test_attention_patterns,
    test_attribution_analysis,
    test_robustness
)

print("=" * 80)
print("TIER 2: ADVANCED ANALYSIS")
print("=" * 80)
print()

# Test 1: Attention Patterns
print("Test 1/3: Attention Pattern Analysis")
print("-" * 80)
try:
    attention_results = test_attention_patterns(model, config)
    if attention_results is not None:
        display(attention_results)
    print("✅ Attention analysis complete")
except Exception as e:
    print(f"⚠️ Attention analysis skipped: {e}")
print()

# Test 2: Attribution Analysis
print("Test 2/3: Input Attribution Analysis")
print("-" * 80)
try:
    attribution_results = test_attribution_analysis(model, config)
    if attribution_results is not None:
        print("\nTop Contributing Tokens:")
        for token, score in attribution_results.get("top_tokens", []):
            print(f"  {token:20s}: {score:+.4f}")
    print("✅ Attribution analysis complete")
except Exception as e:
    print(f"⚠️ Attribution analysis skipped: {e}")
print()

# Test 3: Robustness Testing
print("Test 3/3: Robustness Under Noise")
print("-" * 80)
try:
    robustness_results = test_robustness(model, config, n_samples=20)
    if robustness_results is not None:
        display(robustness_results)
    print("✅ Robustness analysis complete")
except Exception as e:
    print(f"⚠️ Robustness analysis skipped: {e}")
print()

print("=" * 80)
print("✅ TIER 2 ANALYSIS COMPLETE")
print("=" * 80)
print()
print("Next: Scroll down for Tier 3 (Training & Fine-Tuning)")

---

# 🚀 Tier 3: Training & Production Utilities

Advanced utilities for fine-tuning, hyperparameter optimization, and production benchmarking.

**Estimated time:** ~10-20 minutes (depends on training iterations)

**What's included:**
- 🎓 **Fine-Tuning:** Basic training loop with loss tracking and gradient monitoring
- 🔧 **Hyperparameter Search:** Automated optimization using Optuna (learning rate, batch size, warmup)
- 📊 **Benchmark Comparison:** Compare your model against production baselines (distilgpt2, bert-base, etc.)

**Note:** These are compute-intensive operations. Consider using GPU runtime for faster execution.

In [None]:
# Import Tier 3 training utilities
from utils.test_functions import (
    test_fine_tuning,
    test_hyperparameter_search,
    test_benchmark_comparison
)

print("=" * 80)
print("TIER 3: TRAINING & PRODUCTION UTILITIES")
print("=" * 80)
print()

# Test 1: Fine-Tuning
print("Test 1/3: Fine-Tuning Demo")
print("-" * 80)
print("Running 3 epochs of fine-tuning with synthetic data...")
try:
    fine_tune_results = test_fine_tuning(
        model, 
        config, 
        num_epochs=3,
        batch_size=2,
        learning_rate=5e-5
    )
    print(f"\nFinal Loss: {fine_tune_results['final_loss']:.4f}")
    print(f"Best Loss: {fine_tune_results['best_loss']:.4f}")
    print("✅ Fine-tuning complete")
except Exception as e:
    print(f"⚠️ Fine-tuning skipped: {e}")
print()

# Test 2: Hyperparameter Search (OPTIONAL - Comment out to skip)
print("Test 2/3: Hyperparameter Optimization")
print("-" * 80)
print("⚠️ Skipping hyperparameter search (compute-intensive)")
print("To enable: uncomment the code block below")
print()
# Uncomment to run:
# try:
#     hp_results = test_hyperparameter_search(
#         model,
#         config,
#         n_trials=5,
#         epochs_per_trial=2
#     )
#     print("\nBest Parameters:")
#     for param, value in hp_results['best_params'].items():
#         print(f"  {param}: {value}")
#     print("✅ Hyperparameter search complete")
# except Exception as e:
#     print(f"⚠️ Hyperparameter search failed: {e}")

# Test 3: Benchmark Comparison
print("Test 3/3: Benchmark Against Baseline")
print("-" * 80)
print("Comparing against distilgpt2 baseline...")
try:
    benchmark_results = test_benchmark_comparison(
        model,
        config,
        baseline_model="distilgpt2",
        n_samples=10
    )
    if benchmark_results is not None:
        display(benchmark_results)
    print("✅ Benchmark comparison complete")
except Exception as e:
    print(f"⚠️ Benchmark comparison skipped: {e}")
print()

print("=" * 80)
print("✅ TIER 3 TRAINING UTILITIES COMPLETE")
print("=" * 80)
print()
print("🎉 All testing tiers complete! Your model is production-ready.")