# CIFAR-10 Model Comparison: MLP vs CNN vs ViT

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/juho127/ClassificationTest/blob/main/model_comparison.ipynb)

This notebook compares three different architectures for CIFAR-10 classification:
1. **MLP** (Multi-Layer Perceptron) - Basic feedforward network
2. **CNN** (Convolutional Neural Network) - Spatial feature extraction
3. **ViT** (Vision Transformer) - Attention-based architecture

## Expected Performance
- MLP: ~50-55%
- CNN: ~70-75%
- ViT: ~65-70%

## Learning Goals
- Understand the differences between architectures
- See how inductive biases affect performance
- Compare training time and model complexity


## 0. Environment Setup


In [None]:
# Check if running on Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("âœ“ Running on Google Colab")
    print("ðŸ“Œ Tip: Runtime > Change runtime type > GPU for faster training!")
except:
    IN_COLAB = False
    print("âœ“ Running on local environment")

# Install required packages on Colab
if IN_COLAB:
    print("\nInstalling packages...")
    import sys
    !{sys.executable} -m pip install -q torch torchvision tqdm matplotlib einops


## 1. Import Libraries


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print("ðŸŽ‰ You can use GPU for faster training!")


## 2. Hyperparameters and Data Loading


In [None]:
# Hyperparameters
BATCH_SIZE = 128
LEARNING_RATE = 0.001
NUM_EPOCHS = 20  # Reduce to 10 for faster testing
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# CIFAR-10 classes
CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Device: {DEVICE}")
if DEVICE.type == 'cuda':
    print("âœ“ Using GPU!")
else:
    print("â„¹ Using CPU (Colab: Runtime > Change runtime type > GPU)")


In [None]:
# Data preprocessing with augmentation for training
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
print("Loading dataset...")
train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_train
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_test
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")


In [None]:
# Visualize sample images
def show_images(loader, num_images=10):
    dataiter = iter(loader)
    images, labels = next(dataiter)
    
    # Denormalize
    images = images / 2 + 0.5
    
    fig, axes = plt.subplots(2, 5, figsize=(12, 5))
    fig.suptitle('CIFAR-10 Sample Images', fontsize=16, fontweight='bold')
    
    for idx, ax in enumerate(axes.flat):
        if idx < num_images:
            img = images[idx].numpy().transpose((1, 2, 0))
            ax.imshow(img)
            ax.set_title(f'{CLASSES[labels[idx]]}', fontsize=10)
            ax.axis('off')
    
    plt.tight_layout()
    plt.show()

show_images(train_loader)


## 3. Model Definitions

We'll define three different architectures and compare their performance.


### 3.1 MLP Model (Multi-Layer Perceptron)

Simple feedforward network that flattens the image.


In [None]:
class MLP(nn.Module):
    """Multi-Layer Perceptron"""
    
    def __init__(self, input_size=3072, hidden_size1=512, hidden_size2=256, num_classes=10):
        super(MLP, self).__init__()
        
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.3)
        
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        
        self.fc3 = nn.Linear(hidden_size2, num_classes)
    
    def forward(self, x):
        # Flatten image (32x32x3 = 3072)
        x = x.view(x.size(0), -1)
        
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
        
        x = self.fc3(x)
        return x

# Create MLP model
mlp_model = MLP().to(DEVICE)
print(mlp_model)
print(f"\nMLP Parameters: {sum(p.numel() for p in mlp_model.parameters() if p.requires_grad):,}")


In [None]:
class CNN(nn.Module):
    """Convolutional Neural Network"""
    
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.3)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.4)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.conv1(x)  # 32 -> 16
        x = self.conv2(x)  # 16 -> 8
        x = self.conv3(x)  # 8 -> 4
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x

# Create CNN model
cnn_model = CNN().to(DEVICE)
print(cnn_model)
print(f"\nCNN Parameters: {sum(p.numel() for p in cnn_model.parameters() if p.requires_grad):,}")


### 3.3 ViT Model (Vision Transformer)

Uses self-attention mechanisms to process image patches.


In [None]:
class PatchEmbedding(nn.Module):
    """Split image into patches and embed them"""
    
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e h w -> b (h w) e')
        )
    
    def forward(self, x):
        x = self.projection(x)
        return x


class TransformerBlock(nn.Module):
    """Transformer Encoder Block"""
    
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        # Multi-head self-attention
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        # MLP
        x = x + self.mlp(self.norm2(x))
        return x


class ViT(nn.Module):
    """Vision Transformer (Small size for CIFAR-10)"""
    
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10,
                 embed_dim=256, depth=6, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # Class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)
        
        # Add class token
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        
        # Classification token
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)
        
        return x

# Create ViT model
vit_model = ViT().to(DEVICE)
print(vit_model)
print(f"\nViT Parameters: {sum(p.numel() for p in vit_model.parameters() if p.requires_grad):,}")


In [None]:
# Compare model sizes
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model Comparison:")
print("=" * 50)
print(f"MLP Parameters: {count_parameters(mlp_model):,}")
print(f"CNN Parameters: {count_parameters(cnn_model):,}")
print(f"ViT Parameters: {count_parameters(vit_model):,}")
print("=" * 50)


