# Notebook 04: Computer Vision - Image Classification

**Learning Objectives:**
- Understand image classification with deep learning
- Load and use pre-trained vision models
- Classify images into predefined categories
- Work with the Vision Transformer (ViT) architecture

## Prerequisites

### Hardware Requirements

| Model Option | Model Name | Size | Min RAM | Recommended Setup | Notes |
|--------------|------------|------|---------|-------------------|-------|
| **CPU (Small)** | google/vit-base-patch16-224 | 346MB | 4GB | 4GB RAM, CPU | Good accuracy |
| **GPU (Medium)** | google/vit-large-patch16-224 | 1.2GB | 6GB | 8GB VRAM (RTX 4080) | Better accuracy |

### Software Requirements
- Python 3.8+
- Libraries: `transformers`, `torch`, `PIL`
- See `requirements.txt` for full list

## Overview

**Image Classification** assigns labels to images from a predefined set of categories.

**Use Cases:**
- Object recognition
- Medical image diagnosis
- Content moderation
- Quality control in manufacturing
- Wildlife monitoring

**Vision Transformer (ViT):**
- Applies transformer architecture to images
- Splits image into patches
- Treats patches like tokens in NLP
- Achieves state-of-the-art results

## Expected Behaviors

### First Time Running
- **Model Download**: ~346MB for vit-base (~2-4 minutes)
- Downloads model and image processor
- Cached in `~/.cache/huggingface/hub/`

### Setup Cell Output
```
PyTorch version: 2.x.x
CUDA available: True/False
GPU: NVIDIA GeForce RTX 4080 (if available)
```

### Model Loading
```
Loading google/vit-base-patch16-224...
Model loaded successfully!
```
- **CPU**: 3-7 seconds
- **GPU**: 2-4 seconds

### Classification Output Format
```python
[
  {'label': 'Egyptian cat', 'score': 0.8932},
  {'label': 'tabby cat', 'score': 0.0854},
  {'label': 'tiger cat', 'score': 0.0124}
]
```

### ImageNet Classes
- Model trained on ImageNet with **1000 classes**
- Classes include animals, vehicles, objects, food, etc.
- Full list: [ImageNet Classes](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)

### Expected Accuracy
- **Clear objects** (single cat, car, etc.): 80-95% confidence on top prediction
- **Multiple objects**: May focus on most prominent object
- **Unusual angles/lighting**: Lower confidence (60-80%)
- **Objects not in ImageNet**: May misclassify (model limited to training data)

### Performance Benchmarks
- **Single image**:
  - CPU: 200-500ms
  - GPU: 20-50ms
- **Batch of 10 images**:
  - CPU: 1-2 seconds
  - GPU: 100-200ms

### Image Loading
- Accepts URLs, local file paths, or PIL Image objects
- Automatically resizes images to 224x224 pixels
- Converts to RGB if needed
- **Common error**: "Connection timeout" for slow/blocked URLs

### Top-K Predictions
- `top_k=5` returns 5 most likely classes
- Scores sum to approximately 1.0 (probabilities)
- Lower-ranked predictions have exponentially lower scores

### Common Observations
- Works best on **centered, well-lit objects**
- Background clutter reduces confidence
- Some classes are very specific (e.g., 150+ dog breeds)
- May confuse similar-looking objects (e.g., wolves vs dogs)

## Setup and Installation

In [None]:
# Import required libraries
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification, pipeline, set_seed
from PIL import Image
import requests
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')

# Set seed for reproducibility
set_seed(1103)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Model Selection

In [None]:
# CHOOSE YOUR MODEL:

# Option 1: CPU-friendly (recommended for beginners)
MODEL_NAME = "google/vit-base-patch16-224"  # 346MB, ViT base

# Option 2: GPU-optimized (uncomment if you have RTX 4080 or similar)
# MODEL_NAME = "google/vit-large-patch16-224"  # 1.2GB, better accuracy

print(f"Selected model: {MODEL_NAME}")

