# Part 4.2: Computer Vision — Beyond Classification

CNNs can do much more than tell you "this is a cat." The same spatial feature hierarchies that power image classification — edges, textures, parts, objects — serve as the backbone for a rich family of vision tasks: **locating** objects in images, **drawing pixel-perfect masks** around them, and even **connecting images to language**. This notebook explores the architectures that take convolutional features and push them far beyond a single class label.

We will build up from classification to detection, segmentation, and finally to modern vision transformers and multimodal models like CLIP. Each section introduces the core idea, implements a key component from scratch, and connects it to the broader deep learning landscape.

---

## Learning Objectives

By the end of this notebook, you should be able to:

- [ ] Describe the progression from classification → localization → detection → segmentation
- [ ] Implement Intersection over Union (IoU) from scratch and explain its role in detection
- [ ] Explain anchor boxes and Non-Maximum Suppression (NMS) and implement both
- [ ] Describe the YOLO single-pass detection philosophy and its grid-based output structure
- [ ] Explain the U-Net encoder-decoder architecture and implement a mini version in PyTorch
- [ ] Contrast semantic segmentation, instance segmentation, and panoptic segmentation
- [ ] Implement patch embeddings for Vision Transformers (ViT) and explain positional embeddings
- [ ] Describe CLIP's contrastive learning approach for connecting vision and language
- [ ] Apply transfer learning with pretrained torchvision models

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')
np.random.seed(42)
torch.manual_seed(42)

print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)

---

## 1. From Classification to Detection

### Intuitive Explanation

In notebook 13, we trained CNNs to answer a single question: **"What is in this image?"** That is classification — one label per image. But real-world vision demands much more:

| Task | Question | Output | Example |
|------|----------|--------|--------|
| **Classification** | "What is this?" | One label | "Cat" |
| **Localization** | "What is this and where?" | Label + one bounding box | "Cat at (x, y, w, h)" |
| **Object Detection** | "What are all the things and where?" | Multiple labels + boxes | "Cat at ..., Dog at ..." |
| **Semantic Segmentation** | "What is every pixel?" | Per-pixel class label | Pixel grid of classes |
| **Instance Segmentation** | "What object does every pixel belong to?" | Per-pixel class + instance ID | "Cat #1, Cat #2" |

The key insight is that **the same CNN backbone** (VGG, ResNet, etc.) can serve all these tasks. The difference lies in what you attach to the end of the feature extractor — a classification head, a bounding-box regressor, a pixel-wise decoder, or a combination.

In [None]:
# Visualize the hierarchy of computer vision tasks
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for ax in axes:
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    ax.set_aspect('equal')
    ax.axis('off')

# 1. Classification
ax = axes[0]
ax.set_title('Classification', fontsize=13, fontweight='bold')
rect = patches.FancyBboxPatch((1, 1), 8, 8, boxstyle="round,pad=0.2",
                               facecolor='lightblue', edgecolor='blue', linewidth=2)
ax.add_patch(rect)
circle = plt.Circle((5, 4.5), 2, color='gray', alpha=0.6)
ax.add_patch(circle)
ax.plot([3.5, 3.0, 4.0], [6.3, 7.5, 6.8], color='gray', linewidth=2)
ax.plot([6.5, 7.0, 6.0], [6.3, 7.5, 6.8], color='gray', linewidth=2)
ax.text(5, 0.3, '"Cat"', ha='center', fontsize=12, fontweight='bold', color='blue')

# 2. Localization
ax = axes[1]
ax.set_title('Localization', fontsize=13, fontweight='bold')
rect = patches.FancyBboxPatch((1, 1), 8, 8, boxstyle="round,pad=0.2",
                               facecolor='lightyellow', edgecolor='gray', linewidth=1)
ax.add_patch(rect)
circle = plt.Circle((5, 4.5), 2, color='gray', alpha=0.6)
ax.add_patch(circle)
ax.plot([3.5, 3.0, 4.0], [6.3, 7.5, 6.8], color='gray', linewidth=2)
ax.plot([6.5, 7.0, 6.0], [6.3, 7.5, 6.8], color='gray', linewidth=2)
bbox = patches.Rectangle((2.5, 2), 5, 6.2, linewidth=2, edgecolor='red',
                          facecolor='none', linestyle='--')
ax.add_patch(bbox)
ax.text(5, 0.3, '"Cat" + box', ha='center', fontsize=12, fontweight='bold', color='red')

# 3. Detection
ax = axes[2]
ax.set_title('Detection', fontsize=13, fontweight='bold')
rect = patches.FancyBboxPatch((1, 1), 8, 8, boxstyle="round,pad=0.2",
                               facecolor='lightyellow', edgecolor='gray', linewidth=1)
ax.add_patch(rect)
circle1 = plt.Circle((3.5, 4.5), 1.4, color='gray', alpha=0.6)
ax.add_patch(circle1)
ax.plot([2.5, 2.2, 2.9], [5.7, 6.5, 6.1], color='gray', linewidth=2)
ax.plot([4.5, 4.8, 4.1], [5.7, 6.5, 6.1], color='gray', linewidth=2)
circle2 = plt.Circle((7, 3.5), 1.2, color='brown', alpha=0.4)
ax.add_patch(circle2)
bbox1 = patches.Rectangle((1.7, 2.5), 3.6, 4.5, linewidth=2, edgecolor='red',
                           facecolor='none', linestyle='--')
ax.add_patch(bbox1)
bbox2 = patches.Rectangle((5.5, 1.8), 3, 3.4, linewidth=2, edgecolor='green',
                           facecolor='none', linestyle='--')
ax.add_patch(bbox2)
ax.text(3.5, 7.5, 'Cat', fontsize=10, color='red', fontweight='bold', ha='center')
ax.text(7, 5.8, 'Dog', fontsize=10, color='green', fontweight='bold', ha='center')

# 4. Segmentation
ax = axes[3]
ax.set_title('Segmentation', fontsize=13, fontweight='bold')
rect = patches.FancyBboxPatch((1, 1), 8, 8, boxstyle="round,pad=0.2",
                               facecolor='lightyellow', edgecolor='gray', linewidth=1)
ax.add_patch(rect)
circle1 = plt.Circle((3.5, 4.5), 1.4, color='red', alpha=0.4)
ax.add_patch(circle1)
circle2 = plt.Circle((7, 3.5), 1.2, color='green', alpha=0.4)
ax.add_patch(circle2)
ax.text(5, 0.3, 'Per-pixel labels', ha='center', fontsize=12, fontweight='bold', color='purple')

plt.suptitle('The Computer Vision Task Hierarchy', fontsize=15, fontweight='bold', y=1.05)
plt.tight_layout()
plt.show()

### Deep Dive: What CNNs Actually Learn

Recall from notebook 13 that CNN layers learn a hierarchy of increasingly abstract features:

| Layer Depth | What It Detects | Receptive Field | Analogy |
|------------|----------------|-----------------|--------|
| Layer 1 | Edges, gradients | 3x3 - 5x5 pixels | Individual brush strokes |
| Layer 2 | Textures, corners | 10-20 pixels | Patterns in the paint |
| Layer 3 | Parts (eyes, wheels) | 40-80 pixels | Recognizable components |
| Layer 4 | Objects (faces, cars) | 100+ pixels | Complete things |
| Layer 5 | Scenes, contexts | Entire image | The full picture |

**Key insight:** For classification, we only care about the final layer's global summary. For detection and segmentation, we need the **spatial information** from intermediate layers — we need to know not just *what* features are present, but *where* they are. This is why detection and segmentation architectures carefully preserve and combine features from multiple depths.

---

## 2. Object Detection Foundations

### Intuitive Explanation

Object detection answers: "What objects are in this image, and where is each one?" This requires predicting both a **class label** and a **bounding box** (x, y, width, height) for every object. The fundamental building blocks are:

1. **Anchor boxes** — a grid of pre-defined candidate regions tiled across the image
2. **IoU (Intersection over Union)** — the metric that measures how well a predicted box matches a ground truth box
3. **Non-Maximum Suppression (NMS)** — the post-processing step that removes duplicate detections

Let's implement each one from scratch.

In [None]:
def compute_iou(box1, box2):
    """
    Compute Intersection over Union between two bounding boxes.

    Args:
        box1: [x1, y1, x2, y2] -- top-left and bottom-right corners
        box2: [x1, y1, x2, y2] -- top-left and bottom-right corners

    Returns:
        IoU value between 0 and 1
    """
    # Compute intersection coordinates
    inter_x1 = max(box1[0], box2[0])
    inter_y1 = max(box1[1], box2[1])
    inter_x2 = min(box1[2], box2[2])
    inter_y2 = min(box1[3], box2[3])

    # Compute intersection area (0 if no overlap)
    inter_width = max(0, inter_x2 - inter_x1)
    inter_height = max(0, inter_y2 - inter_y1)
    inter_area = inter_width * inter_height

    # Compute union area
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union_area = area1 + area2 - inter_area

    # Avoid division by zero
    if union_area == 0:
        return 0.0

    return inter_area / union_area

# Test IoU with examples
box_a = [1, 1, 4, 4]  # 3x3 box
box_b = [2, 2, 5, 5]  # 3x3 box, partially overlapping
box_c = [5, 5, 8, 8]  # 3x3 box, no overlap
box_d = [1, 1, 4, 4]  # identical to box_a

print("IoU Examples:")
print(f"  Partial overlap:  IoU(A, B) = {compute_iou(box_a, box_b):.4f}")
print(f"  No overlap:       IoU(A, C) = {compute_iou(box_a, box_c):.4f}")
print(f"  Perfect overlap:  IoU(A, D) = {compute_iou(box_a, box_d):.4f}")

