In [1]:
from loaders import MnistLoader, SquareImageSplitingLoader

mnist_loader = MnistLoader()

train_loader, validation_loader = mnist_loader.get_loaders()

square_image_spliting_loader = SquareImageSplitingLoader(train_loader)

mnist dataset loaded, train data size: 50000 validation data size: 10000


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F

class SimpleVisionTransformer(nn.Module):
    """
    A simple Vision Transformer for MNIST classification
    Focus: Understanding the basics step by step
    """
    def __init__(self, patch_dim=49, embed_dim=32, num_patches=16, num_classes=10):
        super().__init__()
        
        # Step 1: Patch Embedding
        self.patch_embedding = nn.Linear(patch_dim, embed_dim)
        print(f"📦 Patch embedding: {patch_dim} -> {embed_dim}")
        
        # Step 2: Positional Encoding (learnable - simpler than sinusoidal)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        print(f"📍 Positional embedding: {self.pos_embedding.shape}")
        
        # Step 3: Self-Attention (single head)
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        print(f"🔍 Self-attention matrices: 3 x ({embed_dim} -> {embed_dim})")
        
        # Step 4: Feed Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),  # Expand
            nn.ReLU(),
            nn.Linear(embed_dim * 2, embed_dim),  # Contract
        )
        print(f"🔄 FFN: {embed_dim} -> {embed_dim * 2} -> {embed_dim}")
        
        # Step 5: Layer Normalization
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # Step 6: Classification Head
        # We'll use global average pooling across patches
        self.classifier = nn.Linear(embed_dim, num_classes)
        print(f"🎯 Classifier: {embed_dim} -> {num_classes}")
        
        self.embed_dim = embed_dim
        
    def forward(self, x):
        """
        Forward pass with detailed comments for learning
        Input: x shape [batch_size, 16, 1, 7, 7] - 16 patches per image
        """
        batch_size = x.size(0)
        
        # Step 1: Flatten patches and embed them
        x = x.flatten(start_dim=2)  # [batch_size, 16, 49]
        x = self.patch_embedding(x)  # [batch_size, 16, 32]
        
        # Step 2: Add positional encoding
        x = x + self.pos_embedding  # [batch_size, 16, 32]
        
        # Step 3: Self-Attention Block
        # Save input for residual connection
        residual = x
        
        # Apply layer norm BEFORE attention (Pre-LN architecture)
        x = self.norm1(x)
        
        # Compute Q, K, V
        Q = self.W_q(x)  # [batch_size, 16, 32]
        K = self.W_k(x)  # [batch_size, 16, 32] 
        V = self.W_v(x)  # [batch_size, 16, 32]
        
        # Attention mechanism
        attention_scores = Q @ K.transpose(-2, -1) / (self.embed_dim ** 0.5)  # [batch_size, 16, 16]
        attention_weights = F.softmax(attention_scores, dim=-1)
        attended_values = attention_weights @ V  # [batch_size, 16, 32]
        
        # Add residual connection
        x = residual + attended_values  # [batch_size, 16, 32]
        
        # Step 4: Feed Forward Block
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = residual + x  # [batch_size, 16, 32]
        
        # Step 5: Classification
        # Global average pooling across patches
        x = x.mean(dim=1)  # [batch_size, 32] - average across all patches
        
        # Final classification
        logits = self.classifier(x)  # [batch_size, 10]
        
        return logits

# Create the model
print("=== CREATING SIMPLE VISION TRANSFORMER ===")
model = SimpleVisionTransformer()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n📊 Model Statistics:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")

# Test forward pass
print(f"\n🧪 Testing forward pass...")
for patches, labels in square_image_spliting_loader:
    print(f"Input shape: {patches.shape}")
    print(f"Labels shape: {labels.shape}")
    
    logits = model(patches)
    print(f"Output logits: {logits.shape}")
    print(f"Sample predictions: {logits[0]}")
    
    break

print("\n✅ Model ready for training!")


=== CREATING SIMPLE VISION TRANSFORMER ===
📦 Patch embedding: 49 -> 32
📍 Positional embedding: torch.Size([1, 16, 32])
🔍 Self-attention matrices: 3 x (32 -> 32)
🔄 FFN: 32 -> 64 -> 32
🎯 Classifier: 32 -> 10

📊 Model Statistics:
   Total parameters: 9,930
   Trainable parameters: 9,930