## Helper Function: Load Images

In [None]:
def load_image_from_url(url):
    """
    Load an image from a URL.
    """
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    return img

def display_image(img, title="Image"):
    """
    Display an image with a title.
    """
    print(f"\n=== {title} ===")
    print(f"Size: {img.size}, Mode: {img.mode}")
    # In Jupyter, this will display the image
    return img

## Method 1: Using Pipeline (Simplest)

In [None]:
# Create image classification pipeline
print(f"Loading {MODEL_NAME}...")
classifier = pipeline(
    "image-classification",
    model=MODEL_NAME,
    device=0 if torch.cuda.is_available() else -1
)

### Basic Image Classification

In [None]:
# Load and classify a sample image
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
image = load_image_from_url(image_url)

# Display image
display_image(image, "Input Image")

# Classify
results = classifier(image, top_k=5)

print("\n=== TOP 5 PREDICTIONS ===")
for i, result in enumerate(results, 1):
    print(f"{i}. {result['label']:30s} - {result['score']:.4f}")

### Multiple Images

In [None]:
# Test with multiple images
test_urls = [
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png",
    "https://images.unsplash.com/photo-1552053831-71594a27632d?w=400",  # dog
    "https://images.unsplash.com/photo-1511367461989-f85a21fda167?w=400"   # banana
]

for i, url in enumerate(test_urls, 1):
    try:
        img = load_image_from_url(url)
        results = classifier(img, top_k=3)
        
        print(f"\n{'='*50}")
        print(f"Image {i}:")
        for result in results:
            print(f"  {result['label']:25s} - {result['score']:.4f}")
    except Exception as e:
        print(f"Error loading image {i}: {e}")

## Method 2: Using Model and Processor Directly

In [None]:
# Load processor and model
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)

# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

print(f"Model loaded on: {device}")
print(f"Number of classes: {model.config.num_labels}")

In [None]:
# Classify with more control
import torch.nn.functional as F

image = load_image_from_url("https://images.unsplash.com/photo-1517849845537-4d257902454a?w=400")  # dog

# Process image
inputs = processor(images=image, return_tensors="pt").to(device)

# Get predictions
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    probabilities = F.softmax(logits, dim=-1)[0]

# Get top 5 predictions
top_probs, top_indices = torch.topk(probabilities, k=5)

print("\n=== DETAILED PREDICTIONS ===")
for prob, idx in zip(top_probs, top_indices):
    label = model.config.id2label[idx.item()]
    print(f"{label:30s} - {prob.item():.6f}")

## Practical Applications

### Example 1: Batch Classification

In [None]:
# Classify multiple images efficiently
image_urls = [
    "https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400",  # cat
    "https://images.unsplash.com/photo-1546527868-ccb7ee7dfa6a?w=400",  # car
    "https://images.unsplash.com/photo-1501594907352-04cda38ebc29?w=400"   # birds
]

images = [load_image_from_url(url) for url in image_urls]

# Batch process
inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

with torch.no_grad():
    outputs = model(**inputs)
    predictions = outputs.logits.argmax(dim=-1)

print("=== BATCH CLASSIFICATION ===")
for i, pred_idx in enumerate(predictions):
    label = model.config.id2label[pred_idx.item()]
    print(f"Image {i+1}: {label}")

### Example 2: Confidence Filtering

In [None]:
def classify_with_confidence_threshold(image, threshold=0.5):
    """
    Only return predictions above a confidence threshold.
    """
    results = classifier(image, top_k=10)
    
    confident_predictions = [r for r in results if r['score'] >= threshold]
    
    if confident_predictions:
        print(f"\nPredictions with >{threshold*100}% confidence:")
        for pred in confident_predictions:
            print(f"  {pred['label']:30s} - {pred['score']:.4f}")
    else:
        print(f"\nNo predictions above {threshold*100}% confidence")
        print("Top prediction:")
        print(f"  {results[0]['label']:30s} - {results[0]['score']:.4f}")
    
    return confident_predictions