## 4. Training Functions


In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, epoch, model_name):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'{model_name} Epoch {epoch+1}/{NUM_EPOCHS}')
    for images, labels in pbar:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({
            'loss': f'{running_loss/total:.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc


def evaluate(model, test_loader, criterion):
    """Evaluate model"""
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    test_loss = test_loss / total
    test_acc = 100 * correct / total
    return test_loss, test_acc


def train_model(model, model_name, num_epochs=NUM_EPOCHS):
    """Complete training loop for a model"""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    train_losses = []
    train_accs = []
    test_losses = []
    test_accs = []
    
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, epoch, model_name)
        test_loss, test_acc = evaluate(model, test_loader, criterion)
        
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"  Test  - Loss: {test_loss:.4f}, Acc: {test_acc:.2f}%")
        print("-" * 60)
    
    training_time = time.time() - start_time
    
    print(f"\n{model_name} Training Complete!")
    print(f"Total training time: {training_time/60:.2f} minutes")
    print(f"Best test accuracy: {max(test_accs):.2f}%")
    
    return {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'test_losses': test_losses,
        'test_accs': test_accs,
        'training_time': training_time,
        'best_acc': max(test_accs)
    }

print("Training functions defined!")


## 5. Train Models

Now let's train each model and compare the results. Run the cells below one by one.


### 5.1 Train MLP


In [None]:
# Train MLP model
mlp_results = train_model(mlp_model, "MLP")


### 5.2 Train CNN


In [None]:
# Train CNN model
cnn_results = train_model(cnn_model, "CNN")


### 5.3 Train ViT


In [None]:
# Train ViT model
vit_results = train_model(vit_model, "ViT")


## 6. Compare Results

Now let's visualize and compare the performance of all three models.


### 6.1 Summary Table


In [None]:
# Summary comparison
results = {
    'MLP': mlp_results,
    'CNN': cnn_results,
    'ViT': vit_results
}

print("\n" + "=" * 70)
print("FINAL COMPARISON")
print("=" * 70)
print(f"{'Model':<10} {'Parameters':<15} {'Best Acc':<12} {'Time (min)':<12}")
print("-" * 70)

models = [mlp_model, cnn_model, vit_model]
names = ['MLP', 'CNN', 'ViT']

for name, model in zip(names, models):
    params = count_parameters(model)
    best_acc = results[name]['best_acc']
    time_taken = results[name]['training_time'] / 60
    print(f"{name:<10} {params:<15,} {best_acc:<12.2f}% {time_taken:<12.2f}")

print("=" * 70)


### 6.2 Training Curves Comparison


In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

epochs = range(1, NUM_EPOCHS + 1)