# Verify: A and B overlap in a 2x2 region. Union = 9+9-4 = 14. IoU = 4/14
print(f"\nManual check: 4/14 = {4/14:.4f}")

In [None]:
# Visualize IoU with different overlap levels
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

examples = [
    ("No Overlap\nIoU = 0.00", [1, 1, 4, 4], [5, 5, 8, 8]),
    ("Small Overlap\nIoU = 0.07", [1, 1, 4, 4], [3, 3, 6, 6]),
    ("Large Overlap\nIoU = 0.29", [1, 1, 4, 4], [2, 2, 5, 5]),
    ("Perfect Overlap\nIoU = 1.00", [1, 1, 4, 4], [1, 1, 4, 4]),
]

for ax, (title, b1, b2) in zip(axes, examples):
    ax.set_xlim(0, 9)
    ax.set_ylim(0, 9)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

    # Draw box 1 (blue)
    rect1 = patches.Rectangle((b1[0], b1[1]), b1[2]-b1[0], b1[3]-b1[1],
                                linewidth=2, edgecolor='blue', facecolor='blue', alpha=0.3)
    ax.add_patch(rect1)

    # Draw box 2 (red)
    rect2 = patches.Rectangle((b2[0], b2[1]), b2[2]-b2[0], b2[3]-b2[1],
                                linewidth=2, edgecolor='red', facecolor='red', alpha=0.3)
    ax.add_patch(rect2)

    # Highlight intersection (green)
    ix1, iy1 = max(b1[0], b2[0]), max(b1[1], b2[1])
    ix2, iy2 = min(b1[2], b2[2]), min(b1[3], b2[3])
    if ix2 > ix1 and iy2 > iy1:
        inter = patches.Rectangle((ix1, iy1), ix2-ix1, iy2-iy1,
                                   linewidth=2, edgecolor='green', facecolor='green', alpha=0.5)
        ax.add_patch(inter)

    ax.set_title(title, fontsize=12, fontweight='bold')
    ax.set_xlabel('x')
    ax.set_ylabel('y')

axes[0].legend([patches.Patch(color='blue', alpha=0.3),
                patches.Patch(color='red', alpha=0.3),
                patches.Patch(color='green', alpha=0.5)],
               ['Box A', 'Box B', 'Intersection'], loc='upper right', fontsize=9)

plt.suptitle('Intersection over Union (IoU)', fontsize=14, fontweight='bold', y=1.03)
plt.tight_layout()
plt.show()

### Anchor Boxes

Instead of searching every possible rectangle in an image, detection models tile the image with a set of **anchor boxes** (also called "priors" or "default boxes") at each spatial location. Each anchor has a predefined **aspect ratio** and **scale**. The network then predicts:

1. **Offsets** — small adjustments (dx, dy, dw, dh) to shift each anchor to better fit an object
2. **Objectness score** — probability that this anchor actually contains an object
3. **Class probabilities** — what kind of object it is

This is far more efficient than searching from scratch — the anchors provide a good starting point, and the network only needs to learn small refinements.

In [None]:
# Visualize anchor boxes tiled across an image
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: grid of anchor centers
ax = axes[0]
ax.set_xlim(0, 8)
ax.set_ylim(0, 8)
ax.set_aspect('equal')
ax.set_title('Anchor Box Grid (Centers)', fontsize=13, fontweight='bold')

# Create a grid of anchor centers
grid_size = 4
for i in range(grid_size):
    for j in range(grid_size):
        cx = i * 2 + 1
        cy = j * 2 + 1
        ax.plot(cx, cy, 'k+', markersize=10, markeredgewidth=2)
        rect = patches.Rectangle((i*2, j*2), 2, 2, linewidth=1,
                                  edgecolor='gray', facecolor='none', linestyle='--')
        ax.add_patch(rect)

ax.set_xlabel('x')
ax.set_ylabel('y')

# Right: multiple anchors at one location
ax = axes[1]
ax.set_xlim(0, 8)
ax.set_ylim(0, 8)
ax.set_aspect('equal')
ax.set_title('Multiple Anchors at One Location', fontsize=13, fontweight='bold')

cx, cy = 4, 4
ax.plot(cx, cy, 'k+', markersize=15, markeredgewidth=3)

# Different aspect ratios and scales
anchors = [
    (2, 2, 'blue', '1:1 small'),
    (3, 3, 'red', '1:1 large'),
    (4, 2, 'green', '2:1'),
    (2, 4, 'orange', '1:2'),
    (5, 2.5, 'purple', '2:1 large'),
]

for w, h, color, label in anchors:
    rect = patches.Rectangle((cx - w/2, cy - h/2), w, h,
                              linewidth=2, edgecolor=color, facecolor='none',
                              linestyle='-', label=label)
    ax.add_patch(rect)

ax.legend(loc='upper right', fontsize=10)
ax.set_xlabel('x')
ax.set_ylabel('y')

plt.suptitle('How Anchor Boxes Work', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("At each grid cell, multiple anchors with different shapes cover various object geometries.")
print(f"With a {grid_size}x{grid_size} grid and {len(anchors)} anchors per cell: {grid_size**2 * len(anchors)} total anchors.")

### Non-Maximum Suppression (NMS)

A detection model with many anchors will typically produce **multiple overlapping detections** for the same object. NMS cleans this up:

1. Sort all detections by confidence score (highest first)
2. Take the highest-scoring detection as a "keep"
3. Remove all other detections that overlap with it above an IoU threshold (e.g., 0.5)
4. Repeat with the next highest-scoring surviving detection
5. Continue until no detections remain

**What this means:** NMS keeps the single best detection for each object and suppresses the duplicates.

In [None]:
def nms(boxes, scores, iou_threshold=0.5):
    """
    Non-Maximum Suppression: remove duplicate detections.

    Args:
        boxes: list of [x1, y1, x2, y2] bounding boxes
        scores: list of confidence scores (one per box)
        iou_threshold: IoU above which a box is considered duplicate

    Returns:
        List of indices of boxes to keep
    """
    if len(boxes) == 0:
        return []

    # Sort by score (descending)
    order = np.argsort(scores)[::-1]

    keep = []
    while len(order) > 0:
        # Keep the highest scoring box
        best_idx = order[0]
        keep.append(best_idx)

        # Compare with all remaining boxes
        remaining = order[1:]
        suppress = []

        for i, idx in enumerate(remaining):
            iou = compute_iou(boxes[best_idx], boxes[idx])
            if iou >= iou_threshold:
                suppress.append(i)

        # Remove suppressed boxes
        order = np.delete(remaining, suppress)

    return keep

# Test NMS: three overlapping detections of the same object
boxes = [
    [1, 1, 4, 4],   # detection A
    [1.2, 1.1, 4.2, 4.1],  # detection B (overlaps A heavily)
    [1.5, 1.3, 4.5, 4.3],  # detection C (overlaps A)
    [6, 6, 9, 9],   # detection D (separate object)
]
scores = [0.9, 0.75, 0.8, 0.85]

kept = nms(boxes, scores, iou_threshold=0.5)
print("NMS Results:")
print(f"  Input: {len(boxes)} detections")
print(f"  Kept indices: {kept}")
print(f"  Kept scores: {[scores[i] for i in kept]}")
print(f"  Output: {len(kept)} detections (duplicates removed)")

In [None]:
# Visualize NMS before and after
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

colors = ['blue', 'red', 'green', 'orange']
labels = [f'Score={s:.2f}' for s in scores]

# Before NMS
ax = axes[0]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect('equal')
ax.set_title('Before NMS (4 detections)', fontsize=13, fontweight='bold')

for i, (box, score, color) in enumerate(zip(boxes, scores, colors)):
    rect = patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1],
                              linewidth=2, edgecolor=color, facecolor=color, alpha=0.2,
                              label=f'Box {i}: score={score:.2f}')
    ax.add_patch(rect)

ax.legend(loc='upper right', fontsize=10)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.grid(True, alpha=0.3)

# After NMS
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect('equal')
ax.set_title(f'After NMS (kept {len(kept)} detections)', fontsize=13, fontweight='bold')

for i in kept:
    box = boxes[i]
    rect = patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1],
                              linewidth=3, edgecolor=colors[i], facecolor=colors[i], alpha=0.3,
                              label=f'Box {i}: score={scores[i]:.2f}')
    ax.add_patch(rect)

ax.legend(loc='upper right', fontsize=10)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.grid(True, alpha=0.3)