🧪 Testing forward pass...
Input shape: torch.Size([128, 16, 1, 7, 7])
Labels shape: torch.Size([128])
Output logits: torch.Size([128, 10])
Sample predictions: tensor([-0.0328, -0.0004,  0.0249,  0.2919,  0.0047,  0.0679,  0.0348, -0.0929,
        -0.1317, -0.0372], grad_fn=<SelectBackward0>)

✅ Model ready for training!


In [None]:
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.0001):
    """
    Simple training loop with detailed explanations
    Focus: Understanding each step of the training process
    """
    
    # Setup training components
    criterion = nn.CrossEntropyLoss()  # For multi-class classification
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    print(f"=== TRAINING SETUP ===")
    print(f"📚 Epochs: {num_epochs}")
    print(f"📈 Learning rate: {learning_rate}")
    print(f"💡 Optimizer: Adam")
    print(f"🎯 Loss function: CrossEntropyLoss")
    
    # Track training progress
    train_losses = []
    train_accuracies = []
    val_accuracies = []
    
    print(f"\n=== STARTING TRAINING ===")
    
    for epoch in range(num_epochs):
        # === TRAINING PHASE ===
        model.train()  # Set model to training mode
        total_train_loss = 0
        correct_predictions = 0
        total_samples = 0
        num_batches = 0  # Track number of batches manually
        
        print(f"\n🔄 Epoch {epoch + 1}/{num_epochs}")
        print("   Training...")
        
        for batch_idx, (patches, labels) in enumerate(train_loader):
            # Forward pass
            logits = model(patches)  # [batch_size, 10]
            loss = criterion(logits, labels)
            
            # Backward pass
            optimizer.zero_grad()  # Clear gradients
            loss.backward()        # Compute gradients
            optimizer.step()       # Update parameters
            
            # Track statistics
            total_train_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            num_batches += 1
            
            # Print progress every 50 batches
            if batch_idx % 50 == 0:
                current_acc = 100. * correct_predictions / total_samples
                print(f"     Batch {batch_idx:3d}: Loss = {loss.item():.4f}, Acc = {current_acc:.1f}%")
        
        # Calculate epoch statistics using batch count instead of len()
        avg_train_loss = total_train_loss / num_batches
        train_accuracy = 100. * correct_predictions / total_samples
        
        train_losses.append(avg_train_loss)
        train_accuracies.append(train_accuracy)
        
        # === VALIDATION PHASE ===
        model.eval()  # Set model to evaluation mode
        val_correct = 0
        val_total = 0
        
        print("   Validating...")
        with torch.no_grad():  # Disable gradient computation for efficiency
            for patches, labels in val_loader:
                logits = model(patches)
                _, predicted = torch.max(logits.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_accuracy = 100. * val_correct / val_total
        val_accuracies.append(val_accuracy)
        
        # Print epoch results
        print(f"   📊 Results:")
        print(f"      Train Loss: {avg_train_loss:.4f}")
        print(f"      Train Acc:  {train_accuracy:.2f}%")
        print(f"      Val Acc:    {val_accuracy:.2f}%")
        
        # Simple early stopping if validation accuracy is very high
        if val_accuracy > 95.0:
            print(f"   🎉 Great accuracy achieved! Stopping early.")
            break
    
    print(f"\n=== TRAINING COMPLETED ===")
    print(f"Final validation accuracy: {val_accuracies[-1]:.2f}%")
    
    return {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies
    }

# First, let's create proper train and validation loaders
train_patch_loader = SquareImageSplitingLoader(train_loader)
val_patch_loader = SquareImageSplitingLoader(validation_loader)

print("=== DATA SETUP ===")
print(f"📚 Training data ready")
print(f"📖 Validation data ready")

# Let's train our model!
print("\nStarting training process...")
history = train_model(model, train_patch_loader, val_patch_loader, num_epochs=100)


=== DATA SETUP ===
📚 Training data ready
📖 Validation data ready

Starting training process...
=== TRAINING SETUP ===
📚 Epochs: 100
📈 Learning rate: 0.0001
💡 Optimizer: Adam
🎯 Loss function: CrossEntropyLoss

=== STARTING TRAINING ===

🔄 Epoch 1/100
   Training...
     Batch   0: Loss = 2.3093, Acc = 11.7%
     Batch  50: Loss = 2.3009, Acc = 9.8%
     Batch 100: Loss = 2.2847, Acc = 11.8%
     Batch 150: Loss = 2.2895, Acc = 15.4%
     Batch 200: Loss = 2.2751, Acc = 17.8%
     Batch 250: Loss = 2.2725, Acc = 18.9%
     Batch 300: Loss = 2.2428, Acc = 20.3%
     Batch 350: Loss = 2.2136, Acc = 21.6%
   Validating...
   📊 Results:
      Train Loss: 2.2656
      Train Acc:  22.42%
      Val Acc:    30.05%

🔄 Epoch 2/100
   Training...
     Batch   0: Loss = 2.1893, Acc = 31.2%
     Batch  50: Loss = 2.1067, Acc = 30.8%
     Batch 100: Loss = 2.0465, Acc = 30.4%
     Batch 150: Loss = 1.8809, Acc = 30.9%
     Batch 200: Loss = 1.8490, Acc = 31.7%
     Batch 250: Loss = 1.7871, Acc = 32.8

In [None]:
# First, let's fix the validation loader issue and create a proper validation set
from loaders import MnistLoader, SquareImageSplitingLoader

# Reload with proper validation
mnist_loader = MnistLoader()
train_loader, validation_loader = mnist_loader.get_loaders()

# Create patch loaders for both train and validation
train_patch_loader = SquareImageSplitingLoader(train_loader)
val_patch_loader = SquareImageSplitingLoader(validation_loader)

print("=== DATA SETUP ===")
print(f"📚 Training batches: {len(train_loader)}")
print(f"📖 Validation batches: {len(validation_loader)}")

# Now let's train with proper validation
print("=== RETRAINING WITH PROPER VALIDATION ===")
model_v2 = SimpleVisionTransformer()  # Fresh model
history = train_model(model_v2, train_patch_loader, val_patch_loader, num_epochs=100, learning_rate=0.0001)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def analyze_model_understanding(model, val_loader):
    """
    Analyze what the model learned - for educational purposes
    """
    print("=== MODEL ANALYSIS ===")
    
    model.eval()
    
    # Test on a few examples
    correct_predictions = []
    wrong_predictions = []
    attention_patterns = []
    
    with torch.no_grad():
        for batch_idx, (patches, labels) in enumerate(val_loader):
            if batch_idx > 5:  # Just analyze first few batches
                break
                
            # Get model predictions
            logits = model(patches)
            _, predicted = torch.max(logits.data, 1)
            
            # Store examples
            for i in range(min(5, patches.size(0))):  # First 5 images in batch
                actual = labels[i].item()
                pred = predicted[i].item()
                
                if actual == pred:
                    correct_predictions.append((patches[i], actual, pred))
                else:
                    wrong_predictions.append((patches[i], actual, pred))
                
                if len(correct_predictions) >= 5 and len(wrong_predictions) >= 3:
                    break
            
            if len(correct_predictions) >= 5 and len(wrong_predictions) >= 3:
                break
    
    # Print statistics
    print(f"✅ Found {len(correct_predictions)} correct predictions")
    print(f"❌ Found {len(wrong_predictions)} wrong predictions")
    
    # Show some examples
    print("\n📊 CORRECT PREDICTIONS:")
    for i, (patches, actual, pred) in enumerate(correct_predictions[:3]):
        print(f"   Example {i+1}: Actual = {actual}, Predicted = {pred} ✓")
    
    print("\n📊 WRONG PREDICTIONS:")
    for i, (patches, actual, pred) in enumerate(wrong_predictions[:3]):
        print(f"   Example {i+1}: Actual = {actual}, Predicted = {pred} ✗")
    
    return correct_predictions, wrong_predictions

def visualize_training_progress(history):
    """
    Simple visualization of training progress
    """
    print("\n=== TRAINING PROGRESS ===")
    
    epochs = range(1, len(history['train_losses']) + 1)
    
    print("📈 Training Progress:")
    for i, epoch in enumerate(epochs):
        print(f"   Epoch {epoch}:")
        print(f"      Loss: {history['train_losses'][i]:.4f}")
        print(f"      Train Acc: {history['train_accuracies'][i]:.2f}%")
        print(f"      Val Acc: {history['val_accuracies'][i]:.2f}%")
    
    # Show improvement
    initial_val_acc = history['val_accuracies'][0]
    final_val_acc = history['val_accuracies'][-1]
    improvement = final_val_acc - initial_val_acc
    
    print(f"\n🎯 OVERALL IMPROVEMENT:")
    print(f"   Initial validation accuracy: {initial_val_acc:.2f}%")
    print(f"   Final validation accuracy: {final_val_acc:.2f}%")
    print(f"   Improvement: +{improvement:.2f}%")
    
    return epochs

# Analyze our trained model
correct_examples, wrong_examples = analyze_model_understanding(model_v2, val_patch_loader)
epochs = visualize_training_progress(history)