# Training Loss
ax = axes[0, 0]
ax.plot(epochs, mlp_results['train_losses'], 'b-', label='MLP', linewidth=2)
ax.plot(epochs, cnn_results['train_losses'], 'g-', label='CNN', linewidth=2)
ax.plot(epochs, vit_results['train_losses'], 'r-', label='ViT', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Training Loss', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Test Loss
ax = axes[0, 1]
ax.plot(epochs, mlp_results['test_losses'], 'b-', label='MLP', linewidth=2)
ax.plot(epochs, cnn_results['test_losses'], 'g-', label='CNN', linewidth=2)
ax.plot(epochs, vit_results['test_losses'], 'r-', label='ViT', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Test Loss', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Training Accuracy
ax = axes[1, 0]
ax.plot(epochs, mlp_results['train_accs'], 'b-', label='MLP', linewidth=2)
ax.plot(epochs, cnn_results['train_accs'], 'g-', label='CNN', linewidth=2)
ax.plot(epochs, vit_results['train_accs'], 'r-', label='ViT', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Training Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Test Accuracy
ax = axes[1, 1]
ax.plot(epochs, mlp_results['test_accs'], 'b-', label='MLP', linewidth=2)
ax.plot(epochs, cnn_results['test_accs'], 'g-', label='CNN', linewidth=2)
ax.plot(epochs, vit_results['test_accs'], 'r-', label='ViT', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Test Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("Comparison plot saved as 'model_comparison.png'")


### 6.3 Bar Chart Comparison


In [None]:
# Bar chart for final metrics
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

model_names = ['MLP', 'CNN', 'ViT']
colors = ['#3498db', '#2ecc71', '#e74c3c']

# Best Accuracy
ax = axes[0]
accuracies = [mlp_results['best_acc'], cnn_results['best_acc'], vit_results['best_acc']]
bars = ax.bar(model_names, accuracies, color=colors, alpha=0.8)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Best Test Accuracy', fontsize=14, fontweight='bold')
ax.set_ylim([0, 100])
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{acc:.2f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Training Time
ax = axes[1]
times = [mlp_results['training_time']/60, cnn_results['training_time']/60, vit_results['training_time']/60]
bars = ax.bar(model_names, times, color=colors, alpha=0.8)
ax.set_ylabel('Time (minutes)', fontsize=12)
ax.set_title('Training Time', fontsize=14, fontweight='bold')
for bar, time in zip(bars, times):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{time:.1f}m', ha='center', va='bottom', fontsize=11, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Parameters
ax = axes[2]
params = [count_parameters(mlp_model)/1e6, count_parameters(cnn_model)/1e6, count_parameters(vit_model)/1e6]
bars = ax.bar(model_names, params, color=colors, alpha=0.8)
ax.set_ylabel('Parameters (Millions)', fontsize=12)
ax.set_title('Model Size', fontsize=14, fontweight='bold')
for bar, param in zip(bars, params):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{param:.2f}M', ha='center', va='bottom', fontsize=11, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('model_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

print("Metrics comparison saved as 'model_metrics.png'")


In [None]:
# Calculate per-class accuracy for each model
def get_class_accuracy(model):
    class_correct = [0] * 10
    class_total = [0] * 10
    
    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    
    class_acc = [100 * class_correct[i] / class_total[i] for i in range(10)]
    return class_acc

print("Calculating per-class accuracy...")
mlp_class_acc = get_class_accuracy(mlp_model)
cnn_class_acc = get_class_accuracy(cnn_model)
vit_class_acc = get_class_accuracy(vit_model)

# Plot per-class accuracy
x = np.arange(len(CLASSES))
width = 0.25

fig, ax = plt.subplots(figsize=(14, 6))

bars1 = ax.bar(x - width, mlp_class_acc, width, label='MLP', color='#3498db', alpha=0.8)
bars2 = ax.bar(x, cnn_class_acc, width, label='CNN', color='#2ecc71', alpha=0.8)
bars3 = ax.bar(x + width, vit_class_acc, width, label='ViT', color='#e74c3c', alpha=0.8)

ax.set_xlabel('Class', fontsize=12, fontweight='bold')
ax.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
ax.set_title('Per-Class Accuracy Comparison', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(CLASSES, fontsize=11)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')
ax.set_ylim([0, 100])

plt.tight_layout()
plt.savefig('per_class_accuracy.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nPer-class accuracy saved as 'per_class_accuracy.png'")


In [None]:
# Print detailed per-class accuracy
print("\nDetailed Per-Class Accuracy:")
print("=" * 70)
print(f"{'Class':<12} {'MLP':<15} {'CNN':<15} {'ViT':<15}")
print("-" * 70)

for i, cls in enumerate(CLASSES):
    print(f"{cls:<12} {mlp_class_acc[i]:<15.2f}% {cnn_class_acc[i]:<15.2f}% {vit_class_acc[i]:<15.2f}%")

print("=" * 70)


## 7. Key Observations and Analysis

### MLP (Multi-Layer Perceptron)
**Pros:**
- âœ“ Simple architecture, easy to understand
- âœ“ Fast training
- âœ“ Fewer parameters

**Cons:**
- âœ— Cannot capture spatial relationships in images
- âœ— Treats pixels independently
- âœ— Limited accuracy (~50-55%)

**Best for:** Quick baseline, understanding basics

---

### CNN (Convolutional Neural Network)
**Pros:**
- âœ“ Best accuracy (~70-75%)
- âœ“ Exploits spatial structure of images
- âœ“ Parameter efficient (shared weights)
- âœ“ Translation invariant

**Cons:**
- âœ— Fixed receptive field
- âœ— Limited global context

**Best for:** Image classification, spatial pattern recognition

---

### ViT (Vision Transformer)
**Pros:**
- âœ“ Global attention mechanism
- âœ“ Flexible architecture
- âœ“ Good performance (~65-70%)

**Cons:**
- âœ— Needs more data to excel
- âœ— Slower training
- âœ— More parameters

**Best for:** Large datasets, transfer learning

---

## Conclusion

For CIFAR-10:
1. **CNN performs best** - Strong inductive biases for images
2. **ViT is competitive** - But needs more data to excel (shines on ImageNet)
3. **MLP is limited** - Cannot exploit spatial structure

**Recommendation:** 
- Use **CNN** for small image datasets like CIFAR-10
- Consider **ViT** for larger datasets (ImageNet) or with pre-training
- Use **MLP** only for learning purposes or non-spatial data


## 8. Discussion Questions

Think about and discuss:

1. **Why does CNN outperform MLP on CIFAR-10?**
   - Hint: Think about spatial structure and parameter sharing

2. **Why doesn't ViT outperform CNN here?**
   - Hint: Consider the dataset size and inductive biases

3. **When would you choose each architecture?**
   - MLP: ?
   - CNN: ?
   - ViT: ?

4. **What happens if we train longer?**
   - Will the gap between models increase or decrease?

5. **Which classes are harder to classify? Why?**
   - Look at the per-class accuracy chart

6. **How can we improve each model?**
   - MLP: ?
   - CNN: ?
   - ViT: ?


## 9. Save Models (Optional)


In [None]:
# Save trained models
torch.save(mlp_model.state_dict(), 'mlp_model.pth')
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
torch.save(vit_model.state_dict(), 'vit_model.pth')

print("Models saved!")
print("  - mlp_model.pth")
print("  - cnn_model.pth")
print("  - vit_model.pth")