plt.suptitle('Non-Maximum Suppression', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

---

## 3. YOLO: You Only Look Once

### Intuitive Explanation

Early detection methods like R-CNN were **two-stage**: first propose candidate regions, then classify each one. This was accurate but slow — hundreds of separate classification passes per image.

**YOLO** revolutionized detection by framing it as a **single regression problem**: one neural network pass predicts all bounding boxes and class probabilities simultaneously.

#### How YOLO Works

1. **Divide** the image into an S x S grid (e.g., 7x7)
2. Each grid cell predicts **B bounding boxes**, each with:
   - 4 coordinates: (x, y, w, h) relative to the cell
   - 1 objectness confidence: P(object) x IoU(pred, truth)
3. Each grid cell also predicts **C class probabilities** (shared across all B boxes)
4. The output tensor has shape: **S x S x (B x 5 + C)**

For example, with S=7, B=2, and C=20 (PASCAL VOC classes):
- Output shape: 7 x 7 x (2 x 5 + 20) = 7 x 7 x 30

In [None]:
# Visualize YOLO grid-based detection
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

S = 7  # grid size

# 1. The grid
ax = axes[0]
ax.set_xlim(0, S)
ax.set_ylim(0, S)
ax.set_aspect('equal')
ax.set_title(f'{S}x{S} Grid Over Image', fontsize=13, fontweight='bold')

for i in range(S + 1):
    ax.axhline(y=i, color='gray', linewidth=0.5, alpha=0.5)
    ax.axvline(x=i, color='gray', linewidth=0.5, alpha=0.5)

ax.fill_between([0, S], 0, S, color='lightblue', alpha=0.3)
ax.set_xlabel('Grid columns')
ax.set_ylabel('Grid rows')

# 2. Cell responsibility
ax = axes[1]
ax.set_xlim(0, S)
ax.set_ylim(0, S)
ax.set_aspect('equal')
ax.set_title('Cell Responsible for Object', fontsize=13, fontweight='bold')

for i in range(S + 1):
    ax.axhline(y=i, color='gray', linewidth=0.5, alpha=0.5)
    ax.axvline(x=i, color='gray', linewidth=0.5, alpha=0.5)

obj_cx, obj_cy = 3.6, 4.3
ax.plot(obj_cx, obj_cy, 'r*', markersize=20, label='Object center')
responsible_cell = patches.Rectangle((3, 4), 1, 1, linewidth=3,
                                      edgecolor='red', facecolor='red', alpha=0.3,
                                      label='Responsible cell')
ax.add_patch(responsible_cell)

obj_box = patches.Rectangle((2.1, 3.0), 3.0, 2.6, linewidth=2,
                             edgecolor='blue', facecolor='none', linestyle='--',
                             label='Ground truth box')
ax.add_patch(obj_box)
ax.legend(loc='upper right', fontsize=9)
ax.set_xlabel('Grid columns')
ax.set_ylabel('Grid rows')

# 3. Output tensor structure
ax = axes[2]
ax.set_xlim(0, 10)
ax.set_ylim(0, 8)
ax.axis('off')
ax.set_title('Output Tensor per Cell', fontsize=13, fontweight='bold')

y_start = 6.5
box_h = 0.6
box_w = 1.2

elements_b1 = ['x1', 'y1', 'w1', 'h1', 'conf1']
for i, elem in enumerate(elements_b1):
    rect = patches.FancyBboxPatch((i * box_w + 0.2, y_start), box_w - 0.1, box_h,
                                   boxstyle="round,pad=0.05", facecolor='lightcoral', edgecolor='red')
    ax.add_patch(rect)
    ax.text(i * box_w + 0.2 + (box_w-0.1)/2, y_start + box_h/2, elem,
            ha='center', va='center', fontsize=9, fontweight='bold')
ax.text(3.1, y_start + box_h + 0.3, 'Box 1 (5 values)', ha='center', fontsize=10, color='red')

y_b2 = y_start - 1.2
for i, elem in enumerate(elements_b1):
    elem2 = elem.replace('1', '2')
    rect = patches.FancyBboxPatch((i * box_w + 0.2, y_b2), box_w - 0.1, box_h,
                                   boxstyle="round,pad=0.05", facecolor='lightgreen', edgecolor='green')
    ax.add_patch(rect)
    ax.text(i * box_w + 0.2 + (box_w-0.1)/2, y_b2 + box_h/2, elem2,
            ha='center', va='center', fontsize=9, fontweight='bold')
ax.text(3.1, y_b2 + box_h + 0.3, 'Box 2 (5 values)', ha='center', fontsize=10, color='green')

y_cls = y_b2 - 1.5
cls_labels = ['P(cat)', 'P(dog)', 'P(car)', '...', 'P(clsC)']
for i, elem in enumerate(cls_labels):
    rect = patches.FancyBboxPatch((i * box_w + 0.2, y_cls), box_w - 0.1, box_h,
                                   boxstyle="round,pad=0.05", facecolor='lightyellow', edgecolor='orange')
    ax.add_patch(rect)
    ax.text(i * box_w + 0.2 + (box_w-0.1)/2, y_cls + box_h/2, elem,
            ha='center', va='center', fontsize=8, fontweight='bold')
ax.text(3.1, y_cls + box_h + 0.3, 'Class probs (C values)', ha='center', fontsize=10, color='orange')

ax.text(3.1, y_cls - 0.6, 'Total per cell: B*5 + C values', ha='center', fontsize=11,
        fontweight='bold', style='italic')

plt.tight_layout()
plt.show()

In [None]:
class SimpleYOLOHead(nn.Module):
    """
    Simplified YOLO-style detection head.
    Takes a feature map and predicts bounding boxes + classes per grid cell.
    """
    def __init__(self, in_channels, S=7, B=2, C=20):
        super().__init__()
        self.S = S
        self.B = B
        self.C = C

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, B * 5 + C, 1),
        )

    def forward(self, x):
        """
        Args:
            x: feature map of shape (batch, in_channels, S, S)
        Returns:
            predictions: (batch, S, S, B*5 + C)
        """
        out = self.conv(x)
        out = out.permute(0, 2, 3, 1)
        return out

    def decode(self, predictions):
        """Decode raw predictions into boxes, confidences, and class probs."""
        batch_size = predictions.shape[0]

        box_preds = predictions[..., :self.B * 5].reshape(
            batch_size, self.S, self.S, self.B, 5
        )
        class_preds = predictions[..., self.B * 5:]

        xy = torch.sigmoid(box_preds[..., :2])
        wh = box_preds[..., 2:4]
        conf = torch.sigmoid(box_preds[..., 4])
        class_probs = torch.softmax(class_preds, dim=-1)

        return xy, wh, conf, class_probs

# Demo
S, B, C = 7, 2, 20
head = SimpleYOLOHead(in_channels=512, S=S, B=B, C=C)

dummy_features = torch.randn(1, 512, S, S)
predictions = head(dummy_features)

print(f"Input feature map:  {dummy_features.shape}")
print(f"Output predictions: {predictions.shape}")
print(f"  Per cell: {B*5 + C} values = {B} boxes x 5 + {C} classes")

xy, wh, conf, class_probs = head.decode(predictions)
print(f"\nDecoded:")
print(f"  Box centers (xy):    {xy.shape}")
print(f"  Box sizes (wh):      {wh.shape}")
print(f"  Confidence:          {conf.shape}")
print(f"  Class probabilities: {class_probs.shape}")

### Deep Dive: One-Stage vs Two-Stage Detectors

| Aspect | Two-Stage (R-CNN family) | One-Stage (YOLO, SSD) |
|--------|------------------------|----------------------|
| **Pipeline** | Region proposal, then classify each region | Single pass: grid predictions for all boxes |
| **Speed** | Slower (separate steps) | Faster (single forward pass) |
| **Accuracy** | Generally higher for small objects | Competitive, sometimes lower on small objects |
| **Examples** | R-CNN, Fast R-CNN, Faster R-CNN | YOLO, SSD, RetinaNet |
| **Key innovation** | Region Proposal Network (RPN) | Grid-based direct prediction |
| **Use case** | When accuracy matters most | When speed matters (real-time) |

#### Key Insight

The fundamental tradeoff is **thoroughness vs speed**. Two-stage detectors carefully examine each proposed region, while one-stage detectors make predictions everywhere simultaneously. Modern one-stage detectors like YOLOv8 have largely closed the accuracy gap, making them the dominant choice for most real-world applications.

#### Common Misconceptions

| Misconception | Reality |
|---------------|--------|
| "YOLO is always less accurate" | Modern YOLO versions rival two-stage detectors on most benchmarks |
| "Two-stage is always better for small objects" | Techniques like Feature Pyramid Networks (FPN) help one-stage detectors with small objects |
| "You must choose one approach" | Many systems combine ideas from both families |

---

## 4. Semantic Segmentation

### Intuitive Explanation

Semantic segmentation assigns a **class label to every single pixel** in the image. Unlike detection (which draws boxes), segmentation produces a precise mask showing exactly which pixels belong to each category.

The core challenge: CNNs for classification progressively **downsample** the spatial resolution (via pooling and striding) to build high-level features. But segmentation needs **pixel-level output** — the same resolution as the input. How do we get back to full resolution?

The answer: **encoder-decoder architectures** that first compress (encode) then expand (decode) the spatial dimensions.

In [None]:
# Demonstrate upsampling methods
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Create a small "feature map"
small = np.array([[1, 2],
                   [3, 4]], dtype=float)

# Method 1: Nearest neighbor upsampling
ax = axes[0]
nearest = np.repeat(np.repeat(small, 2, axis=0), 2, axis=1)
im = ax.imshow(nearest, cmap='Blues', vmin=0, vmax=5)
ax.set_title('Nearest Neighbor\nUpsampling', fontsize=12, fontweight='bold')
for i in range(4):
    for j in range(4):
        ax.text(j, i, f'{nearest[i,j]:.0f}', ha='center', va='center', fontsize=14, fontweight='bold')
ax.set_xticks([]); ax.set_yticks([])

# Method 2: Bilinear interpolation
ax = axes[1]
small_tensor = torch.tensor(small).unsqueeze(0).unsqueeze(0).float()
bilinear = F.interpolate(small_tensor, scale_factor=2, mode='bilinear', align_corners=True)
bilinear_np = bilinear.squeeze().numpy()
im = ax.imshow(bilinear_np, cmap='Blues', vmin=0, vmax=5)
ax.set_title('Bilinear\nInterpolation', fontsize=12, fontweight='bold')
for i in range(4):
    for j in range(4):
        ax.text(j, i, f'{bilinear_np[i,j]:.1f}', ha='center', va='center', fontsize=12, fontweight='bold')
