# üß™ Transformer Builder - Advanced Testing Lab

**Welcome! This notebook tests your custom transformer architecture.**

---

## üöÄ **Quick Start (3 Steps)**

### **STEP 1:** Paste Your Gist ID
‚Üì Scroll down to Cell 3 and paste the Gist ID you received from Transformer Builder

### **STEP 2:** Run All Cells  
Click **Runtime ‚Üí Run all** (or run cells one-by-one)

### **STEP 3:** Review Test Results
Your model will be validated through 3 testing tiers

---

## üìã **What's Included:**

- ‚úÖ **Tier 1:** Critical validation (shape, gradients, numerical stability)
- üî¨ **Tier 2:** Advanced analysis (attention patterns, robustness, profiling)
- üöÄ **Tier 3:** Training utilities (fine-tuning, hyperparameter sweeps, benchmarks)

---

## ‚ö†Ô∏è **First Time Setup:**

If this is your first time OR you're continuing from a previous session:

1. **Runtime** ‚Üí **Restart runtime** (takes 5 seconds)
2. **Edit** ‚Üí **Clear all outputs** (optional, cleans up UI)
3. **Scroll down to Cell 3** ‚Üí Paste your Gist ID
4. **Runtime** ‚Üí **Run all**

This ensures a clean environment and prevents dependency conflicts.

---