# Test
image = load_image_from_url("https://images.unsplash.com/photo-1518791841217-8f162f1e1131?w=400")  # cat
classify_with_confidence_threshold(image, threshold=0.3)

### Example 3: Local Images

In [None]:
# If you have local images in sample_data/
import os

sample_data_path = "../sample_data"

if os.path.exists(sample_data_path):
    image_files = [f for f in os.listdir(sample_data_path) 
                   if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    if image_files:
        print("=== CLASSIFYING LOCAL IMAGES ===")
        for img_file in image_files[:3]:  # Limit to 3
            img_path = os.path.join(sample_data_path, img_file)
            img = Image.open(img_path)
            results = classifier(img, top_k=3)
            
            print(f"\n{img_file}:")
            for result in results:
                print(f"  {result['label']:25s} - {result['score']:.4f}")
    else:
        print("No images found in sample_data/. Add some .jpg or .png files to test!")
else:
    print("sample_data/ directory not found. You can add images there for testing.")

In [None]:
# Using CIFAR-10 dataset (170MB, 10 classes, 32x32 color images)
import torchvision.datasets as datasets
import torchvision.transforms as transforms

print("Downloading CIFAR-10 test dataset...")
# Download test set (will cache after first download)
cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)

# CIFAR-10 class names
cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                 'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Loaded {len(cifar10_test)} test images\n")

# Classify a few CIFAR-10 images
print("=== CIFAR-10 Classification ===")
for i in range(5):
    img, true_label = cifar10_test[i]
    
    # Classify the image
    results = classifier(img, top_k=3)
    
    print(f"\nImage {i+1}:")
    print(f"  True class: {cifar_classes[true_label]}")
    print(f"  Predictions:")
    for j, pred in enumerate(results, 1):
        print(f"    {j}. {pred['label']:30s} - {pred['score']:.4f}")

## State-of-the-Art Open Models (Not Covered)

While Vision Transformer (ViT) is excellent, there are several cutting-edge image classification models that push the boundaries of accuracy and efficiency. These models represent the latest advances in computer vision research.

### Top SOTA Image Classification Models