ax.set_xticks([]); ax.set_yticks([])

# Method 3: Transposed convolution (learnable!)
ax = axes[2]
conv_t = nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2, bias=False)
conv_t.weight.data = torch.tensor([[[[1.0, 0.5], [0.5, 0.25]]]])
with torch.no_grad():
    trans_out = conv_t(small_tensor)
trans_np = trans_out.squeeze().numpy()
im = ax.imshow(trans_np, cmap='Blues', vmin=0, vmax=5)
ax.set_title('Transposed Convolution\n(Learnable!)', fontsize=12, fontweight='bold')
for i in range(trans_np.shape[0]):
    for j in range(trans_np.shape[1]):
        ax.text(j, i, f'{trans_np[i,j]:.1f}', ha='center', va='center', fontsize=12, fontweight='bold')
ax.set_xticks([]); ax.set_yticks([])

plt.suptitle('Upsampling Methods: From Low-Res Feature Maps to High-Res Output',
             fontsize=13, fontweight='bold', y=1.03)
plt.tight_layout()
plt.show()

print("Key difference: Transposed convolution has LEARNABLE parameters.")
print("The network can learn the best way to upsample for the task.")

### The U-Net Architecture

U-Net is the most influential segmentation architecture, originally designed for biomedical image segmentation. Its key innovation is **skip connections** between the encoder and decoder at matching resolutions:

```
Encoder (Downsampling)          Decoder (Upsampling)

Input (256x256)  ------------------>  Output (256x256)
    | conv+pool                          ^ upconv+concat
    v                                    |
  (128x128)  ------------------------> (128x128)
    | conv+pool                          ^ upconv+concat
    v                                    |
   (64x64)  ------------------------->  (64x64)
    | conv+pool                          ^ upconv+concat
    v                                    |
   (32x32)  ------------------------->  (32x32)
    | conv+pool                          ^ upconv
    v                                    |
          (16x16) Bottleneck
```

#### Why Skip Connections?

The encoder captures **what** features are present (semantics) but loses **where** they are (spatial detail). The decoder needs to recover precise boundaries. Skip connections pass the high-resolution spatial information directly from the encoder to the decoder, giving it both:

- **Deep features** (from the bottleneck path): "This region contains a cat"
- **Fine details** (from the skip connections): "The exact boundary of the cat"

This combination produces sharp, accurate segmentation masks.

In [None]:
# Visualize U-Net architecture
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')
ax.set_title('U-Net Architecture', fontsize=16, fontweight='bold')

# Encoder blocks (left side, going down)
encoder_blocks = [
    (1, 8, 2.0, 1.2, '3x256x256\n64 ch', 'lightblue'),
    (2, 6.3, 1.6, 1.2, '64x128x128\n128 ch', 'cornflowerblue'),
    (3, 4.6, 1.3, 1.2, '128x64x64\n256 ch', 'royalblue'),
    (4, 2.9, 1.0, 1.2, '256x32x32\n512 ch', 'mediumblue'),
]

# Bottleneck
bottleneck = (5.5, 1.2, 0.8, 2.0, '512x16x16\n1024 ch', 'darkblue')

# Decoder blocks (right side, going up)
decoder_blocks = [
    (8.5, 2.9, 1.0, 1.2, '512x32x32', 'lightsalmon'),
    (9.5, 4.6, 1.3, 1.2, '256x64x64', 'salmon'),
    (10.5, 6.3, 1.6, 1.2, '128x128x128', 'lightcoral'),
    (11.5, 8, 2.0, 1.2, '64x256x256', 'indianred'),
]

# Draw encoder
for x, y, w, h, label, color in encoder_blocks:
    rect = patches.FancyBboxPatch((x, y), h, w, boxstyle="round,pad=0.1",
                                   facecolor=color, edgecolor='black', alpha=0.7)
    ax.add_patch(rect)
    ax.text(x + h/2, y + w/2, label, ha='center', va='center', fontsize=7, color='white', fontweight='bold')

# Draw bottleneck
bx, by, bw, bh, blabel, bcolor = bottleneck
rect = patches.FancyBboxPatch((bx, by), bh, bw, boxstyle="round,pad=0.1",
                               facecolor=bcolor, edgecolor='black', alpha=0.8)
ax.add_patch(rect)
ax.text(bx + bh/2, by + bw/2, blabel, ha='center', va='center', fontsize=7, color='white', fontweight='bold')

# Draw decoder
for x, y, w, h, label, color in decoder_blocks:
    rect = patches.FancyBboxPatch((x, y), h, w, boxstyle="round,pad=0.1",
                                   facecolor=color, edgecolor='black', alpha=0.7)
    ax.add_patch(rect)
    ax.text(x + h/2, y + w/2, label, ha='center', va='center', fontsize=7, color='white', fontweight='bold')

# Draw skip connections (horizontal arrows)
skip_ys = [8.8, 7.1, 5.3, 3.4]
enc_rights = [2.2, 3.2, 4.2, 5.2]
dec_lefts = [11.5, 10.5, 9.5, 8.5]
for sy, er, dl in zip(skip_ys, enc_rights, dec_lefts):
    ax.annotate('', xy=(dl, sy), xytext=(er, sy),
                arrowprops=dict(arrowstyle='->', color='green', lw=2, linestyle='--'))
    ax.text((er + dl) / 2, sy + 0.15, 'skip', ha='center', fontsize=8, color='green', fontweight='bold')

# Down arrows (encoder)
for i in range(3):
    x_from = encoder_blocks[i][0] + encoder_blocks[i][3]/2
    y_from = encoder_blocks[i][1]
    x_to = encoder_blocks[i+1][0] + encoder_blocks[i+1][3]/2
    y_to = encoder_blocks[i+1][1] + encoder_blocks[i+1][2]
    ax.annotate('', xy=(x_to, y_to), xytext=(x_from, y_from),
                arrowprops=dict(arrowstyle='->', color='blue', lw=2))

# Arrow from last encoder to bottleneck
ax.annotate('', xy=(bx + bh/2, by + bw),
            xytext=(encoder_blocks[-1][0] + encoder_blocks[-1][3]/2, encoder_blocks[-1][1]),
            arrowprops=dict(arrowstyle='->', color='blue', lw=2))

# Arrow from bottleneck to first decoder
ax.annotate('', xy=(decoder_blocks[0][0] + decoder_blocks[0][3]/2, decoder_blocks[0][1] + decoder_blocks[0][2]),
            xytext=(bx + bh/2, by + bw),
            arrowprops=dict(arrowstyle='->', color='red', lw=2))

# Up arrows (decoder)
for i in range(3):
    x_from = decoder_blocks[i][0] + decoder_blocks[i][3]/2
    y_from = decoder_blocks[i][1] + decoder_blocks[i][2]
    x_to = decoder_blocks[i+1][0] + decoder_blocks[i+1][3]/2
    y_to = decoder_blocks[i+1][1]
    ax.annotate('', xy=(x_to, y_to), xytext=(x_from, y_from),
                arrowprops=dict(arrowstyle='->', color='red', lw=2))

# Labels
ax.text(2, 0.5, 'ENCODER\n(Downsample)', ha='center', fontsize=12, fontweight='bold', color='blue')
ax.text(6.5, 0.5, 'BOTTLENECK', ha='center', fontsize=12, fontweight='bold', color='darkblue')
ax.text(11, 0.5, 'DECODER\n(Upsample)', ha='center', fontsize=12, fontweight='bold', color='red')

plt.tight_layout()
plt.show()