**Source:** Generated from [Transformer Builder](https://transformer-builder.com)

# üß™ Transformer Builder - Advanced Testing Lab

Welcome! This notebook provides comprehensive testing and training capabilities for your custom transformer architecture.

**What's included:**
- ‚úÖ **Tier 1:** Critical validation (shape, gradients, numerical stability)
- üî¨ **Tier 2:** Advanced analysis (attention patterns, robustness, profiling)
- üöÄ **Tier 3:** Training utilities (fine-tuning, hyperparameter sweeps, benchmarks)

**Quick Start:**
1. Click "Run all" (Runtime ‚Üí Run all)
2. Review Tier 1 results (should complete in ~1 minute)
3. Explore Tier 2/3 sections as needed

**Source:** Generated from [Transformer Builder](https://transformer-builder.com)

---

---

## üìã **STEP 1: Paste Your Gist ID**

When you exported from **Transformer Builder**, you received a **Gist ID**.

**Paste it in the cell below and run it.**

If you don't have a Gist ID yet, go back to Transformer Builder and click **"Export to Colab"**.

In [None]:
# ==============================================================================
# GIST ID INPUT - Paste the ID from Transformer Builder
# ==============================================================================

#@title üì• **Paste Your Gist ID Here**
GIST_ID = ""  #@param {type:"string"}

#@markdown ---
#@markdown **Where to find your Gist ID:**
#@markdown 1. Go to Transformer Builder
#@markdown 2. Click "Export to Colab"
#@markdown 3. Copy the Gist ID from the modal
#@markdown 4. Paste it above and run this cell

if not GIST_ID or not GIST_ID.strip():
    print("=" * 70)
    print("‚ö†Ô∏è  NO GIST ID PROVIDED")
    print("=" * 70)
    print()
    print("Please paste your Gist ID in the field above and re-run this cell.")
    print()
    print("If you don't have a Gist ID:")
    print("  1. Go to Transformer Builder")
    print("  2. Click 'Export to Colab'")
    print("  3. Copy the Gist ID from the modal")
    print("  4. Come back here and paste it")
    print()
    raise ValueError("Gist ID is required to load your custom model")
else:
    # Validate format
    import re
    if not re.fullmatch(r"[A-Za-z0-9]+", GIST_ID.strip()):
        print("=" * 70)
        print("‚ö†Ô∏è  INVALID GIST ID FORMAT")
        print("=" * 70)
        print()
        print(f"The Gist ID you entered: {GIST_ID!r}")
        print()
        print("Gist IDs should be alphanumeric (e.g., 'abc123def456')")
        print("Please check and re-enter.")
        print()
        raise ValueError("Invalid Gist ID format")
    
    # Store for later use
    GIST_ID = GIST_ID.strip()
    
    print("=" * 70)
    print("‚úÖ GIST ID SAVED")
    print("=" * 70)
    print()
    print(f"Gist ID: {GIST_ID}")
    print()
    print("You can now proceed to run the cells below to:")
    print("  1. Install dependencies")
    print("  2. Load your custom model")
    print("  3. Run tests")
    print()
    print("üí° Tip: Click 'Runtime ‚Üí Run all' to execute everything automatically")

In [None]:
# ==============================================================================
# RUNTIME FRESHNESS DETECTION - Prevents reused runtimes with corrupted packages
# ==============================================================================

import os

# Check if this runtime was previously used
RUNTIME_MARKER = "/tmp/transformer_builder_runtime_used"

print("=" * 70)
print("üîç RUNTIME FRESHNESS CHECK")
print("=" * 70)
print()

if os.path.exists(RUNTIME_MARKER):
    print("‚ö†Ô∏è  WARNING: This runtime was previously used!")
    print()
    print("Reusing runtimes can cause dependency conflicts.")
    print()
    print("‚úÖ RECOMMENDED: Restart runtime for clean environment")
    print("   (Runtime ‚Üí Restart runtime ‚Üí Run all)")
    print()
    print("=" * 70)
    
    # Give user option to continue anyway (advanced users)
    response = input("Continue with reused runtime anyway? [y/N]: ")
    
    if response.lower() != 'y':
        print()
        print("‚úÖ Good choice! Please restart the runtime and try again.")
        raise RuntimeError("Runtime restart recommended for clean environment")
    else:
        print()
        print("‚ö†Ô∏è  Proceeding with reused runtime...")
        print()
else:
    # Mark runtime as used
    with open(RUNTIME_MARKER, 'w') as f:
        f.write("used")
    
    print("‚úÖ Fresh runtime detected!")
    print()

print("üìå Version: v3.3.2 (2025-01-13)")
print("üìå Fix: Proactive numpy repair + minimal dependencies")
print()
print("=" * 70)
print("‚úÖ Runtime check complete")
print("=" * 70)
print()

In [None]:
# ==============================================================================
# DEPENDENCY INSTALLATION - v3.3.2 minimal dependencies (prevents corruption)
# ==============================================================================

from IPython.display import clear_output
import gc

print("üì¶ Installing core dependencies...")
print()

# ==============================================================================
# INSTALLATION: Minimal safe dependencies
# ==============================================================================

# Step 1: Upgrade pip (silent)
!pip install --upgrade pip -qq

# Step 2: Install minimal safe dependencies (only torchinfo, pytest, pytest-cov)
!wget -qq https://raw.githubusercontent.com/matt-hans/transformer-builder-colab-templates/main/requirements-colab.txt -O requirements-colab.txt
!pip install -qq -r requirements-colab.txt

# Step 3: Install pytorch-lightning with --no-deps (prevents dependency hell)
!pip install -qq --no-deps 'pytorch-lightning>=2.4.0,<2.6.0'
!pip install -qq --no-deps 'torchmetrics>=1.3.0,<2.0.0'
!pip install -qq --no-deps 'lightning-utilities>=0.10.0'

# Clear installation output to prevent console overflow
clear_output(wait=True)

# ==============================================================================
# POST-INSTALLATION VERIFICATION: Ensure numpy still intact
# ==============================================================================

def check_numpy_integrity():
    try:
        from numpy._core.umath import _center
        return True
    except ImportError:
        return False

if not check_numpy_integrity():
    print("=" * 70)
    print("‚ùå UNEXPECTED: NumPy corrupted DURING installation!")
    print("=" * 70)
    print()
    print("This should NOT happen with v3.3.2 minimal dependencies.")
    print()
    print("üêõ PLEASE REPORT THIS BUG:")
    print("   GitHub: https://github.com/matt-hans/transformer-builder-colab-templates/issues")
    print("   Include: Screenshot + Colab version")
    print()
    print("üîß WORKAROUND: Runtime ‚Üí Restart runtime ‚Üí Run all")
    print()
    raise ImportError("NumPy corrupted during installation - critical bug!")

# ==============================================================================
# VERIFICATION: Check all critical imports
# ==============================================================================

print("=" * 70)
print("INSTALLATION VERIFICATION")
print("=" * 70)

try:
    import numpy as np
    import torch
    import pytorch_lightning as pl
    from transformers import AutoTokenizer
    import torchinfo
    
    print(f"‚úÖ numpy: {np.__version__} (Colab pre-installed)")
    print(f"‚úÖ torch: {torch.__version__} (Colab pre-installed)")
    print(f"‚úÖ pytorch-lightning: {pl.__version__}")
    print(f"‚úÖ transformers: (Colab pre-installed)")
    print(f"‚úÖ torchinfo: {torchinfo.__version__}")
    print(f"‚úÖ numpy C extensions: intact")
    
    # Check for GPU
    if torch.cuda.is_available():
        print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
        print(f"‚úÖ CUDA: {torch.version.cuda}")
    else:
        print("‚ö†Ô∏è  GPU: Not available (CPU mode)")
    
    # Cleanup memory
    gc.collect()
    
    print()
    print("=" * 70)
    print("‚úÖ INSTALLATION COMPLETE - All core dependencies ready!")
    print("=" * 70)
    print()
    print("üìù Note: Tier 1 tests work immediately.")
    print("   Tier 2/3 require optional packages (see cells before those tiers).")
    
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print()
    print("üîß Troubleshooting:")
    print("   1. Runtime ‚Üí Restart runtime")
    print("   2. Run all cells from beginning")
    print("   3. If persists, report as bug")
    raise

In [None]:
# ==============================================================================
# DOWNLOAD UTILS PACKAGE
# ==============================================================================

print("üì¶ Downloading test utilities package...")

# Remove old utils directory if exists
!rm -rf utils/

# Download complete utils package from GitHub
!git clone --depth 1 --branch main https://github.com/matt-hans/transformer-builder-colab-templates.git temp_repo 2>/dev/null

# Copy utils directory
!cp -r temp_repo/utils ./

# Cleanup
!rm -rf temp_repo

# Verify package structure
import sys
import os

# Add current directory to Python path
if './' not in sys.path:
    sys.path.insert(0, './')

# Verify utils package is importable
try:
    import utils
    print(f"‚úÖ Utils package loaded (version {utils.__version__})")
    
    # Verify package structure
    utils_path = os.path.join(os.getcwd(), 'utils')
    subdirs = ['adapters', 'tokenization', 'training', 'ui']
    
    for subdir in subdirs:
        subdir_path = os.path.join(utils_path, subdir)
        if os.path.exists(subdir_path):
            print(f"‚úÖ {subdir}/ directory found")
        else:
            print(f"‚ö†Ô∏è  {subdir}/ directory missing")
    
    # Test importing test functions (backward compatibility)
    from utils import (
        test_shape_robustness,
        test_gradient_flow,
        test_output_stability,
        run_all_tier1_tests
    )
    print("‚úÖ Test functions importable")
    
    print("\n‚úÖ Utils package ready!")
    
except ImportError as e:
    print(f"‚ùå Failed to import utils package: {e}")
    print("Falling back to direct file download...")
    # Fallback: download test_functions.py directly
    !wget -q https://raw.githubusercontent.com/matt-hans/transformer-builder-colab-templates/main/utils/test_functions.py

In [ ]:
# ==============================================================================
# LOAD CUSTOM MODEL - v3.4.0 (Simple Modal Approach)
# ==============================================================================

import os, re, json, urllib.request, urllib.error

print("=" * 70)
print("MODEL LOADING")
print("=" * 70)
print()

# ==============================================================================
# VERIFY GIST ID WAS PROVIDED
# ==============================================================================

if 'GIST_ID' not in globals() or not GIST_ID:
    print("‚ùå ERROR: No Gist ID found!")
    print()
    print("=" * 70)
    print("üîô GO BACK TO CELL 3")
    print("=" * 70)
    print()
    print("You must run Cell 3 first to provide your Gist ID.")
    print()
    print("Steps:")
    print("  1. Scroll up to Cell 3")
    print("  2. Paste your Gist ID from Transformer Builder")
    print("  3. Run Cell 3")
    print("  4. Come back and run this cell")
    print()
    raise ValueError("Gist ID required - please run Cell 3 first")

gist_id = GIST_ID
model_name = "Model"  # Default name, will be overridden from config

print(f"üì• Loading model from GitHub Gist: {gist_id}")
print()

# ==============================================================================
# FETCH GIST AND LOAD MODEL FILES
# ==============================================================================

def _fetch_gist(gid: str) -> dict:
    """Fetch Gist data from GitHub API."""
    url = f"https://api.github.com/gists/{gid}"
    req = urllib.request.Request(url, headers={
        "Accept": "application/vnd.github+json",
        "User-Agent": "transformer-builder-colab"
    })
    try:
        with urllib.request.urlopen(req, timeout=20) as resp:
            return json.loads(resp.read().decode("utf-8"))
    except urllib.error.HTTPError as e:
        detail = f"HTTP {e.code}"
        try:
            body = e.read().decode("utf-8")
            if "rate limit" in body.lower():
                detail += " - GitHub API rate limit (try again in an hour)"
            elif e.code == 404:
                detail += " - Gist not found (check your Gist ID)"
        except:
            pass
        raise RuntimeError(f"GitHub API error: {detail}") from e
    except Exception as e:
        raise RuntimeError(f"Network error: {e}") from e

def _write(path: str, text: str):
    """Write text to file."""
    with open(path, "w") as f:
        f.write(text)

# Fetch Gist
try:
    gist_data = _fetch_gist(gist_id)
    files = gist_data.get("files") or {}
    
    # Check for required files
    if "model.py" not in files:
        raise RuntimeError("Gist is missing 'model.py' - please re-export from Transformer Builder")
    if "config.json" not in files:
        raise RuntimeError("Gist is missing 'config.json' - please re-export from Transformer Builder")
    
    model_code = files["model.py"].get("content", "")
    config_json = files["config.json"].get("content", "")
    
    if not model_code or not config_json:
        raise RuntimeError("Empty content in model.py or config.json")
    
    # Write to files
    _write("custom_transformer.py", model_code)
    _write("config.json", config_json)
    
    print(f"‚úÖ Model loaded successfully!")
    print(f"‚úÖ Gist URL: {gist_data.get('html_url', 'N/A')}")
    print(f"‚úÖ Model code: {len(model_code):,} bytes")
    print(f"‚úÖ Config: {len(config_json):,} bytes")
    print()
    
    # Parse model name from config if available
    try:
        config_dict = json.loads(config_json)
        if 'model_name' in config_dict:
            model_name = config_dict['model_name']
            print(f"‚úÖ Model name: {model_name}")
            print()
    except:
        pass

except Exception as e:
    print(f"‚ùå Failed to load model from Gist!")
    print()
    print(f"Error: {e}")
    print()
    print("=" * 70)
    print("TROUBLESHOOTING")
    print("=" * 70)
    print()
    print("Common issues:")
    print("  1. Check your Gist ID is correct (go back to Cell 3)")
    print("  2. Ensure you exported from Transformer Builder successfully")
    print("  3. Check you're not hitting GitHub rate limit (60 requests/hour)")
    print("  4. Try re-exporting from Transformer Builder")
    print()
    print("If the problem persists:")
    print(f"  ‚Ä¢ Gist URL: https://gist.github.com/{gist_id}")
    print("  ‚Ä¢ Verify the Gist contains model.py and config.json")
    print()
    raise

print("=" * 70)
print("‚úÖ MODEL LOADING COMPLETE")
print("=" * 70)
print()
print("Next: Continue to model instantiation and testing below!")
print()

# Store model_name for next cell
params = {"name": model_name}

## üìÑ View Loaded Model Code

This cell displays the Python code that was loaded from your Transformer Builder export. You can review the architecture before running tests.

In [None]:
# Display the loaded model code for transparency
print("=" * 80)
print("üìÑ LOADED MODEL CODE (custom_transformer.py)")
print("=" * 80)
print()

with open('custom_transformer.py', 'r') as f:
    model_code_display = f.read()

# Use syntax highlighting
from IPython.display import Code
display(Code(model_code_display, language='python'))

print()
print("=" * 80)
print("üìã MODEL CONFIGURATION (config.json)")
print("=" * 80)
print()

with open('config.json', 'r') as f:
    config_display = json.load(f)

# Pretty print JSON
print(json.dumps(config_display, indent=2))
print()
print("‚úÖ You can now proceed to run the model instantiation and tests below!")

## Dynamic Dependency Detection

Automatically detect and install any custom dependencies your model needs.

In [None]:
import ast
import subprocess
import sys

# Parse imports from generated code
with open('custom_transformer.py', 'r') as f:
    source_code = f.read()
    tree = ast.parse(source_code)

# Extract all imports
imports = set()
for node in ast.walk(tree):
    if isinstance(node, ast.Import):
        for alias in node.names:
            imports.add(alias.name.split('.')[0])
    elif isinstance(node, ast.ImportFrom):
        if node.module:
            imports.add(node.module.split('.')[0])

print(f"Detected imports: {', '.join(sorted(imports))}")

# Standard library modules (don't need pip install)
stdlib_modules = {
    'abc', 'collections', 'dataclasses', 'functools', 'json', 'math',
    'typing', 'warnings', 'os', 'sys', 're', 'time', 'copy'
}

# Already installed
installed_modules = {
    'torch', 'transformers', 'numpy', 'scipy', 'matplotlib',
    'pandas', 'seaborn', 'tqdm', 'torchinfo', 'captum', 'optuna'
}

# Find missing packages
missing = imports - stdlib_modules - installed_modules

if missing:
    print(f"\nInstalling additional dependencies: {', '.join(missing)}")
    for package in missing:
        try:
            subprocess.check_call(
                [sys.executable, '-m', 'pip', 'install', '-q', package],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL
            )
            print(f"  ‚úÖ Installed {package}")
        except subprocess.CalledProcessError:
            print(f"  ‚ö†Ô∏è Failed to install {package} (may not be a pip package)")
else:
    print("\n‚úÖ All dependencies already installed")

## Import and Instantiate Model

Load your custom transformer and prepare for testing.

In [None]:
import torch
import torch.nn as nn
from torchinfo import summary

# Import the custom model
exec(open('custom_transformer.py').read())

# Load config
with open('config.json') as f:
    config_dict = json.load(f)

# Find the model class
model_class = None
for name, obj in list(globals().items()):
    if isinstance(obj, type) and issubclass(obj, nn.Module) and obj is not nn.Module:
        if name == params['name']:
            model_class = obj
            break

if model_class is None:
    # Fallback: find any nn.Module subclass
    for name, obj in list(globals().items()):
        if isinstance(obj, type) and issubclass(obj, nn.Module) and obj is not nn.Module:
            model_class = obj
            print(f"‚ö†Ô∏è Using {name} (expected {params['name']})")
            break

if model_class:
    # Instantiate model
    try:
        model = model_class(**config_dict)
        model.eval()
        
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"‚úÖ Model instantiated: {model_class.__name__}")
        print(f"‚úÖ Total parameters: {total_params:,}")
        print(f"‚úÖ Trainable parameters: {trainable_params:,}")
        
        # Move to GPU if available
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        print(f"‚úÖ Device: {device}")
        
        # Display model summary
        print("\n--- Model Summary ---")
        try:
            # Create dummy input based on config
            vocab_size = config_dict.get('vocab_size', 50257)
            dummy_input = torch.randint(0, vocab_size, (1, 32)).to(device)
            summary(model, input_data=dummy_input, depth=3)
        except Exception as e:
            print(f"‚ö†Ô∏è Could not generate summary: {e}")
        
    except Exception as e:
        print(f"‚ùå Failed to instantiate model: {e}")
        raise
else:
    raise RuntimeError(f"Could not find model class '{params['name']}' in generated code")

# Create config object for test functions
class ModelConfig:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

config = ModelConfig(**config_dict)
print("\n‚úÖ Ready for testing!")

---

# üîç Tier 1: Critical Validation

These tests verify your model is mathematically sound and ready for training.

**Estimated time:** ~1 minute

**What's tested:**
- ‚úÖ Shape validation across edge cases
- ‚úÖ Gradient flow (detect vanishing/exploding gradients)
- ‚úÖ Numerical stability (NaN/Inf detection)
- ‚úÖ Parameter initialization quality
- ‚úÖ Memory footprint scaling
- ‚úÖ Inference speed benchmarks

In [None]:
# Import test utilities from the cloned utils package
from utils.test_functions import (
    test_shape_robustness,
    test_gradient_flow,
    test_output_stability,
    test_parameter_initialization,
    test_memory_footprint,
    test_inference_speed
)

print("‚úÖ Test functions loaded from utils package")

In [None]:
print("=" * 80)
print("TIER 1: CRITICAL VALIDATION")
print("=" * 80)
print()

# Test 1: Shape Robustness
print("Test 1/6: Shape Validation")
print("-" * 80)
shape_results = test_shape_robustness(model, config)
display(shape_results)
print()

# Test 2: Gradient Flow
print("Test 2/6: Gradient Flow Analysis")
print("-" * 80)
grad_results = test_gradient_flow(model, config)
display(grad_results)
print()

# Test 3: Output Stability
print("Test 3/6: Numerical Stability")
print("-" * 80)
stability_stats = test_output_stability(model, config, n_samples=100)
print()

# Test 4: Parameter Initialization
print("Test 4/6: Parameter Initialization")
print("-" * 80)
param_results = test_parameter_initialization(model)
display(param_results)
print()

# Test 5: Memory Footprint
print("Test 5/6: Memory Footprint Analysis")
print("-" * 80)
memory_results = test_memory_footprint(model, config)
display(memory_results)
print()

# Test 6: Inference Speed
print("Test 6/6: Inference Speed Benchmark")
print("-" * 80)
speed_stats = test_inference_speed(model, config, n_trials=50)
print()

print("=" * 80)
print("‚úÖ TIER 1 VALIDATION COMPLETE")
print("=" * 80)
print()
print("All critical tests passed! Your model is mathematically sound.")
print()
print("üìù Next steps:")
print("   ‚Ä¢ Tier 2: Advanced analysis (attention patterns, attribution)")
print("     ‚Üí Install optional dependencies in the cell before Tier 2")
print("     ‚Üí Then run Tier 2 tests")
print()
print("   ‚Ä¢ Tier 3: Training utilities (fine-tuning, hyperparameter search)")
print("     ‚Üí Install optional dependencies in the cell before Tier 3")
print("     ‚Üí Then run Tier 3 tests")

---

# üî¨ Tier 2: Advanced Analysis

Deep dive into model behavior with advanced diagnostic tools.

**Estimated time:** ~3-5 minutes

**What's tested:**
- üéØ **Attention Patterns:** Visualize attention weights, detect collapsed attention, analyze head specialization
- üîç **Attribution Analysis:** Identify which input tokens contribute most to predictions (using Captum)
- üõ°Ô∏è **Robustness Testing:** Measure stability under input perturbations and noise

**Note:** These tests are optional but highly recommended for understanding model behavior.

In [None]:
# ==============================================================================
# TIER 2 OPTIONAL DEPENDENCIES - Run this cell to enable advanced analysis
# ==============================================================================

print("üì¶ Installing Tier 2 dependencies (captum for attribution analysis)...")
print()
print("‚è≥ This may take 10-15 seconds...")
print()

# Install captum (for feature attribution analysis)
# Using --no-deps to prevent numpy corruption
!pip install -qq --no-deps captum

# Verify installation
try:
    import captum
    from captum.attr import IntegratedGradients
    print(f"‚úÖ captum: {captum.__version__} (attribution analysis)")
    print()
    print("=" * 70)
    print("‚úÖ Tier 2 dependencies installed successfully!")
    print("=" * 70)
    print()
    print("You can now run all Tier 2 tests below.")
except ImportError as e:
    print(f"‚ö†Ô∏è  Installation incomplete: {e}")
    print("Tier 2 tests will skip features requiring captum.")
print()

## üì¶ Optional: Install Tier 2 Dependencies

**Note:** Tier 2 tests use lazy imports - they will skip gracefully if dependencies are missing.

**To enable all Tier 2 functionality, run the cell below:**

In [None]:
# ==============================================================================
# TIER 3 OPTIONAL DEPENDENCIES - Run this cell to enable training utilities
# ==============================================================================

print("üì¶ Installing Tier 3 dependencies (optuna for hyperparameter search)...")
print()
print("‚è≥ This may take 20-30 seconds...")
print()

# Install optuna (for hyperparameter optimization)
# Using --no-deps to prevent numpy corruption
!pip install -qq --no-deps optuna

# Install optuna's critical dependencies separately (avoids scipy conflicts)
!pip install -qq alembic colorlog sqlalchemy

# Optionally install datasets if needed for fine-tuning with real data
# !pip install -qq --no-deps datasets
# !pip install -qq pyarrow dill xxhash multiprocess

# Verify installation
try:
    import optuna
    print(f"‚úÖ optuna: {optuna.__version__} (hyperparameter optimization)")
    print()
    print("=" * 70)
    print("‚úÖ Tier 3 dependencies installed successfully!")
    print("=" * 70)
    print()
    print("You can now run all Tier 3 tests below.")
    print()
    print("üí° Tip: Uncomment the 'datasets' installation lines above if you need")
    print("   to fine-tune with real datasets (currently uses synthetic data).")
except ImportError as e:
    print(f"‚ö†Ô∏è  Installation incomplete: {e}")
    print("Tier 3 tests will skip features requiring optuna.")
print()

## üì¶ Optional: Install Tier 3 Dependencies

**Note:** Tier 3 tests use lazy imports - they will skip gracefully if dependencies are missing.

**To enable all Tier 3 functionality, run the cell below:**

‚ö†Ô∏è **Warning:** These dependencies are more complex and may take longer to install.

In [None]:
# Import Tier 2 test functions
from utils.test_functions import (
    test_attention_patterns,
    test_attribution_analysis,
    test_robustness
)

print("=" * 80)
print("TIER 2: ADVANCED ANALYSIS")
print("=" * 80)
print()

# Test 1: Attention Patterns
print("Test 1/3: Attention Pattern Analysis")
print("-" * 80)
try:
    attention_results = test_attention_patterns(model, config)
    if attention_results is not None:
        display(attention_results)
    print("‚úÖ Attention analysis complete")
except Exception as e:
    print(f"‚ö†Ô∏è Attention analysis skipped: {e}")
print()

# Test 2: Attribution Analysis
print("Test 2/3: Input Attribution Analysis")
print("-" * 80)
try:
    attribution_results = test_attribution_analysis(model, config)
    if attribution_results is not None:
        print("\nTop Contributing Tokens:")
        for token, score in attribution_results.get("top_tokens", []):
            print(f"  {token:20s}: {score:+.4f}")
    print("‚úÖ Attribution analysis complete")
except Exception as e:
    print(f"‚ö†Ô∏è Attribution analysis skipped: {e}")
print()

# Test 3: Robustness Testing
print("Test 3/3: Robustness Under Noise")
print("-" * 80)
try:
    robustness_results = test_robustness(model, config, n_samples=20)
    if robustness_results is not None:
        display(robustness_results)
    print("‚úÖ Robustness analysis complete")
except Exception as e:
    print(f"‚ö†Ô∏è Robustness analysis skipped: {e}")
print()

print("=" * 80)
print("‚úÖ TIER 2 ANALYSIS COMPLETE")
print("=" * 80)
print()
print("Next: Scroll down for Tier 3 (Training & Fine-Tuning)")

---

# üöÄ Tier 3: Training & Production Utilities

Advanced utilities for fine-tuning, hyperparameter optimization, and production benchmarking.

**Estimated time:** ~10-20 minutes (depends on training iterations)

**What's included:**
- üéì **Fine-Tuning:** Basic training loop with loss tracking and gradient monitoring
- üîß **Hyperparameter Search:** Automated optimization using Optuna (learning rate, batch size, warmup)
- üìä **Benchmark Comparison:** Compare your model against production baselines (distilgpt2, bert-base, etc.)

**Note:** These are compute-intensive operations. Consider using GPU runtime for faster execution.

In [None]:
# Import Tier 3 training utilities
from utils.test_functions import (
    test_fine_tuning,
    test_hyperparameter_search,
    test_benchmark_comparison
)

print("=" * 80)
print("TIER 3: TRAINING & PRODUCTION UTILITIES")
print("=" * 80)
print()

# Test 1: Fine-Tuning
print("Test 1/3: Fine-Tuning Demo")
print("-" * 80)
print("Running 3 epochs of fine-tuning with synthetic data...")
try:
    fine_tune_results = test_fine_tuning(
        model, 
        config, 
        num_epochs=3,
        batch_size=2,
        learning_rate=5e-5
    )
    print(f"\nFinal Loss: {fine_tune_results['final_loss']:.4f}")
    print(f"Best Loss: {fine_tune_results['best_loss']:.4f}")
    print("‚úÖ Fine-tuning complete")
except Exception as e:
    print(f"‚ö†Ô∏è Fine-tuning skipped: {e}")
print()

# Test 2: Hyperparameter Search (OPTIONAL - Comment out to skip)
print("Test 2/3: Hyperparameter Optimization")
print("-" * 80)
print("‚ö†Ô∏è Skipping hyperparameter search (compute-intensive)")
print("To enable: uncomment the code block below")
print()
# Uncomment to run:
# try:
#     hp_results = test_hyperparameter_search(
#         model,
#         config,
#         n_trials=5,
#         epochs_per_trial=2
#     )
#     print("\nBest Parameters:")
#     for param, value in hp_results['best_params'].items():
#         print(f"  {param}: {value}")
#     print("‚úÖ Hyperparameter search complete")
# except Exception as e:
#     print(f"‚ö†Ô∏è Hyperparameter search failed: {e}")

# Test 3: Benchmark Comparison
print("Test 3/3: Benchmark Against Baseline")
print("-" * 80)
print("Comparing against distilgpt2 baseline...")
try:
    benchmark_results = test_benchmark_comparison(
        model,
        config,
        baseline_model="distilgpt2",
        n_samples=10
    )
    if benchmark_results is not None:
        display(benchmark_results)
    print("‚úÖ Benchmark comparison complete")
except Exception as e:
    print(f"‚ö†Ô∏è Benchmark comparison skipped: {e}")
print()

print("=" * 80)
print("‚úÖ TIER 3 TRAINING UTILITIES COMPLETE")
print("=" * 80)
print()
print("üéâ All testing tiers complete! Your model is production-ready.")