#### 1. üî∑ ConvNeXt (Facebook/Meta)
**Modern pure convolutional architecture rivaling transformers**
- **Why it's special**: Modernized ConvNet design matching ViT performance without attention
- **Performance**: 87.8% ImageNet accuracy (ConvNeXt-XL), faster inference than ViT
- **Model Card**: [facebook/convnext-large-224](https://huggingface.co/facebook/convnext-large-224)
- **Paper**: [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545)
- **Size**: 800MB (ConvNeXt-Large)

#### 2. ‚ö° EfficientNetV2 (Google)
**Optimized scaling for speed and accuracy**
- **Why it's special**: Progressive learning strategy, extremely efficient training/inference
- **Performance**: 87.3% ImageNet accuracy at 5x faster training than EfficientNet-B7
- **Model Card**: [google/efficientnet-b7](https://huggingface.co/google/efficientnet-b7)
- **Paper**: [EfficientNetV2: Smaller Models and Faster Training](https://arxiv.org/abs/2104.00298)
- **Size**: 264MB (EfficientNetV2-L)

#### 3. ü™ü Swin Transformer (Microsoft)
**Hierarchical vision transformer with shifted windows**
- **Why it's special**: Computes attention in local windows, scales to high-resolution images
- **Performance**: 87.3% ImageNet accuracy, excellent for downstream tasks
- **Model Card**: [microsoft/swin-large-patch4-window7-224](https://huggingface.co/microsoft/swin-large-patch4-window7-224)
- **Paper**: [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
- **Size**: 800MB (Swin-Large)

#### 4. üé® BEiT (Microsoft)
**BERT pre-training approach for images**
- **Why it's special**: Masked image modeling (like BERT for text), strong transfer learning
- **Performance**: 88.6% ImageNet accuracy (BEiT-Large), excellent fine-tuning capability
- **Model Card**: [microsoft/beit-large-patch16-224](https://huggingface.co/microsoft/beit-large-patch16-224)
- **Paper**: [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)
- **Size**: 1.2GB (BEiT-Large)

#### 5. üìö DeiT (Facebook)
**Data-efficient image transformer**
- **Why it's special**: Distillation-based training, achieves great results with less data
- **Performance**: 85.2% ImageNet accuracy, trains 3x faster than ViT
- **Model Card**: [facebook/deit-base-distilled-patch16-224](https://huggingface.co/facebook/deit-base-distilled-patch16-224)
- **Paper**: [Training data-efficient image transformers](https://arxiv.org/abs/2012.12877)
- **Size**: 346MB (DeiT-Base)

### Why Not Covered?

These models require:
- **GPU Memory**: 12-24GB VRAM for large variants
- **Inference Time**: 2-5x slower than ViT-base on CPU
- **Specialized Use Cases**: Benefits most apparent at scale or on specific domains
- **Training Resources**: Fine-tuning requires significant compute

ViT provides an excellent balance of performance and accessibility for learning!

### Learning Path Recommendation

1. **Start here**: Master ViT (this notebook)
2. **Next step**: Try ConvNeXt or Swin for better accuracy
3. **Efficiency focus**: Experiment with EfficientNetV2 for deployment
4. **Research**: Explore BEiT for transfer learning projects

### Benchmarks & Leaderboards

- **ImageNet-1K Top-1 Accuracy** (224x224 resolution):
  - ViT-Base: 81.8%
  - DeiT-Base: 85.2%
  - ConvNeXt-Large: 87.8%
  - Swin-Large: 87.3%
  - BEiT-Large: 88.6%

- **Explore rankings**: [Papers With Code - ImageNet](https://paperswithcode.com/sota/image-classification-on-imagenet)

### Quick Comparison Table

| Model | Size | Speed | Accuracy | Best For |
|-------|------|-------|----------|----------|
| **ViT-Base** ‚≠ê | 346MB | Fast | 81.8% | Learning, general use |
| **DeiT-Base** | 346MB | Fast | 85.2% | Data-efficient training |
| **EfficientNetV2** | 264MB | Very Fast | 87.3% | Production deployment |
| **ConvNeXt-Large** | 800MB | Medium | 87.8% | High accuracy, pure CNN |
| **Swin-Large** | 800MB | Medium | 87.3% | High-res images, detection |
| **BEiT-Large** | 1.2GB | Slow | 88.6% | Transfer learning, fine-tuning |

**üí° Tip**: For real-world applications, ConvNeXt and Swin offer the best accuracy-efficiency trade-off with GPU acceleration!

## Exercises

1. **Custom Images**: Test with your own images. How accurate is the model?

2. **Ambiguous Images**: Try images that could fit multiple categories. What does the model predict?

3. **Model Comparison**: If you have GPU, compare ViT-base with ViT-large. Is the larger model better?

4. **Batch Size**: Experiment with batch processing different numbers of images. How does speed change?

5. **Other Models**: Try ResNet-50 (`microsoft/resnet-50`) instead of ViT. Compare results.

In [None]:
# Your code here for exercises


## Key Takeaways

‚úÖ **Vision Transformer (ViT)** treats images as sequences of patches

‚úÖ **ImageNet pre-training** enables recognition of 1000+ object categories

‚úÖ **Batch processing** improves efficiency for multiple images

‚úÖ **Confidence scores** indicate prediction certainty

‚úÖ Models work on both URLs and local files

## Next Steps

- Try **Notebook 05**: Object Detection for locating multiple objects
- Explore other vision models on [HuggingFace Hub](https://huggingface.co/models?pipeline_tag=image-classification)
- Learn about fine-tuning on custom datasets

## Resources

- [Vision Transformer Paper](https://arxiv.org/abs/2010.11929)
- [Image Classification Guide](https://huggingface.co/docs/transformers/tasks/image_classification)
- [ImageNet Classes](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)