In [None]:
class DoubleConv(nn.Module):
    """Two consecutive conv-batchnorm-relu blocks (the basic U-Net building block)."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class MiniUNet(nn.Module):
    """
    A compact U-Net for semantic segmentation.

    Architecture:
        Encoder: 3 downsampling stages (channels: 64 -> 128 -> 256)
        Bottleneck: 512 channels
        Decoder: 3 upsampling stages with skip connections
        Output: num_classes channels (one per class)
    """
    def __init__(self, in_channels=3, num_classes=10):
        super().__init__()

        # Encoder path
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.pool = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.bottleneck = DoubleConv(256, 512)

        # Decoder path (note: input channels = skip + upsampled)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)  # 256 (up) + 256 (skip) = 512 in

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)  # 128 (up) + 128 (skip) = 256 in

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)   # 64 (up) + 64 (skip) = 128 in

        # Final 1x1 convolution to get class predictions
        self.final = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)             # (B, 64, H, W)
        e2 = self.enc2(self.pool(e1)) # (B, 128, H/2, W/2)
        e3 = self.enc3(self.pool(e2)) # (B, 256, H/4, W/4)

        # Bottleneck
        b = self.bottleneck(self.pool(e3))  # (B, 512, H/8, W/8)

        # Decoder with skip connections
        d3 = self.up3(b)                    # (B, 256, H/4, W/4)
        d3 = torch.cat([d3, e3], dim=1)     # (B, 512, H/4, W/4)
        d3 = self.dec3(d3)                  # (B, 256, H/4, W/4)

        d2 = self.up2(d3)                   # (B, 128, H/2, W/2)
        d2 = torch.cat([d2, e2], dim=1)     # (B, 256, H/2, W/2)
        d2 = self.dec2(d2)                  # (B, 128, H/2, W/2)

        d1 = self.up1(d2)                   # (B, 64, H, W)
        d1 = torch.cat([d1, e1], dim=1)     # (B, 128, H, W)
        d1 = self.dec1(d1)                  # (B, 64, H, W)

        return self.final(d1)               # (B, num_classes, H, W)

# Test the Mini U-Net
model = MiniUNet(in_channels=3, num_classes=10)

x = torch.randn(2, 3, 64, 64)
output = model(x)

print(f"Input shape:  {x.shape}  -> (batch, channels, height, width)")
print(f"Output shape: {output.shape} -> (batch, num_classes, height, width)")
print(f"\nOutput is a per-pixel class prediction!")
print(f"To get the predicted class per pixel: output.argmax(dim=1) -> shape {output.argmax(dim=1).shape}")

total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

In [None]:
# Visualize what the U-Net produces
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

np.random.seed(42)
torch.manual_seed(42)
fake_image = torch.randn(1, 3, 64, 64)

with torch.no_grad():
    pred = model(fake_image)
    pred_classes = pred.argmax(dim=1).squeeze().numpy()

# Display input channels
ax = axes[0]
ax.imshow(fake_image[0, 0].numpy(), cmap='gray')
ax.set_title('Input (channel 0)', fontsize=12, fontweight='bold')
ax.axis('off')

# Display raw predictions (one channel)
ax = axes[1]
im = ax.imshow(pred[0, 0].detach().numpy(), cmap='RdBu_r')
ax.set_title('Raw Prediction (class 0 logits)', fontsize=12, fontweight='bold')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)

# Display argmax segmentation map
ax = axes[2]
im = ax.imshow(pred_classes, cmap='tab10', vmin=0, vmax=9)
ax.set_title('Segmentation Map (argmax)', fontsize=12, fontweight='bold')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)

plt.suptitle('Mini U-Net Output (Untrained -- Random Predictions)', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("Note: The model is untrained, so predictions are random.")
print("After training on labeled data, each pixel would be correctly classified.")

---

## 5. Instance Segmentation

### Intuitive Explanation

Semantic segmentation labels every pixel with a class, but it cannot distinguish between **individual instances** of the same class. If two cats overlap, semantic segmentation labels all their pixels as "cat" — but cannot tell you which pixels belong to cat #1 vs cat #2.

**Instance segmentation** solves this by combining detection (separate bounding boxes per object) with segmentation (pixel-level mask within each box).

| Approach | What It Outputs | Limitation |
|----------|----------------|------------|
| Semantic Segmentation | Per-pixel class label | Cannot separate overlapping objects of the same class |
| Instance Segmentation | Per-pixel class + instance ID | More complex; needs detection + segmentation |
| Panoptic Segmentation | Both (all pixels labeled + instances separated) | The full picture; most complex |

### Mask R-CNN

The dominant instance segmentation architecture is **Mask R-CNN**, which extends Faster R-CNN by adding a small mask prediction branch:

1. **Backbone CNN** extracts feature maps
2. **Region Proposal Network (RPN)** proposes candidate object regions
3. **ROI Align** extracts fixed-size features for each proposed region (improvement over ROI Pooling — uses bilinear interpolation instead of harsh quantization)
4. **Classification head** predicts class and refines bounding box
5. **Mask head** predicts a binary mask for each detected object (new in Mask R-CNN)

The mask head is a small FCN that outputs a 28x28 binary mask per detected object. This mask is then resized to match the bounding box dimensions.

#### Key Insight

Mask R-CNN **decouples** class prediction from mask prediction. The mask head predicts a binary mask for *each* class independently, and the classification head decides which class's mask to use. This avoids competition between classes at the mask level and significantly improves quality.

In [None]:
# Visualize the difference between semantic and instance segmentation
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for ax in axes:
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    ax.set_aspect('equal')
    ax.axis('off')

# Original "image"
ax = axes[0]
ax.set_title('Original Image', fontsize=13, fontweight='bold')
ax.add_patch(patches.FancyBboxPatch((0.5, 0.5), 9, 9, boxstyle="round,pad=0.1",
             facecolor='lightyellow', edgecolor='gray'))
c1 = plt.Circle((3.5, 5), 2, color='gray', alpha=0.6)
c2 = plt.Circle((5.5, 5), 2, color='gray', alpha=0.6)
c3 = plt.Circle((8, 3), 1.2, color='brown', alpha=0.5)
ax.add_patch(c1); ax.add_patch(c2); ax.add_patch(c3)
ax.text(3.5, 5, 'Cat', ha='center', va='center', fontsize=10, fontweight='bold')
ax.text(5.5, 5, 'Cat', ha='center', va='center', fontsize=10, fontweight='bold')
ax.text(8, 3, 'Dog', ha='center', va='center', fontsize=10, fontweight='bold')

# Semantic segmentation
ax = axes[1]
ax.set_title('Semantic Segmentation', fontsize=13, fontweight='bold')
ax.add_patch(patches.FancyBboxPatch((0.5, 0.5), 9, 9, boxstyle="round,pad=0.1",
             facecolor='lightyellow', edgecolor='gray'))
c1 = plt.Circle((3.5, 5), 2, color='red', alpha=0.4)
c2 = plt.Circle((5.5, 5), 2, color='red', alpha=0.4)
c3 = plt.Circle((8, 3), 1.2, color='blue', alpha=0.4)
ax.add_patch(c1); ax.add_patch(c2); ax.add_patch(c3)
ax.text(4.5, 8, 'cat (same color)', fontsize=10, color='red', fontweight='bold', ha='center')
ax.text(8, 1.2, 'dog', fontsize=10, color='blue', fontweight='bold', ha='center')

# Instance segmentation
ax = axes[2]
ax.set_title('Instance Segmentation', fontsize=13, fontweight='bold')
ax.add_patch(patches.FancyBboxPatch((0.5, 0.5), 9, 9, boxstyle="round,pad=0.1",
             facecolor='lightyellow', edgecolor='gray'))
c1 = plt.Circle((3.5, 5), 2, color='red', alpha=0.4)
c2 = plt.Circle((5.5, 5), 2, color='green', alpha=0.4)
c3 = plt.Circle((8, 3), 1.2, color='blue', alpha=0.4)
ax.add_patch(c1); ax.add_patch(c2); ax.add_patch(c3)
ax.text(3.5, 7.5, 'cat #1', fontsize=10, color='red', fontweight='bold', ha='center')
ax.text(5.5, 7.5, 'cat #2', fontsize=10, color='green', fontweight='bold', ha='center')
ax.text(8, 1.2, 'dog #1', fontsize=10, color='blue', fontweight='bold', ha='center')

plt.suptitle('Semantic vs Instance Segmentation', fontsize=14, fontweight='bold', y=1.03)
plt.tight_layout()
plt.show()

---

## 6. Vision Transformers (ViT)

### Intuitive Explanation

Transformers revolutionized NLP (as we will see in notebooks 16-17). The natural question: can we use the same self-attention mechanism for images?

The challenge is scale: a 224x224 image has 50,176 pixels. Self-attention is O(n^2) in sequence length — applying it to every pixel would be prohibitively expensive.

**The ViT solution:** Divide the image into non-overlapping **patches** (e.g., 16x16 pixels), treat each patch as a "token" (like a word in NLP), and apply a standard Transformer encoder.

#### The ViT Pipeline

1. **Split** the image into P x P patches (e.g., 16x16)
2. **Flatten** each patch into a vector (16 x 16 x 3 = 768 dimensions)
3. **Project** each flattened patch through a linear layer (the "patch embedding")
4. **Add positional embeddings** (so the model knows where each patch was)
5. **Prepend a [CLS] token** (for classification, like BERT)
6. **Feed through a standard Transformer encoder** (self-attention + FFN)
7. **Use the [CLS] token output** for classification

This is remarkably simple — and it works extremely well, especially with large datasets.

In [None]:
# Visualize how an image gets split into patches
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

H, W = 224, 224
P = 16  # patch size
n_patches = (H // P) * (W // P)

# Generate a colorful image
x = np.linspace(0, 1, W)
y = np.linspace(0, 1, H)
xx, yy = np.meshgrid(x, y)
image = np.stack([xx, yy, 1 - xx], axis=-1)  # RGB gradient

# 1. Original image
ax = axes[0]
ax.imshow(image)
ax.set_title(f'Original Image ({H}x{W})', fontsize=12, fontweight='bold')
ax.axis('off')

# 2. Image with patch grid
ax = axes[1]
ax.imshow(image)
for i in range(0, H + 1, P):
    ax.axhline(y=i, color='white', linewidth=0.5, alpha=0.8)
    ax.axvline(x=i, color='white', linewidth=0.5, alpha=0.8)
ax.set_title(f'Divided into {P}x{P} Patches\n({n_patches} patches total)', fontsize=12, fontweight='bold')
ax.axis('off')

# 3. Patches as a sequence
ax = axes[2]
n_show = 10
patch_sequence = []
for i in range(0, min(n_show * P, H), P):
    patch = image[0:P, i:i+P]
    patch_sequence.append(patch)

combined = np.concatenate(patch_sequence, axis=1)
ax.imshow(combined)
ax.set_title(f'First {n_show} Patches as Sequence', fontsize=12, fontweight='bold')
ax.axis('off')

for i in range(n_show):
    ax.text(i * P + P/2, P + 3, f'[{i}]', ha='center', fontsize=8, fontweight='bold')

plt.suptitle('Vision Transformer: Image -> Patch Sequence', fontsize=14, fontweight='bold', y=1.03)
plt.tight_layout()
plt.show()

print(f"Image size: {H}x{W}x3 = {H*W*3:,} values")
print(f"Patch size: {P}x{P}x3 = {P*P*3} values per patch")
print(f"Number of patches: {n_patches}")
print(f"Sequence length for Transformer: {n_patches} + 1 (CLS token) = {n_patches + 1}")

In [None]:
class PatchEmbedding(nn.Module):
    """
    Convert an image into a sequence of patch embeddings.

    This is the core innovation of ViT: treat an image as a sequence of
    flattened patches, just like a sentence is a sequence of words.

    Args:
        img_size: Input image size (assumes square)
        patch_size: Size of each patch (assumes square)
        in_channels: Number of input channels (3 for RGB)
        embed_dim: Dimension of the embedding space
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # Linear projection of flattened patches
        # Equivalent to Conv2d with kernel_size=stride=patch_size
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # Project patches: (B, C, H, W) -> (B, embed_dim, H/P, W/P)
        x = self.projection(x)
        # Flatten spatial dims: (B, embed_dim, H/P, W/P) -> (B, embed_dim, n_patches)
        x = x.flatten(2)
        # Transpose: (B, embed_dim, n_patches) -> (B, n_patches, embed_dim)
        x = x.transpose(1, 2)
        return x


