In [1]:
# === VISION TRANSFORMER FROM FIRST PRINCIPLES ===

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

# Import our clean backend modules
from backend import (
    setup_data_loaders, debug_data_shapes,
    VisionTransformerEncoder, debug_model_architecture, visualize_model_data_flow,
    train_model_with_debugging,
    analyze_training_history, analyze_predictions, show_prediction_mistakes,
    show_learning_insights
)

print("Vision Transformer Learning Journey")
print("==================================")
print("Understanding transformers from first principles")
print()

# Setup data loaders
train_patch_loader, val_patch_loader, train_loader, val_loader = setup_data_loaders()

Vision Transformer Learning Journey
Understanding transformers from first principles

Setting up data loaders...
mnist dataset loaded, train data size: 50000 validation data size: 10000
✓ Data ready: 391 train batches, 79 val batches
✓ Image transformation: [128, 1, 28, 28] → [128, 16, 1, 7, 7]
  28x28 images split into 16 patches of 7x7 each



In [None]:
# === TRAINING ===

print("Training the Model")
print("==================")

# Train the model with clean insights
history = train_model_with_debugging(
    model=model,
    train_loader=train_patch_loader,
    val_loader=val_patch_loader,
    epochs=15,
    lr=0.0001
)


In [None]:
# === ANALYSIS ===

print("Analyzing Results")
print("================")

# Training analysis with visualization
analyze_training_history(history)

# Prediction analysis
prediction_results = analyze_predictions(model, val_patch_loader)

# Show mistakes to understand failure cases
show_prediction_mistakes(model, val_loader, val_patch_loader)


In [None]:
# === BUILD VISION TRANSFORMER ===

print("Building Vision Transformer")
print("==========================")

# Create the model
model = VisionTransformerEncoder(debug=False)  # We'll debug manually

# Get sample data for visualization
sample_patches, sample_labels = next(iter(train_patch_loader))
sample_images, _ = next(iter(train_loader))

print(f"Sample batch - Patches: {sample_patches.shape}, Labels: {sample_labels.shape}")
print(f"Model ready! Now let's see how data flows through it...")
print()


In [None]:
# === MODEL DATA FLOW VISUALIZATION ===

print("Model Data Flow - See How Transformers Process Images")
print("====================================================")

# Show complete data transformation through the model
model_flow = visualize_model_data_flow(
    model=model,
    sample_data=sample_patches,
    original_images=sample_images,
    labels=sample_labels
)

print("This shows the complete journey:")
print("1. Image → Patches → Flattened → Embedded")
print("2. Add positional info → Transformer layers")  
print("3. Self-attention processing → Global pooling")
print("4. Final classification → Prediction")
print()

# Model architecture overview
debug_model_architecture(model, sample_patches[:2])


In [None]:
# === INSIGHTS ===

print("Key Insights")
print("===========")

# Show learning insights
show_learning_insights(history)

print("What we learned about Vision Transformers:")
print("• Patches allow treating images as sequences")
print("• Self-attention captures spatial relationships")
print("• Multi-head attention captures different features")
print("• Positional embeddings encode patch locations")
print("• Layer normalization stabilizes training")
print("• Residual connections help gradient flow")

print(f"\nModel summary:")
print(f"• Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"• Best accuracy: {max(history['val_accuracies']):.2f}%")
print(f"• Epochs trained: {len(history['train_losses'])}")

print("\nNext steps:")
print("• Experiment with different hyperparameters")
print("• Try on more complex datasets")
print("• Visualize attention patterns")
print("• Compare with CNN architectures")

In [None]:
# === OPTIONAL: DEEPER ANALYSIS ===

print("Optional: Deeper Analysis")
print("========================")
print("Uncomment sections below for additional insights:")
print()

# Additional data analysis
debug_data_shapes(train_patch_loader)

# Batch statistics
from backend.data_processing import analyze_batch_statistics
analyze_batch_statistics(train_patch_loader)

# Transformer component explanation  
from backend.transformer_architecture import explain_transformer_components
explain_transformer_components()

# Training health analysis
from backend.training_engine import analyze_training_health
analyze_training_health(history)

print("These tools help you understand the details behind the scenes!")