class SimpleViT(nn.Module):
    """
    Simplified Vision Transformer for classification.

    Args:
        img_size: Input image size
        patch_size: Patch size
        in_channels: Input channels
        num_classes: Number of output classes
        embed_dim: Embedding dimension
        num_heads: Number of attention heads
        num_layers: Number of transformer layers
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3,
                 num_classes=10, embed_dim=768, num_heads=8, num_layers=6):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches

        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # Positional embeddings (learnable)
        self.pos_embed = nn.Parameter(torch.randn(1, n_patches + 1, embed_dim))

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=embed_dim * 4, dropout=0.1,
            activation='gelu', batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        batch_size = x.shape[0]

        # Patch embedding
        x = self.patch_embed(x)  # (B, n_patches, embed_dim)

        # Prepend CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, n_patches+1, embed_dim)

        # Add positional embeddings
        x = x + self.pos_embed

        # Transformer encoder
        x = self.transformer(x)

        # Use CLS token for classification
        cls_output = self.norm(x[:, 0])
        return self.classifier(cls_output)

# Test the Vision Transformer
torch.manual_seed(42)
vit = SimpleViT(img_size=64, patch_size=8, in_channels=3,
                num_classes=10, embed_dim=256, num_heads=8, num_layers=4)

x = torch.randn(2, 3, 64, 64)
output = vit(x)

print(f"Input:  {x.shape} -> (batch, channels, H, W)")
print(f"Output: {output.shape} -> (batch, num_classes)")

n_patches = (64 // 8) ** 2
print(f"\nPatch size: 8x8 = 64 pixels per patch")
print(f"Number of patches: {n_patches}")
print(f"Sequence length: {n_patches + 1} (patches + CLS)")

total_params = sum(p.numel() for p in vit.parameters())
print(f"Total parameters: {total_params:,}")

### Deep Dive: CNN vs Vision Transformer

| Aspect | CNN | Vision Transformer (ViT) |
|--------|-----|-------------------------|
| **Inductive bias** | Strong: locality, translation equivariance | Weak: only sequence order (via positional embeddings) |
| **Receptive field** | Grows gradually (layer by layer) | Global from layer 1 (self-attention sees all patches) |
| **Data efficiency** | Better with small datasets (biases help) | Needs large datasets (or pretraining) to shine |
| **Scalability** | Saturates with more data/compute | Scales well: more data leads to better performance |
| **Architecture** | Conv, Pool, Conv, Pool, FC | Patch Embed, Transformer Blocks, CLS head |
| **Position awareness** | Built-in (convolution is local) | Must be learned (positional embeddings) |
| **Compute** | O(kernel^2 x channels) per position | O(n_patches^2 x embed_dim) for self-attention |

#### Key Insight

CNNs build in **strong assumptions** about images (local patterns, translation invariance). These assumptions help with limited data but can also limit the model. ViTs make **fewer assumptions**, allowing them to learn more flexible representations — but they need more data to discover what CNNs get "for free." In practice, the best modern models often combine ideas from both: convolutional patch embeddings, local attention windows, etc.

#### Common Misconceptions

| Misconception | Reality |
|---------------|--------|
| "ViTs completely replaced CNNs" | Many state-of-the-art models are hybrids (e.g., ConvNeXt, Swin Transformer) |
| "ViTs need no spatial inductive bias" | Positional embeddings encode position; many variants add locality |
| "Bigger patches are always better" | Smaller patches capture more detail but increase sequence length quadratically |

---

## 7. CLIP: Connecting Vision and Language

### Intuitive Explanation

What if, instead of training a vision model on fixed class labels, we trained it to **understand the relationship between images and text**? That is the core idea behind CLIP (Contrastive Language-Image Pre-training) from OpenAI.

CLIP trains two separate encoders simultaneously:
- An **image encoder** (CNN or ViT) that maps images to a shared embedding space
- A **text encoder** (Transformer) that maps text descriptions to the same embedding space

The training objective is **contrastive**: in a batch of (image, text) pairs, maximize the similarity between matching pairs and minimize it for non-matching pairs.

#### Why CLIP Matters

1. **Zero-shot classification**: Classify images into categories never seen during training, just by providing text descriptions
2. **Flexible**: No fixed label set — describe any class in natural language
3. **Transfer learning**: CLIP features are remarkably general
4. **Multimodal foundation**: CLIP is a building block for text-to-image models (DALL-E, Stable Diffusion), visual question answering, and more

In [None]:
# Visualize CLIP's dual-encoder architecture and contrastive learning
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Left: Architecture
ax = axes[0]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis('off')
ax.set_title('CLIP Architecture', fontsize=14, fontweight='bold')

# Image encoder
img_box = patches.FancyBboxPatch((0.5, 6.5), 3.5, 2.5, boxstyle="round,pad=0.2",
                                  facecolor='lightblue', edgecolor='blue', linewidth=2)
ax.add_patch(img_box)
ax.text(2.25, 7.75, 'Image\nEncoder\n(ViT/CNN)', ha='center', va='center', fontsize=11, fontweight='bold')

# Text encoder
txt_box = patches.FancyBboxPatch((6, 6.5), 3.5, 2.5, boxstyle="round,pad=0.2",
                                  facecolor='lightyellow', edgecolor='orange', linewidth=2)
ax.add_patch(txt_box)
ax.text(7.75, 7.75, 'Text\nEncoder\n(Transformer)', ha='center', va='center', fontsize=11, fontweight='bold')

# Embeddings
ax.annotate('', xy=(2.25, 5.0), xytext=(2.25, 6.5),
            arrowprops=dict(arrowstyle='->', color='blue', lw=2))
ax.text(2.25, 5.5, 'Image\nembedding', ha='center', fontsize=9, color='blue')

ax.annotate('', xy=(7.75, 5.0), xytext=(7.75, 6.5),
            arrowprops=dict(arrowstyle='->', color='orange', lw=2))
ax.text(7.75, 5.5, 'Text\nembedding', ha='center', fontsize=9, color='orange')

# Shared embedding space
space_box = patches.FancyBboxPatch((1, 2.5), 8, 2.5, boxstyle="round,pad=0.2",
                                    facecolor='lightgreen', edgecolor='green', linewidth=2, alpha=0.3)
ax.add_patch(space_box)
ax.text(5, 4.5, 'Shared Embedding Space', ha='center', fontsize=12, fontweight='bold', color='green')

ax.plot(2.25, 3.5, 'bo', markersize=12)
ax.plot(7.75, 3.5, 's', color='orange', markersize=12)
ax.annotate('', xy=(7.5, 3.5), xytext=(2.5, 3.5),
            arrowprops=dict(arrowstyle='<->', color='green', lw=2))
ax.text(5, 3.2, 'cosine similarity', ha='center', fontsize=10, color='green', style='italic')

ax.text(2.25, 9.4, 'Image input', ha='center', fontsize=10, style='italic')
ax.text(7.75, 9.4, 'Text input', ha='center', fontsize=10, style='italic')

# Right: Contrastive learning matrix
ax = axes[1]
ax.set_title('Contrastive Learning Objective', fontsize=14, fontweight='bold')

batch_size = 5
np.random.seed(42)
sim_matrix = np.random.rand(batch_size, batch_size) * 0.3
for i in range(batch_size):
    sim_matrix[i, i] = 0.8 + np.random.rand() * 0.2

im = ax.imshow(sim_matrix, cmap='RdYlGn', vmin=0, vmax=1)
ax.set_xlabel('Text descriptions', fontsize=12)
ax.set_ylabel('Images', fontsize=12)
ax.set_xticks(range(batch_size))
ax.set_yticks(range(batch_size))
ax.set_xticklabels([f'Text {i}' for i in range(batch_size)], fontsize=9)
ax.set_yticklabels([f'Img {i}' for i in range(batch_size)], fontsize=9)

for i in range(batch_size):
    for j in range(batch_size):
        color = 'white' if sim_matrix[i, j] > 0.5 else 'black'
        ax.text(j, i, f'{sim_matrix[i, j]:.2f}', ha='center', va='center',
                fontsize=10, color=color, fontweight='bold')

plt.colorbar(im, ax=ax, label='Cosine Similarity')
ax.text(2, -1.2, 'Goal: maximize diagonal (matching pairs)\nminimize off-diagonal (non-matching)',
        ha='center', fontsize=10, style='italic')

plt.tight_layout()
plt.show()

In [None]:
def clip_contrastive_loss(image_embeddings, text_embeddings, temperature=0.07):
    """
    Compute CLIP-style contrastive loss (InfoNCE / NT-Xent).

    Args:
        image_embeddings: (batch_size, embed_dim) -- L2 normalized
        text_embeddings: (batch_size, embed_dim) -- L2 normalized
        temperature: scaling factor (lower = sharper distribution)

    Returns:
        Scalar loss value
    """
    # Normalize embeddings
    image_embeddings = F.normalize(image_embeddings, dim=-1)
    text_embeddings = F.normalize(text_embeddings, dim=-1)

    # Compute similarity matrix: (batch, batch)
    logits = image_embeddings @ text_embeddings.T / temperature

    # Labels: matching pairs are on the diagonal
    batch_size = logits.shape[0]
    labels = torch.arange(batch_size)

    # Cross-entropy loss in both directions
    loss_i2t = F.cross_entropy(logits, labels)       # image -> text
    loss_t2i = F.cross_entropy(logits.T, labels)     # text -> image

    return (loss_i2t + loss_t2i) / 2

# Demo: simulate CLIP training on a mini batch
torch.manual_seed(42)
batch_size = 8
embed_dim = 128

image_emb = torch.randn(batch_size, embed_dim)
text_emb = torch.randn(batch_size, embed_dim)

# Make matching pairs slightly similar (simulate partially trained model)
text_emb = text_emb + 0.3 * image_emb

loss = clip_contrastive_loss(image_emb, text_emb, temperature=0.07)
print(f"CLIP contrastive loss: {loss.item():.4f}")
print(f"Random baseline (batch={batch_size}): {np.log(batch_size):.4f}")
print(f"  (Loss should be lower than random baseline since pairs are correlated)")

print("\nEffect of temperature:")
for temp in [0.01, 0.07, 0.1, 0.5, 1.0]:
    l = clip_contrastive_loss(image_emb, text_emb, temperature=temp)
    print(f"  temperature={temp:.2f} -> loss={l.item():.4f}")

### Zero-Shot Classification with CLIP

One of CLIP's most powerful capabilities is classifying images into categories **it has never been explicitly trained on**. Here is how:

1. Given candidate class names: ["cat", "dog", "car", "airplane"]
2. Create text prompts: ["a photo of a cat", "a photo of a dog", ...]
3. Encode all text prompts with the text encoder
4. Encode the input image with the image encoder
5. Compute cosine similarity between the image embedding and each text embedding
6. The class with the highest similarity wins

No retraining needed — just change the text descriptions to classify into any set of categories.

**What this means:** CLIP turns classification from a fixed-label problem into a **language-guided** problem. Instead of retraining a model for every new task, you simply describe the task in words.

---

## 8. Transfer Learning for Vision

### Intuitive Explanation

Training a large vision model from scratch requires enormous datasets and compute. **Transfer learning** lets us reuse knowledge from a model pre-trained on a large dataset (like ImageNet with 1.2M images) and adapt it to our specific task.

There are two main approaches:

| Approach | What You Do | When to Use |
|----------|-------------|-------------|
| **Feature extraction** | Freeze the pretrained backbone, only train a new classification head | Small dataset, similar domain to pretraining |
| **Fine-tuning** | Unfreeze some/all of the pretrained backbone, train end-to-end with a small learning rate | Medium-large dataset, different domain |

#### Practical Guidelines

1. **Start with feature extraction** — it is fast and often surprisingly effective
2. **If accuracy is insufficient, try fine-tuning** the last few layers first
3. **Use a small learning rate for fine-tuning** (10-100x smaller than for the new head)
4. **Apply data augmentation** — especially important with small datasets
5. **Freeze batch normalization** layers if fine-tuning with very small batches

In [None]:
# Quick transfer learning demo with torchvision models
import torchvision.models as models

# Load a pretrained ResNet-18
resnet = models.resnet18(weights='IMAGENET1K_V1')

print("ResNet-18 final layers:")
print(f"  Average pool: {resnet.avgpool}")
print(f"  FC layer: {resnet.fc}")
print(f"  Original output: {resnet.fc.out_features} classes (ImageNet)")

# Approach 1: Feature extraction (freeze everything, replace head)
for param in resnet.parameters():
    param.requires_grad = False

num_features = resnet.fc.in_features
resnet.fc = nn.Linear(num_features, 10)  # New head for 10 classes

# Count trainable vs frozen parameters
trainable = sum(p.numel() for p in resnet.parameters() if p.requires_grad)
total = sum(p.numel() for p in resnet.parameters())
frozen = total - trainable

print(f"\nAfter modification for 10-class task:")
print(f"  Total parameters:     {total:>10,}")
print(f"  Frozen parameters:    {frozen:>10,} ({100*frozen/total:.1f}%)")
print(f"  Trainable parameters: {trainable:>10,} ({100*trainable/total:.1f}%)")

# Test forward pass
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
    output = resnet(x)
print(f"\nInput:  {x.shape}")
print(f"Output: {output.shape}")

In [None]:
# Approach 2: Fine-tuning (unfreeze last layers)
resnet_ft = models.resnet18(weights='IMAGENET1K_V1')

# First freeze everything
for param in resnet_ft.parameters():
    param.requires_grad = False

# Unfreeze the last residual block (layer4) and the new head
for param in resnet_ft.layer4.parameters():
    param.requires_grad = True

# Replace the classification head
num_features = resnet_ft.fc.in_features
resnet_ft.fc = nn.Linear(num_features, 10)

# Count trainable parameters
trainable_ft = sum(p.numel() for p in resnet_ft.parameters() if p.requires_grad)
total_ft = sum(p.numel() for p in resnet_ft.parameters())

print("Fine-tuning approach (unfreeze layer4 + new head):")
print(f"  Total parameters:     {total_ft:>10,}")
print(f"  Trainable parameters: {trainable_ft:>10,} ({100*trainable_ft/total_ft:.1f}%)")

print("\nTypical optimizer setup for fine-tuning:")
print("  optimizer = torch.optim.Adam([")
print("      {'params': model.layer4.parameters(), 'lr': 1e-4},  # Pretrained: small LR")
print("      {'params': model.fc.parameters(), 'lr': 1e-3},      # New head: larger LR")
print("  ])")

print("\n" + "="*60)
print(f"{'Approach':<25} {'Trainable Params':>20} {'% of Total':>12}")
print("="*60)
print(f"{'Feature extraction':<25} {trainable:>20,} {100*trainable/total:>11.1f}%")
print(f"{'Fine-tune layer4':<25} {trainable_ft:>20,} {100*trainable_ft/total_ft:>11.1f}%")
print(f"{'Full fine-tuning':<25} {total_ft:>20,} {'100.0':>11}%")
print("="*60)

In [None]:
# Visualize what gets frozen vs fine-tuned
fig, ax = plt.subplots(figsize=(14, 5))
ax.set_xlim(0, 14)
ax.set_ylim(0, 5)
ax.axis('off')

layers = [
    ('Conv1', 1.0, 'lightblue'),
    ('Layer1', 2.5, 'lightblue'),
    ('Layer2', 4.5, 'lightblue'),
    ('Layer3', 6.5, 'lightblue'),
    ('Layer4', 8.5, 'lightyellow'),
    ('FC Head', 11.0, 'lightgreen'),
]

ax.text(0.5, 4.5, 'Feature\nExtraction:', fontsize=10, fontweight='bold', va='top')
ax.text(0.5, 2.5, 'Fine-\nTuning:', fontsize=10, fontweight='bold', va='top')

for name, x_pos, color in layers:
    w = 1.5 if name != 'FC Head' else 2.0

    # Feature extraction row (top)
    c = 'lightgreen' if name == 'FC Head' else 'lightcoral'
    label = 'Train' if name == 'FC Head' else 'Frozen'
    rect = patches.FancyBboxPatch((x_pos, 3.5), w, 1.0, boxstyle="round,pad=0.1",
                                   facecolor=c, edgecolor='black', alpha=0.7)
    ax.add_patch(rect)
    ax.text(x_pos + w/2, 4.0, f'{name}\n({label})', ha='center', va='center', fontsize=8, fontweight='bold')

    # Fine-tuning row (bottom)
    if name in ['Layer4', 'FC Head']:
        c = 'lightgreen'
        label = 'Train'
    else:
        c = 'lightcoral'
        label = 'Frozen'
    rect = patches.FancyBboxPatch((x_pos, 1.5), w, 1.0, boxstyle="round,pad=0.1",
                                   facecolor=c, edgecolor='black', alpha=0.7)
    ax.add_patch(rect)
    ax.text(x_pos + w/2, 2.0, f'{name}\n({label})', ha='center', va='center', fontsize=8, fontweight='bold')

# Arrows
for i in range(len(layers)-1):
    x_from = layers[i][1] + (1.5 if layers[i][0] != 'FC Head' else 2.0)
    x_to = layers[i+1][1]
    for y in [4.0, 2.0]:
        ax.annotate('', xy=(x_to, y), xytext=(x_from, y),
                    arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))

# Legend
ax.add_patch(patches.Rectangle((1, 0.3), 0.5, 0.3, facecolor='lightcoral', edgecolor='black', alpha=0.7))
ax.text(1.7, 0.45, 'Frozen', fontsize=9, va='center')
ax.add_patch(patches.Rectangle((3.5, 0.3), 0.5, 0.3, facecolor='lightgreen', edgecolor='black', alpha=0.7))
ax.text(4.2, 0.45, 'Trainable', fontsize=9, va='center')

ax.set_title('Transfer Learning: What to Freeze vs Train', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

---

## Exercises

### Exercise 1: Batch IoU Computation

Implement a function that computes IoU between every pair of boxes in two sets — a common operation in detection evaluation. Your function should use vectorized NumPy operations (no loops) for efficiency.

In [None]:
# EXERCISE 1: Batch IoU computation
def batch_iou(boxes_a, boxes_b):
    """
    Compute pairwise IoU between two sets of boxes.

    Args:
        boxes_a: numpy array of shape (N, 4), each row is [x1, y1, x2, y2]
        boxes_b: numpy array of shape (M, 4), each row is [x1, y1, x2, y2]

    Returns:
        IoU matrix of shape (N, M) where iou[i, j] = IoU(boxes_a[i], boxes_b[j])

    Hint: Use broadcasting! Expand boxes_a to (N, 1, 4) and boxes_b to (1, M, 4)
    """
    # TODO: Implement this!
    # Step 1: Compute intersection coordinates using np.maximum and np.minimum
    # Step 2: Compute intersection areas (clamp to 0 for non-overlapping)
    # Step 3: Compute areas of each box
    # Step 4: Compute union = area_a + area_b - intersection
    # Step 5: Return intersection / union

    pass

# Test
boxes_a = np.array([
    [0, 0, 3, 3],
    [1, 1, 4, 4],
])
boxes_b = np.array([
    [0, 0, 3, 3],
    [2, 2, 5, 5],
    [6, 6, 9, 9],
])

result = batch_iou(boxes_a, boxes_b)

expected = np.array([
    [1.0, compute_iou([0,0,3,3], [2,2,5,5]), 0.0],
    [compute_iou([1,1,4,4], [0,0,3,3]), compute_iou([1,1,4,4], [2,2,5,5]), 0.0],
])

if result is not None:
    print(f"Your result:\n{result}")
    print(f"\nExpected:\n{expected}")
    print(f"\nCorrect: {np.allclose(result, expected)}")
else:
    print("Expected result:")
    print(expected)
    print("\nImplement batch_iou to verify!")

### Exercise 2: Implement Patch Embedding from Scratch (No Conv2d)

The `PatchEmbedding` class above uses `nn.Conv2d` as a shortcut. Implement it from scratch using only reshaping and a linear layer — this makes the process more explicit and reinforces understanding.

In [None]:
# EXERCISE 2: Patch embedding without Conv2d
class ManualPatchEmbedding(nn.Module):
    """
    Implement patch embedding using explicit reshape + linear projection.
    No Conv2d allowed!

    Steps:
    1. Reshape image from (B, C, H, W) to (B, n_patches, patch_size*patch_size*C)
    2. Apply a linear layer to project each flattened patch to embed_dim

    Hint: Use tensor.reshape() and careful dimension ordering.
    """
    def __init__(self, img_size=64, patch_size=8, in_channels=3, embed_dim=256):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        patch_dim = patch_size * patch_size * in_channels

        # TODO: Create a linear layer that maps patch_dim -> embed_dim
        self.projection = None  # Replace with nn.Linear(...)

    def forward(self, x):
        """
        Args:
            x: (batch, channels, height, width)
        Returns:
            (batch, n_patches, embed_dim)
        """
        B, C, H, W = x.shape
        P = self.patch_size

        # TODO: Reshape x into patches
        # Step 1: Reshape to (B, C, H//P, P, W//P, P)
        # Step 2: Permute to (B, H//P, W//P, C, P, P)
        # Step 3: Reshape to (B, n_patches, C*P*P)
        # Step 4: Apply self.projection

        pass

# Test
torch.manual_seed(42)
# Uncomment after implementing:
# manual_pe = ManualPatchEmbedding(img_size=64, patch_size=8, in_channels=3, embed_dim=256)
# x = torch.randn(2, 3, 64, 64)
# output = manual_pe(x)
# print(f"Input:  {x.shape}")
# print(f"Output: {output.shape}")
# print(f"Expected: torch.Size([2, 64, 256])")
# print(f"Correct shape: {output.shape == torch.Size([2, 64, 256])}")

print("Uncomment the test code above after implementing ManualPatchEmbedding!")
print("Expected output shape: (2, 64, 256)")
print("  2 = batch size")
print("  64 = (64/8)^2 = 64 patches")
print("  256 = embed_dim")

### Exercise 3: Add Dice Loss for Segmentation

Cross-entropy is the standard loss for classification, but for segmentation it can be problematic when classes are imbalanced (e.g., a small tumor in a large medical image). **Dice loss** directly optimizes the Dice coefficient, which measures overlap between prediction and ground truth.

$$\text{Dice} = \frac{2 |A \cap B|}{|A| + |B|}$$

Implement Dice loss for binary segmentation.

In [None]:
# EXERCISE 3: Dice loss for segmentation
def dice_loss(pred, target, smooth=1e-6):
    """
    Compute Dice loss for binary segmentation.

    Args:
        pred: (batch, 1, H, W) -- raw logits (apply sigmoid first!)
        target: (batch, 1, H, W) -- binary ground truth (0 or 1)
        smooth: smoothing factor to avoid division by zero

    Returns:
        1 - Dice coefficient (so that minimizing loss = maximizing Dice)

    Steps:
        1. Apply sigmoid to pred to get probabilities
        2. Flatten both pred and target to (batch, H*W)
        3. Compute intersection: sum(pred * target) per sample
        4. Compute Dice = (2 * intersection + smooth) / (sum(pred) + sum(target) + smooth)
        5. Return 1 - mean(Dice)
    """
    # TODO: Implement this!

    pass

# Test
torch.manual_seed(42)

# Perfect prediction
target = torch.zeros(1, 1, 8, 8)
target[0, 0, 2:6, 2:6] = 1.0  # 4x4 square of 1s

# Near-perfect prediction (high logits where target=1)
pred_good = torch.zeros(1, 1, 8, 8) - 5.0  # low logits everywhere
pred_good[0, 0, 2:6, 2:6] = 5.0  # high logits where target=1

# Bad prediction (random)
pred_bad = torch.randn(1, 1, 8, 8)

if dice_loss(pred_good, target) is not None:
    loss_good = dice_loss(pred_good, target)
    loss_bad = dice_loss(pred_bad, target)
    print(f"Dice loss (good prediction): {loss_good.item():.4f} (should be close to 0)")
    print(f"Dice loss (bad prediction):  {loss_bad.item():.4f} (should be higher)")
    print(f"\nGood < Bad: {loss_good.item() < loss_bad.item()}")
else:
    print("Implement dice_loss to verify!")
    print("Expected: good prediction loss close to 0.0, bad prediction loss > 0.3")

---

## Summary

### Key Concepts

- **Object Detection** decomposes into anchor boxes, IoU scoring, class prediction, and NMS post-processing
- **IoU (Intersection over Union)** is the standard metric for measuring bounding box overlap — values range from 0 (no overlap) to 1 (perfect overlap)
- **Non-Maximum Suppression (NMS)** removes duplicate detections by suppressing lower-confidence boxes that overlap with higher-confidence ones
- **YOLO** frames detection as a single-pass regression problem, predicting boxes and classes from a grid of cells simultaneously
- **Semantic Segmentation** assigns a class to every pixel using encoder-decoder architectures like U-Net
- **U-Net** uses skip connections to combine deep semantic features with fine spatial details for precise segmentation
- **Transposed convolutions** are the learnable upsampling operation that lets decoders increase spatial resolution
- **Instance Segmentation** (Mask R-CNN) combines detection with per-instance mask prediction
- **Vision Transformers (ViT)** treat images as sequences of patches and apply Transformer self-attention
- **Patch embeddings** convert image patches into token embeddings, analogous to word embeddings in NLP
- **CLIP** connects vision and language through contrastive learning in a shared embedding space
- **Transfer learning** reuses pretrained features, either by freezing the backbone (feature extraction) or fine-tuning with a small learning rate

### Connection to Deep Learning

| Concept | Where It Appears | Why It Matters |
|---------|-----------------|----------------|
| IoU and NMS | Every detection model | Standard evaluation and post-processing |
| Anchor boxes | YOLO, SSD, Faster R-CNN | Efficient candidate generation |
| Encoder-decoder | U-Net, segmentation, autoencoders | Compress then reconstruct spatial information |
| Skip connections | U-Net, ResNet, DenseNet | Preserve information across depth |
| Patch embeddings | ViT, DINO, MAE | Bridge images and Transformers |
| Contrastive learning | CLIP, SimCLR, DINO | Learn representations without explicit labels |
| Transfer learning | Nearly every practical vision system | Train with limited data by reusing pretrained features |

### Checklist

- [ ] I can compute IoU between two bounding boxes by hand
- [ ] I understand how anchor boxes tile an image with candidate regions
- [ ] I can explain NMS and why it is needed
- [ ] I understand YOLO's grid-based, single-pass detection approach
- [ ] I can describe the U-Net encoder-decoder architecture and the role of skip connections
- [ ] I know the difference between semantic, instance, and panoptic segmentation
- [ ] I can explain how ViT converts an image into a sequence of patch tokens
- [ ] I understand CLIP's contrastive learning objective
- [ ] I know when to use feature extraction vs fine-tuning for transfer learning

---

## Next Steps

In **Notebook 15: Recurrent Neural Networks (RNNs)**, we shift from spatial data (images) to **sequential data** (text, time series, audio). We will explore how RNNs maintain a hidden state that carries information forward through a sequence, why vanilla RNNs struggle with long-range dependencies, and how LSTM and GRU architectures solve the vanishing gradient problem.

The ideas from this notebook — particularly attention mechanisms (ViT) and encoder-decoder architectures (U-Net) — will reappear in new forms as we move into sequence modeling and eventually the full Transformer architecture.