# Few-Shot Learning for Rooftop Segmentation: Approach Setup

#### ATTENTION: This notebook is AI generated, and should be used as a checklist / orientative document. Information have been checked, and should be acceptably accurate.

**Goal**: Explore three few-shot learning approaches for semantic segmentation of rooftops across different geographic regions.

**Scenario**: We have a model trained on one area of Geneva (grid 1301_11) and want to deploy it to other areas (grids 1301_13 and 1301_31) with only K labeled examples from each target area.

**Real-world motivation**: A city government trained a rooftop segmentation model for solar panel assessment. They now want to deploy it to neighboring cities/districts, but labeling is expensive. Can we achieve good performance with only 5-10 labeled examples from the new area?

## 0. Dataset and Geographic Split Setup

In [4]:
from huggingface_hub import snapshot_download
from PIL import Image
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict
import os

# Download dataset
data_dir = Path("data")
dataset_path = snapshot_download(
    repo_id="raphaelattias/overfitteam-geneva-satellite-images",
    repo_type="dataset",
    local_dir=str(data_dir),
    local_dir_use_symlinks=True,
)

print(f"Dataset downloaded to: {dataset_path}")

  from .autonotebook import tqdm as notebook_tqdm
Fetching ... files: 2111it [00:00, 8405.51it/s] 

Dataset downloaded to: /Users/giocopp/Desktop/Uni/Hertie School/5th Semester/DL/DL-FinalProject/DL-Tutorial-giorgio-exploration/data





### 0.1 Geographic Split Strategy

The Geneva dataset contains images from three geographic grids:
- **Grid 1301_11**: ~70% of data (295 train images)
- **Grid 1301_13**: ~18% of data (76 train images) 
- **Grid 1301_31**: ~12% of data (49 train images) 

**Our Split**:
```
Source Domain (Base Training):  Grid 1301_11 only
Target Domains (Few-Shot Test): Grid 1301_13 and Grid 1301_31
```

This simulates **geographic domain shift**: training on one part of the city and deploying to other parts with different building densities, roof types, and urban layouts.

In [5]:
def parse_grid_from_filename(filepath):
    """
    Extract grid ID from filename.
    Example: DOP25_LV03_1301_11_2015_1_15_497500.0_119062.5.png -> '1301_11'
    """
    parts = Path(filepath).stem.split("_")
    sheet, subgrid = parts[2], parts[3]
    return f"{sheet}_{subgrid}"


def create_geographic_split(dataset_path, split="train", category="all"):
    """
    Organize images by geographic grid.

    Returns:
        dict: {grid_id: [(image_path, label_path), ...]}

    """
    img_dir = Path(dataset_path) / split / "images" / category
    label_dir = Path(dataset_path) / split / "labels" / category

    grid_data = defaultdict(list)

    for img_path in sorted(img_dir.glob("*.png")):
        grid_id = parse_grid_from_filename(img_path)
        label_path = label_dir / img_path.name.replace(".png", "_label.png")

        if label_path.exists():
            grid_data[grid_id].append((str(img_path), str(label_path)))

    return grid_data


# Create splits
train_by_grid = create_geographic_split(dataset_path, split="train", category="all")
val_by_grid = create_geographic_split(dataset_path, split="val", category="all")
test_by_grid = create_geographic_split(dataset_path, split="test", category="all")

# Print statistics
print("\n=== TRAINING DATA BY GRID ===")
for grid_id in sorted(train_by_grid.keys()):
    print(f"Grid {grid_id}: {len(train_by_grid[grid_id])} images")

print("\n=== OUR SPLIT STRATEGY ===")
print(f"Source domain (base training):  Grid 1301_11 ({len(train_by_grid['1301_11'])} images)")
print(f"Target domain 1 (few-shot):     Grid 1301_13 ({len(train_by_grid['1301_13'])} images)")
print(f"Target domain 2 (few-shot):     Grid 1301_31 ({len(train_by_grid['1301_31'])} images)")


=== TRAINING DATA BY GRID ===
Grid 1301_11: 295 images
Grid 1301_13: 76 images
Grid 1301_31: 49 images

=== OUR SPLIT STRATEGY ===
Source domain (base training):  Grid 1301_11 (295 images)
Target domain 1 (few-shot):     Grid 1301_13 (76 images)
Target domain 2 (few-shot):     Grid 1301_31 (49 images)


---

## 1. Approach 1: Fine-Tuning (Transfer Learning)

### Concept
Traditional transfer learning approach: pre-train a model on the source domain, then fine-tune on K labeled examples from the target domain.

### Methodology

**Phase 1: Base Model Training**
```
Training data: All images from Grid 1301_11 (~295 images)
Architecture:  U-Net or DeepLabV3 with ResNet encoder
Task:          Binary segmentation (rooftop vs. background)
Optimization:  Standard supervised learning with cross-entropy + Dice loss
```

**Phase 2: Few-Shot Adaptation (Fine-Tuning)**
```
For each target grid (1301_13, 1301_31):
    1. Randomly select K images as support set (K = 1, 3, 5, 10, 20)
    2. Fine-tune the pre-trained model on these K examples
    3. Evaluate on remaining images from that grid
```

**Fine-Tuning Strategies**:
- **Full fine-tuning**: Update all model parameters
- **Partial fine-tuning**: Freeze encoder, only update decoder
- **Learning rate**: Use small LR (1e-5 to 1e-4) to avoid catastrophic forgetting
- **Early stopping**: Monitor validation loss to prevent overfitting on K examples

### Advantages
- ✅ Simple to implement and understand
- ✅ Well-established baseline in transfer learning
- ✅ Works with any architecture
- ✅ Fast adaptation (just a few gradient updates)

### Limitations
- ❌ Prone to overfitting when K is very small (K=1,3)
- ❌ Doesn't explicitly learn "how to learn from few examples"
- ❌ Requires careful hyperparameter tuning (LR, epochs, which layers to freeze)
- ❌ May suffer from catastrophic forgetting

### Expected Performance
```
Baseline (zero-shot, no adaptation):  ~XX% IoU on target grids
Fine-tuning with K=1:                 ~XX% IoU (minimal improvement)
Fine-tuning with K=5:                 ~XX% IoU (moderate improvement)
Fine-tuning with K=20:                ~XX% IoU (good improvement)
```

### Implementation Outline (Not Implemented Yet)

In [None]:
# Fine-tuning approach pseudocode
"""
# Phase 1: Train base model
base_model = UNet(encoder='resnet34', classes=2)
train_dataset = Grid_1301_11_Dataset(train_by_grid['1301_11'])
trainer = train(base_model, train_dataset, epochs=50)
save_checkpoint(base_model, 'base_model.pth')

# Phase 2: Few-shot fine-tuning
for target_grid in ['1301_13', '1301_31']:
    for K in [1, 3, 5, 10, 20]:
        # Sample K support examples
        support_set = random.sample(train_by_grid[target_grid], K)
        query_set = [x for x in train_by_grid[target_grid] if x not in support_set]
        
        # Load pre-trained base model
        model = load_checkpoint('base_model.pth')
        
        # Fine-tune on K examples
        optimizer = Adam(model.parameters(), lr=1e-5)
        for epoch in range(10):  # Small number of epochs
            for img, mask in support_set:
                loss = compute_loss(model(img), mask)
                loss.backward()
                optimizer.step()
        
        # Evaluate on query set
        metrics = evaluate(model, query_set)
        print(f"Grid {target_grid}, K={K}: IoU = {metrics['iou']:.3f}")
"""
pass

---

## 2. Approach 2: Meta-Learning with Prototypical Networks

### Concept
Instead of pre-training then adapting, **learn how to segment from K examples during training itself**. The model learns to extract good feature embeddings such that pixels from the same class (rooftop/background) cluster together.

### Methodology

**Core Idea**: 
- Learn a feature encoder that maps pixels to an embedding space
- In this space, compute class prototypes (centers) from support set
- Classify query pixels by distance to nearest prototype

**Episodic Training (on Grid 1301_11)**:
```
For each training episode:
    1. Sample K images as support set
    2. Sample Q images as query set  
    3. Encode all images → pixel embeddings
    4. Compute class prototypes from support set:
       - fg_prototype = mean(embeddings where mask == 1)
       - bg_prototype = mean(embeddings where mask == 0)
    5. Classify query pixels by distance to prototypes
    6. Compute loss on query predictions
    7. Update encoder to improve few-shot segmentation
```

**Architecture**:
```
Input Image [H, W, 3]
    ↓
Feature Encoder (ResNet/UNet encoder)
    ↓
Pixel Embeddings [H, W, D]  (D = embedding dimension, e.g., 256)
    ↓
Prototype Computation:
    - For each class c, compute prototype μ_c
    - μ_c = mean of embeddings where support_mask == c
    ↓
Distance-based Classification:
    - For each query pixel embedding z
    - distance_c = ||z - μ_c||²
    - prediction = argmin_c(distance_c)
```

**Training Details**:
- Episodes: 1000-2000 episodes sampled from Grid 1301_11
- K (support): 3-5 images per episode
- Q (query): 5-10 images per episode
- Embedding dim: 256 or 512
- Distance metric: Euclidean or cosine
- Loss: Cross-entropy on query predictions
- Augmentation: Rotation, flip, color jitter, random crop

**Few-Shot Inference (on Target Grids)**:
```
Given K labeled images from Grid 1301_13:
    1. Encode support images → embeddings
    2. Compute prototypes from support masks
    3. For each query image:
        - Encode → embeddings
        - Classify each pixel by nearest prototype
    4. No gradient updates needed! (zero-shot in the target domain)
```

### Advantages
- ✅ **Explicitly designed for few-shot learning**: trained to segment from K examples
- ✅ **No fine-tuning needed**: just compute prototypes and classify
- ✅ **Better generalization**: learns transferable embeddings, not dataset-specific features
- ✅ **Interpretable**: can visualize learned embeddings and prototypes
- ✅ **Sample efficient**: should outperform fine-tuning at very low K (1-5 shots)

### Limitations
- ❌ More complex to implement than fine-tuning
- ❌ Requires episodic training (different from standard supervised learning)
- ❌ Slower training (need many episodes to sample diverse tasks)
- ❌ Needs sufficient source domain diversity (Grid 1301_11 should be varied enough)

### Expected Performance
```
Baseline (zero-shot, no adaptation):       ~XX% IoU on target grids
Prototypical Networks with K=1:            ~XX% IoU (better than fine-tuning at K=1)
Prototypical Networks with K=5:            ~XX% IoU (significantly better than fine-tuning)
Prototypical Networks with K=20:           ~XX% IoU (comparable or better than fine-tuning)
```

**Key hypothesis**: Prototypical Networks should outperform fine-tuning especially at **low K (1-5 shots)** because they're explicitly trained for this scenario.

### Implementation Outline (Not Implemented Yet)

In [None]:
# Prototypical Networks pseudocode
"""
import torch
import torch.nn as nn

class PrototypicalSegmentationNetwork(nn.Module):
    def __init__(self, encoder_name='resnet34', embedding_dim=256):
        super().__init__()
        self.encoder = build_encoder(encoder_name)  # e.g., ResNet backbone
        self.embedding_head = nn.Conv2d(encoder_channels, embedding_dim, 1)
    
    def extract_embeddings(self, images):
        # images: [B, 3, H, W]
        features = self.encoder(images)  # [B, C, H, W]
        embeddings = self.embedding_head(features)  # [B, D, H, W]
        return embeddings
    
    def compute_prototypes(self, support_embeddings, support_masks):
        # support_embeddings: [K, D, H, W]
        # support_masks: [K, H, W] (binary: 0=background, 1=rooftop)
        
        # Flatten spatial dimensions
        embeddings_flat = support_embeddings.flatten(2)  # [K, D, H*W]
        masks_flat = support_masks.flatten(1)  # [K, H*W]
        
        # Compute prototypes for each class
        bg_prototype = embeddings_flat[:, :, masks_flat == 0].mean(dim=-1)  # [D]
        fg_prototype = embeddings_flat[:, :, masks_flat == 1].mean(dim=-1)  # [D]
        
        return torch.stack([bg_prototype, fg_prototype])  # [2, D]
    
    def classify_by_prototypes(self, query_embeddings, prototypes):
        # query_embeddings: [Q, D, H, W]
        # prototypes: [2, D]
        
        Q, D, H, W = query_embeddings.shape
        embeddings_flat = query_embeddings.view(Q, D, -1).permute(0, 2, 1)  # [Q, H*W, D]
        
        # Compute distances to each prototype
        distances = torch.cdist(embeddings_flat, prototypes.unsqueeze(0))  # [Q, H*W, 2]
        
        # Classify by nearest prototype (argmin distance)
        predictions = distances.argmin(dim=-1).view(Q, H, W)  # [Q, H, W]
        
        return predictions

# Training loop
model = PrototypicalSegmentationNetwork()
optimizer = Adam(model.parameters(), lr=1e-4)

for episode in range(num_episodes):
    # Sample episode from Grid 1301_11
    support_imgs, support_masks = sample_k_images(train_by_grid['1301_11'], k=5)
    query_imgs, query_masks = sample_q_images(train_by_grid['1301_11'], q=5)
    
    # Extract embeddings
    support_emb = model.extract_embeddings(support_imgs)
    query_emb = model.extract_embeddings(query_imgs)
    
    # Compute prototypes from support set
    prototypes = model.compute_prototypes(support_emb, support_masks)
    
    # Classify query images
    predictions = model.classify_by_prototypes(query_emb, prototypes)
    
    # Compute loss
    loss = cross_entropy(predictions, query_masks)
    
    # Update encoder
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Inference on target grid
support_imgs, support_masks = select_k_examples(train_by_grid['1301_13'], k=5)
query_imgs, query_masks = get_remaining_images(train_by_grid['1301_13'])

with torch.no_grad():
    support_emb = model.extract_embeddings(support_imgs)
    query_emb = model.extract_embeddings(query_imgs)
    
    prototypes = model.compute_prototypes(support_emb, support_masks)
    predictions = model.classify_by_prototypes(query_emb, prototypes)
    
    iou = compute_iou(predictions, query_masks)
    print(f"IoU on Grid 1301_13 with K=5: {iou:.3f}")
"""
pass

---

## 3. Approach 3: PANet (Prototype Alignment Network)

### Concept
PANet is a more sophisticated meta-learning approach specifically designed for few-shot semantic segmentation. It extends Prototypical Networks by adding **masked average pooling** and **prototype alignment** to better handle dense prediction tasks like segmentation.

### Key Differences from Prototypical Networks
While Prototypical Networks compute simple class prototypes, PANet introduces:
1. **Masked Average Pooling (MAP)**: Uses support masks to extract more informative prototypes
2. **Bidirectional matching**: Matches query→support AND support→query for consistency
3. **Multi-scale features**: Combines features from multiple encoder levels
4. **Optimized for segmentation**: Designed specifically for dense pixel-wise prediction

### Methodology

**Core Architecture**:
```
Support Set (K images + masks)           Query Image
         ↓                                      ↓
    Feature Encoder                       Feature Encoder
         ↓                                      ↓
    [K, C, H, W]                              [C, H, W]
         ↓                                      ↓
    Masked Average Pooling ──────────────→  Prototype Matching
    (extract class prototypes                     ↓
     using support masks)              Distance-based prediction
         ↓                                      ↓
    Foreground prototype: μ_fg          Segmentation mask [H, W]
    Background prototype: μ_bg
```

**Masked Average Pooling (MAP)**:
```
For class c (foreground or background):
    1. Collect all feature vectors from support images where mask == c
    2. Compute prototype μ_c = mean of these feature vectors
    3. This gives one prototype per class
    
μ_fg = Σ_{i,x,y} f_i(x,y) * m_i(x,y) / Σ_{i,x,y} m_i(x,y)
μ_bg = Σ_{i,x,y} f_i(x,y) * (1-m_i(x,y)) / Σ_{i,x,y} (1-m_i(x,y))

where:
- f_i(x,y) = feature vector at position (x,y) in support image i
- m_i(x,y) = mask value (0 or 1) at position (x,y) in support image i
```

**Prototype Alignment**:
```
For each query pixel embedding z:
    1. Compute cosine similarity to each prototype
       sim_fg = cos(z, μ_fg)
       sim_bg = cos(z, μ_bg)
    
    2. Convert to prediction probability
       P(y=fg|z) = exp(sim_fg) / (exp(sim_fg) + exp(sim_bg))
```

**Episodic Training (on Grid 1301_11)**:
```
For each training episode:
    1. Sample K support images with masks
    2. Sample Q query images with masks
    3. Extract multi-scale features for all images
    4. Compute class prototypes via Masked Average Pooling on support set
    5. For each query image:
       a. Compute similarity between query features and prototypes
       b. Generate segmentation prediction
    6. Compute loss (cross-entropy + auxiliary losses)
    7. Update encoder via backpropagation
```

**Training Details**:
- Episodes: 1000-2000 episodes from Grid 1301_11
- K (support): 1-5 images per episode (PANet works well even with K=1)
- Q (query): 1 image per episode (typical PANet setup)
- Feature extractor: ResNet-50 or ResNet-101 backbone
- Multi-scale: Uses features from multiple ResNet blocks (e.g., res3, res4, res5)
- Loss: Cross-entropy loss on query predictions
- Augmentation: Random crop, flip, color jitter

**Few-Shot Inference (on Target Grids)**:
```
Given K labeled images from Grid 1301_13:
    1. Extract features from K support images
    2. Compute fg/bg prototypes using Masked Average Pooling
    3. For each query image:
       a. Extract features
       b. Compute cosine similarity to prototypes
       c. Generate segmentation prediction
    4. No gradient updates! (forward pass only)
```

### Advantages
- ✅ **State-of-art for few-shot segmentation**: Consistently outperforms simpler methods
- ✅ **Better use of support masks**: MAP extracts richer prototypes than simple averaging
- ✅ **Works with K=1**: Effective even with single support example
- ✅ **Multi-scale features**: Captures both fine details and semantic context
- ✅ **No fine-tuning needed**: Like Prototypical Networks, inference is forward-pass only
- ✅ **Well-studied**: Multiple papers and implementations available

### Limitations
- ❌ **More complex**: Harder to implement than Prototypical Networks
- ❌ **Computationally expensive**: Multi-scale features increase memory/compute
- ❌ **Requires careful design**: Choosing which feature levels, how to combine them
- ❌ **Slower training**: More complex architecture means longer training time
- ❌ **More hyperparameters**: Feature levels, similarity metrics, loss weights

### Expected Performance
```
Baseline (zero-shot, no adaptation):   ~XX% IoU on target grids
PANet with K=1:                        ~(XX+5)% IoU (best at K=1)
PANet with K=5:                        ~(XX+8)% IoU (state-of-art performance)
PANet with K=20:                       ~(XX+10)% IoU (approaching upper bound)
```

**Key hypothesis**: PANet should outperform both fine-tuning and Prototypical Networks across all K values, with the **largest advantage at K=1-5** due to its sophisticated prototype extraction.

### Implementation Outline (Not Implemented Yet)

In [None]:
# PANet pseudocode
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class PANet(nn.Module):
    def __init__(self, backbone='resnet50', use_multi_scale=True):
        super().__init__()
        self.encoder = build_resnet_encoder(backbone)  # ResNet-50 or 101
        self.use_multi_scale = use_multi_scale
        
        # If using multi-scale, we'll extract features from multiple layers
        # e.g., res3, res4, res5 from ResNet
        if use_multi_scale:
            self.feature_levels = ['layer2', 'layer3', 'layer4']  # ResNet layers
        
    def extract_features(self, images, multi_scale=False):
        # images: [B, 3, H, W]
        if self.use_multi_scale and multi_scale:
            # Extract features from multiple ResNet blocks
            features = {}
            x = self.encoder.conv1(images)
            x = self.encoder.bn1(x)
            x = self.encoder.relu(x)
            x = self.encoder.maxpool(x)
            
            x = self.encoder.layer1(x)
            x = self.encoder.layer2(x)
            features['layer2'] = x  # 1/8 resolution
            
            x = self.encoder.layer3(x)
            features['layer3'] = x  # 1/16 resolution
            
            x = self.encoder.layer4(x)
            features['layer4'] = x  # 1/32 resolution
            
            return features
        else:
            # Single-scale features from final layer
            features = self.encoder(images)  # [B, C, H', W']
            return features
    
    def masked_average_pooling(self, features, masks):
        \"\"\"
        Compute class prototypes using Masked Average Pooling (MAP).
        
        Args:
            features: [K, C, H, W] - support features
            masks: [K, H, W] - support masks (binary: 0=bg, 1=fg)
        
        Returns:
            fg_prototype: [C] - foreground prototype
            bg_prototype: [C] - background prototype
        \"\"\"
        K, C, H, W = features.shape
        
        # Resize masks to match feature resolution
        masks_resized = F.interpolate(
            masks.unsqueeze(1).float(), 
            size=(H, W), 
            mode='nearest'
        ).squeeze(1)  # [K, H, W]
        
        # Flatten spatial dimensions
        features_flat = features.view(K, C, -1)  # [K, C, H*W]
        masks_flat = masks_resized.view(K, -1)  # [K, H*W]
        
        # Foreground prototype: average of features where mask == 1
        fg_mask = (masks_flat == 1).unsqueeze(1)  # [K, 1, H*W]
        fg_features = features_flat * fg_mask  # [K, C, H*W]
        fg_count = fg_mask.sum(dim=(0, 2), keepdim=True)  # [1, 1, 1]
        fg_prototype = fg_features.sum(dim=(0, 2)) / (fg_count.squeeze() + 1e-5)  # [C]
        
        # Background prototype: average of features where mask == 0
        bg_mask = (masks_flat == 0).unsqueeze(1)  # [K, 1, H*W]
        bg_features = features_flat * bg_mask  # [K, C, H*W]
        bg_count = bg_mask.sum(dim=(0, 2), keepdim=True)  # [1, 1, 1]
        bg_prototype = bg_features.sum(dim=(0, 2)) / (bg_count.squeeze() + 1e-5)  # [C]
        
        return fg_prototype, bg_prototype
    
    def prototype_alignment(self, query_features, fg_prototype, bg_prototype):
        \"\"\"
        Compute cosine similarity between query features and prototypes.
        
        Args:
            query_features: [Q, C, H, W]
            fg_prototype: [C]
            bg_prototype: [C]
        
        Returns:
            predictions: [Q, 2, H, W] - similarity scores for each class
        \"\"\"
        Q, C, H, W = query_features.shape
        
        # Normalize features and prototypes for cosine similarity
        query_norm = F.normalize(query_features, p=2, dim=1)  # [Q, C, H, W]
        fg_proto_norm = F.normalize(fg_prototype.unsqueeze(0), p=2, dim=1)  # [1, C]
        bg_proto_norm = F.normalize(bg_prototype.unsqueeze(0), p=2, dim=1)  # [1, C]
        
        # Compute cosine similarity
        # query_norm: [Q, C, H, W], proto_norm: [1, C] -> need to broadcast
        fg_sim = (query_norm * fg_proto_norm.view(1, C, 1, 1)).sum(dim=1)  # [Q, H, W]
        bg_sim = (query_norm * bg_proto_norm.view(1, C, 1, 1)).sum(dim=1)  # [Q, H, W]
        
        # Stack similarities for both classes
        similarity_map = torch.stack([bg_sim, fg_sim], dim=1)  # [Q, 2, H, W]
        
        return similarity_map
    
    def forward(self, support_images, support_masks, query_images):
        \"\"\"
        PANet forward pass for few-shot segmentation.
        
        Args:
            support_images: [K, 3, H, W]
            support_masks: [K, H, W]
            query_images: [Q, 3, H, W]
        
        Returns:
            predictions: [Q, 2, H, W] - class scores for query images
        \"\"\"
        # Extract features
        support_features = self.extract_features(support_images)  # [K, C, H', W']
        query_features = self.extract_features(query_images)  # [Q, C, H', W']
        
        # Compute prototypes from support set
        fg_prototype, bg_prototype = self.masked_average_pooling(
            support_features, support_masks
        )
        
        # Align query features with prototypes
        similarity_map = self.prototype_alignment(
            query_features, fg_prototype, bg_prototype
        )
        
        # Upsample to original resolution
        Q, _, H, W = query_images.shape
        predictions = F.interpolate(
            similarity_map,
            size=(H, W),
            mode='bilinear',
            align_corners=True
        )  # [Q, 2, H, W]
        
        return predictions

# Training loop
model = PANet(backbone='resnet50')
optimizer = Adam(model.parameters(), lr=1e-4)

for episode in range(num_episodes):
    # Sample episode from Grid 1301_11
    support_imgs, support_masks = sample_k_images(train_by_grid['1301_11'], k=5)
    query_imgs, query_masks = sample_q_images(train_by_grid['1301_11'], q=1)
    
    # Forward pass
    predictions = model(support_imgs, support_masks, query_imgs)  # [Q, 2, H, W]
    
    # Compute loss
    loss = F.cross_entropy(predictions, query_masks.long())
    
    # Update model
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if episode % 100 == 0:
        print(f"Episode {episode}, Loss: {loss.item():.4f}")

# Inference on target grid
support_imgs, support_masks = select_k_examples(train_by_grid['1301_13'], k=5)
query_imgs, query_masks = get_remaining_images(train_by_grid['1301_13'])

model.eval()
with torch.no_grad():
    predictions = model(support_imgs, support_masks, query_imgs)  # [Q, 2, H, W]
    pred_masks = predictions.argmax(dim=1)  # [Q, H, W]
    
    iou = compute_iou(pred_masks, query_masks)
    print(f"IoU on Grid 1301_13 with K=5: {iou:.3f}")
"""
pass

---

## 4. Comparison of Approaches

| Aspect | Fine-Tuning | Prototypical Networks | PANet |
|--------|-------------|----------------------|-------|
| **Training paradigm** | Standard supervised → fine-tune | Episodic meta-learning | Episodic meta-learning |
| **Adaptation method** | Gradient updates on K examples | Compute prototypes (no updates) | Compute prototypes (no updates) |
| **Prototype extraction** | N/A | Simple averaging | Masked Average Pooling (MAP) |
| **Feature representation** | Task-specific | Learned embeddings | Multi-scale learned embeddings |
| **Sample efficiency** | Moderate (needs K~10-20) | High (works well with K=3-5) | Very high (works well with K=1-3) |
| **Overfitting risk** | High at low K | Lower | Lowest |
| **Computational cost (training)** | Low (few gradient steps) | Medium (episodic training) | High (complex architecture) |
| **Computational cost (inference)** | Low | Low | Medium (multi-scale features) |
| **Implementation complexity** | Simple | Moderate | High |
| **Interpretability** | Standard CNN features | Explicit embeddings + prototypes | Prototypes + similarity maps |
| **Expected performance K=1** | Poor | Moderate | Good |
| **Expected performance K=5** | Moderate | Good | Very good |
| **Expected performance K=20** | Good | Very good | Excellent |

### Detailed Comparison

#### 1. **Fine-Tuning (Baseline)**
- **Best for**: K > 10, similar domains, simple baseline
- **Strengths**: Easy to implement, well-understood, works with any architecture
- **Weaknesses**: Overfits at low K, requires hyperparameter tuning
- **Use case**: When you have enough target data and want a simple solution

#### 2. **Prototypical Networks**
- **Best for**: K = 3-10, learning transferable embeddings
- **Strengths**: Designed for few-shot, no fine-tuning needed, interpretable
- **Weaknesses**: Simpler than PANet, may not capture complex patterns
- **Use case**: Balance between performance and implementation complexity

#### 3. **PANet (State-of-Art)**
- **Best for**: K = 1-5, maximum performance, complex domain shifts
- **Strengths**: Best performance, works with K=1, sophisticated prototype extraction
- **Weaknesses**: Complex to implement, computationally expensive, many hyperparameters
- **Use case**: When you need the best possible performance and can handle complexity

### When to Use Each Approach

| Scenario | Recommended Method | Rationale |
|----------|-------------------|-----------|
| K = 1 shot | **PANet** | Only PANet handles single example well |
| K = 3-5 shots | **PANet or Prototypical** | Both work well, trade complexity for performance |
| K = 10+ shots | **Fine-tuning or Prototypical** | Diminishing returns from PANet complexity |
| Tutorial/Educational | **Prototypical Networks** | Best balance of concepts and implementation |
| Production deployment | **PANet** | Maximum performance, worth the engineering |
| Quick baseline | **Fine-tuning** | Simplest, fastest to implement |
| Large domain shift | **PANet** | Best generalization to new domains |
| Limited compute | **Fine-tuning** | Lowest computational requirements |

---

## 5. Evaluation Protocol

For all three approaches, we'll use the same evaluation protocol:

### 5.1 Metrics
- **Intersection over Union (IoU)**: Primary metric for segmentation quality
- **Pixel Accuracy**: Overall percentage of correctly classified pixels
- **Dice Coefficient**: Harmonic mean of precision and recall
- **F1-Score**: Per-class performance metric

### 5.2 Experimental Setup
```
For K in [1, 3, 5, 10, 20]:
    For target_grid in ['1301_13', '1301_31']:
        For trial in range(5):  # 5 random trials for statistical significance
            # Randomly select K support examples
            support_set = random.sample(target_grid_data, K)
            query_set = remaining images
            
            # Apply each method
            pred_finetuning = fine_tuning_method.predict(support_set, query_set)
            pred_prototypical = prototypical_method.predict(support_set, query_set)
            pred_panet = panet_method.predict(support_set, query_set)
            
            # Compute metrics for each
            iou_ft = compute_iou(pred_finetuning, query_ground_truth)
            iou_proto = compute_iou(pred_prototypical, query_ground_truth)
            iou_panet = compute_iou(pred_panet, query_ground_truth)
            
        # Report mean ± std across trials
```

### 5.3 Baselines
1. **Zero-shot**: Trained on 1301_11, test directly on target (no adaptation)
2. **Upper bound**: Train on full target grid data (supervised learning)
3. **Random**: Random segmentation (sanity check)

### 5.4 Statistical Testing
- **Paired t-test**: Compare methods at each K value
- **Confidence intervals**: 95% CI for mean IoU
- **Significance level**: α = 0.05

### 5.5 Visualization
- **Performance vs. K curves**: Line plot showing all three methods
- **Box plots**: Distribution of IoU scores across trials
- **Qualitative results**: Side-by-side segmentation outputs
- **Error analysis**: Where does each method fail?
- **Prototype visualization** (for Prototypical Networks and PANet):
  - t-SNE of learned embeddings
  - Similarity maps
  - Prototype evolution as K increases

In [None]:
# Evaluation utilities (to be implemented)
"""
def compute_iou(predictions, targets, num_classes=2):
    \"\"\"
    Compute intersection over union for each class.
    
    Args:
        predictions: [N, H, W] - predicted masks
        targets: [N, H, W] - ground truth masks
        num_classes: number of classes
    
    Returns:
        iou: mean IoU across classes
    \"\"\"
    pass

def evaluate_fewshot(model, support_set, query_set, method='prototypical'):
    \"\"\"
    Run few-shot evaluation for a given method.
    
    Args:
        model: trained model
        support_set: K labeled examples
        query_set: unlabeled test examples
        method: 'fine_tuning', 'prototypical', or 'panet'
    
    Returns:
        metrics: dict with IoU, Dice, accuracy, etc.
    \"\"\"
    pass

def plot_performance_curve(results_dict, save_path=None):
    \"\"\"
    Plot IoU vs K for all three methods.
    
    Args:
        results_dict: {
            'fine_tuning': {K: [iou_values]}, 
            'prototypical': {K: [iou_values]},
            'panet': {K: [iou_values]}
        }
    \"\"\"
    import matplotlib.pyplot as plt
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    for method_name, results in results_dict.items():
        K_values = sorted(results.keys())
        mean_ious = [np.mean(results[k]) for k in K_values]
        std_ious = [np.std(results[k]) for k in K_values]
        
        ax.plot(K_values, mean_ious, marker='o', label=method_name)
        ax.fill_between(K_values, 
                        np.array(mean_ious) - np.array(std_ious),
                        np.array(mean_ious) + np.array(std_ious),
                        alpha=0.2)
    
    ax.set_xlabel('K (number of support examples)')
    ax.set_ylabel('Mean IoU')
    ax.set_title('Few-Shot Segmentation Performance')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    pass

def visualize_predictions(images, masks, predictions_dict, n_samples=5):
    \"\"\"
    Side-by-side comparison of ground truth and all method predictions.
    
    Args:
        images: [N, H, W, 3] - input images
        masks: [N, H, W] - ground truth masks
        predictions_dict: {
            'fine_tuning': [N, H, W],
            'prototypical': [N, H, W],
            'panet': [N, H, W]
        }
        n_samples: number of samples to visualize
    \"\"\"
    pass

def visualize_prototypes(model, support_set, method='prototypical'):
    \"\"\"
    Visualize learned embeddings and prototypes.
    
    For Prototypical Networks and PANet:
    - t-SNE visualization of pixel embeddings
    - Show foreground/background prototype locations
    - Similarity maps for query images
    \"\"\"
    pass

def statistical_comparison(results_dict, K_value=5):
    \"\"\"
    Perform paired t-test between methods at a specific K.
    
    Args:
        results_dict: results from all methods
        K_value: which K to compare
    
    Returns:
        p_values: pairwise p-values
    \"\"\"
    from scipy import stats
    
    methods = list(results_dict.keys())
    p_values = {}
    
    for i, method1 in enumerate(methods):
        for method2 in methods[i+1:]:
            ious1 = results_dict[method1][K_value]
            ious2 = results_dict[method2][K_value]
            
            t_stat, p_val = stats.ttest_rel(ious1, ious2)
            p_values[f'{method1}_vs_{method2}'] = p_val
    
    return p_values
"""
pass

---

## 6. Cross-City Transfer: Integrating Inria Aerial Dataset

### Motivation
To make the few-shot learning scenario more realistic and demonstrate the full power of meta-learning approaches, we integrate the **Inria Aerial Image Dataset** as cross-city target domains. This tests true **cross-city generalization**, moving beyond within-city transfer to real-world deployment scenarios.

### Why Inria Aerial Dataset?

**Dataset Overview**:
- **5 cities**: Austin (US), Chicago (US), Kitsap (US), Vienna (Austria), Tyrol (Austria)
- **Coverage**: 180 km² total across all cities
- **Images**: 360 images (180 train + 180 test), 5000×5000 pixels each
- **Resolution**: 0.3m per pixel
- **Task**: Building footprint segmentation (binary: building vs background)
- **Availability**: Free, well-documented benchmark dataset
- **HuggingFace**: `Jonathan/INRIA-Aerial-Dataset`

**Why This Strategy (Geneva + Inria)**:
✅ **Progressive difficulty**: Start with small domain shift (Geneva grids), then cross-city transfer (Inria cities)
✅ **Same task throughout**: Binary segmentation (rooftop/building vs background)
✅ **Educational value**: Clear learning progression from easy to hard
✅ **Multiple test scenarios**: 2 Geneva grids + up to 5 Inria cities
✅ **Real-world relevance**: "Train on Geneva, deploy to Vienna" mimics actual government use cases

**Expected Domain Shifts**:
```
Geneva 1301_11 → Geneva 1301_13:  SMALL (same city, different neighborhood)
Geneva 1301_11 → Vienna:          MEDIUM (both European, different city)
Geneva 1301_11 → Austin:          LARGE (European → US, different architecture)
Geneva 1301_11 → Kitsap:          VERY LARGE (European urban → US rural)
```

### Recommended Implementation: Vienna + Austin

For this tutorial, we'll focus on **two Inria cities**:

1. **Vienna** (Medium shift)
   - European city like Geneva
   - Moderate density, similar coordinate system
   - Different architecture but same continental style
   - Expected: Meta-learning should show moderate advantage

2. **Austin** (Large shift)
   - US suburban sprawl
   - Low-density, very different building types
   - Large domain gap from Geneva
   - Expected: Meta-learning should show large advantage over fine-tuning

### Integration Workflow

#### Step 1: Download and Prepare Inria Dataset

```python
from datasets import load_dataset
from PIL import Image
import numpy as np
from pathlib import Path

def download_inria_dataset(cities=['vienna', 'austin']):
    """
    Download Inria dataset from HuggingFace.
    
    Args:
        cities: List of cities to download
    
    Returns:
        Dictionary with city data
    """
    # Load full dataset
    inria_data = load_dataset("Jonathan/INRIA-Aerial-Dataset")
    
    # Filter by cities
    city_data = {}
    for city in cities:
        city_data[city] = {
            'train': inria_data['train'].filter(lambda x: x['city'].lower() == city),
            'test': inria_data['test'].filter(lambda x: x['city'].lower() == city)
        }
    
    return city_data
```

#### Step 2: Tile Images to Match Geneva Format

```python
def tile_image(image, mask, tile_size=250, stride=250):
    """
    Tile large Inria images (5000x5000) into Geneva-compatible patches (250x250).
    
    Args:
        image: PIL Image or numpy array [H, W, 3]
        mask: PIL Image or numpy array [H, W]
        tile_size: Size of output tiles (default 250 to match Geneva)
        stride: Stride for tiling (default 250 for non-overlapping)
    
    Returns:
        List of (image_tile, mask_tile) tuples
    """
    if isinstance(image, Image.Image):
        image = np.array(image)
    if isinstance(mask, Image.Image):
        mask = np.array(mask)
    
    H, W = image.shape[:2]
    tiles = []
    
    for y in range(0, H - tile_size + 1, stride):
        for x in range(0, W - tile_size + 1, stride):
            img_tile = image[y:y+tile_size, x:x+tile_size]
            mask_tile = mask[y:y+tile_size, x:x+tile_size]
            
            # Skip tiles with no buildings (optional: for efficiency)
            if mask_tile.sum() > 0:  # At least some building pixels
                tiles.append((img_tile, mask_tile))
    
    return tiles

def preprocess_inria_city(city_data, output_dir, tile_size=250):
    """
    Process all images from one Inria city.
    
    Args:
        city_data: Dataset for one city (train/test splits)
        output_dir: Where to save tiled images
        tile_size: Tile size (default 250)
    
    Returns:
        Dictionary with paths to tiled images
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    tiled_data = {'train': [], 'test': []}
    
    for split in ['train', 'test']:
        for idx, sample in enumerate(city_data[split]):
            image = sample['image']  # PIL Image
            mask = sample['mask']    # PIL Image (building=1, background=0)
            
            # Tile into 250x250 patches
            tiles = tile_image(image, mask, tile_size=tile_size)
            
            # Save tiles
            for tile_idx, (img_tile, mask_tile) in enumerate(tiles):
                img_path = output_dir / split / 'images' / f"{idx}_{tile_idx}.png"
                mask_path = output_dir / split / 'masks' / f"{idx}_{tile_idx}.png"
                
                img_path.parent.mkdir(parents=True, exist_ok=True)
                mask_path.parent.mkdir(parents=True, exist_ok=True)
                
                Image.fromarray(img_tile).save(img_path)
                Image.fromarray(mask_tile).save(mask_path)
                
                tiled_data[split].append((str(img_path), str(mask_path)))
    
    print(f"Created {len(tiled_data['train'])} train tiles and {len(tiled_data['test'])} test tiles")
    return tiled_data
```

#### Step 3: Normalize and Align Preprocessing

```python
def analyze_dataset_statistics(dataset_paths):
    """
    Compute mean and std of RGB channels for normalization.
    
    Args:
        dataset_paths: List of (image_path, mask_path) tuples
    
    Returns:
        mean, std: RGB channel statistics
    """
    means = []
    stds = []
    
    for img_path, _ in dataset_paths[:100]:  # Sample 100 images
        img = np.array(Image.open(img_path)) / 255.0
        means.append(img.mean(axis=(0, 1)))
        stds.append(img.std(axis=(0, 1)))
    
    mean = np.mean(means, axis=0)
    std = np.mean(stds, axis=0)
    
    return mean, std

# Compare Geneva vs Inria statistics
geneva_mean, geneva_std = analyze_dataset_statistics(train_by_grid['1301_11'])
vienna_mean, vienna_std = analyze_dataset_statistics(vienna_tiled_data['train'])
austin_mean, austin_std = analyze_dataset_statistics(austin_tiled_data['train'])

print("Dataset Statistics:")
print(f"Geneva: mean={geneva_mean}, std={geneva_std}")
print(f"Vienna: mean={vienna_mean}, std={vienna_std}")
print(f"Austin: mean={austin_mean}, std={austin_std}")
```

#### Step 4: Three-Level Evaluation Framework

```
Tutorial Evaluation Structure:

Level 1: Within-City Transfer (Baseline)
    Train: Geneva Grid 1301_11
    Test:  Geneva Grids 1301_13, 1301_31
    Purpose: Validate methods, small domain shift
    Expected: All methods work reasonably well

Level 2: Cross-City Transfer - Similar Domain (Vienna)
    Train: Geneva Grid 1301_11
    Test:  Vienna (Inria)
    Purpose: Medium domain shift, European → European
    Expected: Meta-learning starts showing advantage

Level 3: Cross-City Transfer - Large Domain Shift (Austin)
    Train: Geneva Grid 1301_11
    Test:  Austin (Inria)
    Purpose: Large domain shift, European urban → US suburban
    Expected: Meta-learning shows large advantage, fine-tuning struggles
```


In [None]:
# Inria dataset integration implementation (to be completed)
"""
# Complete implementation for integrating Inria dataset

from datasets import load_dataset
from PIL import Image
import numpy as np
from pathlib import Path

# Step 1: Download Inria dataset
def load_and_prepare_inria(cities=['vienna', 'austin'], cache_dir='./data/inria'):
    \"\"\"
    Download and prepare Inria dataset for cross-city evaluation.
    \"\"\"
    print(f"Downloading Inria dataset for cities: {cities}")
    
    # Load from HuggingFace
    inria_dataset = load_dataset("Jonathan/INRIA-Aerial-Dataset", cache_dir=cache_dir)
    
    city_data = {}
    for city in cities:
        print(f"Processing {city}...")
        
        # Filter by city (case-insensitive)
        train_city = [x for x in inria_dataset['train'] if x['city'].lower() == city.lower()]
        test_city = [x for x in inria_dataset['test'] if x['city'].lower() == city.lower()]
        
        city_data[city] = {
            'train': train_city,
            'test': test_city
        }
        
        print(f"  {city}: {len(train_city)} train, {len(test_city)} test images")
    
    return city_data

# Step 2: Tile images to 250x250
def tile_inria_images(city_data, output_dir, tile_size=250, stride=250, min_building_pixels=100):
    \"\"\"
    Tile Inria 5000x5000 images into 250x250 patches compatible with Geneva.
    
    Args:
        city_data: Dictionary with train/test splits for a city
        output_dir: Where to save tiles
        tile_size: Size of each tile (default 250)
        stride: Stride for tiling (default 250 for non-overlapping)
        min_building_pixels: Minimum building pixels to keep a tile
    
    Returns:
        Dictionary with tile paths: {'train': [(img, mask), ...], 'test': [...]}
    \"\"\"
    output_dir = Path(output_dir)
    tiled_paths = {'train': [], 'test': []}
    
    for split in ['train', 'test']:
        split_dir = output_dir / split
        (split_dir / 'images').mkdir(parents=True, exist_ok=True)
        (split_dir / 'masks').mkdir(parents=True, exist_ok=True)
        
        for img_idx, sample in enumerate(city_data[split]):
            # Get image and mask (already PIL Images)
            image = sample['image']  # 5000x5000 RGB
            mask = sample['mask']    # 5000x5000 binary
            
            # Convert to numpy for tiling
            img_np = np.array(image)
            mask_np = np.array(mask)
            
            H, W = img_np.shape[:2]
            tile_idx = 0
            
            # Tile the image
            for y in range(0, H - tile_size + 1, stride):
                for x in range(0, W - tile_size + 1, stride):
                    img_tile = img_np[y:y+tile_size, x:x+tile_size]
                    mask_tile = mask_np[y:y+tile_size, x:x+tile_size]
                    
                    # Skip tiles with too few building pixels
                    if mask_tile.sum() < min_building_pixels:
                        continue
                    
                    # Save tiles
                    img_path = split_dir / 'images' / f"img{img_idx:03d}_tile{tile_idx:03d}.png"
                    mask_path = split_dir / 'masks' / f"img{img_idx:03d}_tile{tile_idx:03d}.png"
                    
                    Image.fromarray(img_tile).save(img_path)
                    Image.fromarray(mask_tile).save(mask_path)
                    
                    tiled_paths[split].append((str(img_path), str(mask_path)))
                    tile_idx += 1
        
        print(f"{split}: Created {len(tiled_paths[split])} tiles")
    
    return tiled_paths

# Step 3: Cross-city evaluation
def evaluate_cross_city_transfer(model, geneva_train, vienna_data, austin_data, K_values=[1, 3, 5, 10, 20]):
    \"\"\"
    Evaluate few-shot transfer from Geneva to Vienna and Austin.
    
    Args:
        model: Trained model (Fine-tuning, Prototypical, or PANet)
        geneva_train: Geneva Grid 1301_11 data (source domain)
        vienna_data: Vienna tiled data (target domain 1)
        austin_data: Austin tiled data (target domain 2)
        K_values: List of K values to evaluate
    
    Returns:
        Dictionary with results for each target city and K value
    \"\"\"
    results = {
        'vienna': {},
        'austin': {}
    }
    
    # Test on Vienna
    print("\\n=== Vienna Cross-City Transfer ===")
    for K in K_values:
        # Sample K support examples from Vienna
        support_vienna = random.sample(vienna_data['test'], K)
        query_vienna = [x for x in vienna_data['test'] if x not in support_vienna]
        
        # Run few-shot evaluation
        iou_vienna = model.evaluate_fewshot(support_vienna, query_vienna)
        results['vienna'][K] = iou_vienna
        
        print(f"K={K}: IoU = {iou_vienna:.3f}")
    
    # Test on Austin
    print("\\n=== Austin Cross-City Transfer ===")
    for K in K_values:
        # Sample K support examples from Austin
        support_austin = random.sample(austin_data['test'], K)
        query_austin = [x for x in austin_data['test'] if x not in support_austin]
        
        # Run few-shot evaluation
        iou_austin = model.evaluate_fewshot(support_austin, query_austin)
        results['austin'][K] = iou_austin
        
        print(f"K={K}: IoU = {iou_austin:.3f}")
    
    return results

# Step 4: Visualization - Compare within-city vs cross-city
def plot_cross_city_comparison(results_geneva, results_vienna, results_austin, save_path=None):
    \"\"\"
    Plot performance comparison across different domain shifts.
    
    Args:
        results_geneva: Results on Geneva grids (1301_13, 1301_31)
        results_vienna: Results on Vienna
        results_austin: Results on Austin
    \"\"\"
    import matplotlib.pyplot as plt
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot 1: Fine-tuning across domains
    K_values = sorted(results_geneva['fine_tuning'].keys())
    
    ax1.plot(K_values, [results_geneva['fine_tuning'][k] for k in K_values], 
             marker='o', label='Geneva (small shift)', linewidth=2)
    ax1.plot(K_values, [results_vienna['fine_tuning'][k] for k in K_values], 
             marker='s', label='Vienna (medium shift)', linewidth=2)
    ax1.plot(K_values, [results_austin['fine_tuning'][k] for k in K_values], 
             marker='^', label='Austin (large shift)', linewidth=2)
    
    ax1.set_xlabel('K (number of support examples)', fontsize=12)
    ax1.set_ylabel('Mean IoU', fontsize=12)
    ax1.set_title('Fine-Tuning: Performance vs Domain Shift', fontsize=14)
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: PANet across domains (should be more robust)
    ax2.plot(K_values, [results_geneva['panet'][k] for k in K_values], 
             marker='o', label='Geneva (small shift)', linewidth=2)
    ax2.plot(K_values, [results_vienna['panet'][k] for k in K_values], 
             marker='s', label='Vienna (medium shift)', linewidth=2)
    ax2.plot(K_values, [results_austin['panet'][k] for k in K_values], 
             marker='^', label='Austin (large shift)', linewidth=2)
    
    ax2.set_xlabel('K (number of support examples)', fontsize=12)
    ax2.set_ylabel('Mean IoU', fontsize=12)
    ax2.set_title('PANet: Performance vs Domain Shift', fontsize=14)
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# Usage example:
# city_data = load_and_prepare_inria(cities=['vienna', 'austin'])
# vienna_tiles = tile_inria_images(city_data['vienna'], output_dir='data/inria/vienna')
# austin_tiles = tile_inria_images(city_data['austin'], output_dir='data/inria/austin')
# results = evaluate_cross_city_transfer(model, geneva_train, vienna_tiles, austin_tiles)
"""
pass

---

## 7. Recommended Starting Points for Implementation

### For This Tutorial Specifically

Based on our rooftop segmentation task, here are the most relevant resources:

#### 1. **Start with PFENet for Understanding** ⭐ RECOMMENDED
   ```bash
   git clone https://github.com/dvlab-research/PFENet
   ```
   **Why**: 
   - Cleanest PyTorch implementation
   - Well-documented code structure
   - Easy to adapt for binary segmentation
   - Good episodic training examples
   
   **Use for**: Understanding episodic training loop and prototype computation

#### 2. **Study PANet Official Implementation**
   ```bash
   git clone https://github.com/kaixin96/PANet
   ```
   **Why**:
   - Official implementation
   - Complete masked average pooling code
   - Even though it's TensorFlow, the logic is clear
   
   **Use for**: Understanding PANet architecture details

#### 3. **Use HSNet for PyTorch PANet-style Implementation**
   ```bash
   git clone https://github.com/juhongm999/hsnet
   ```
   **Why**:
   - Modern PyTorch implementation
   - Similar to PANet but better documented
   - Can adapt their prototype extraction code
   
   **Use for**: Adapting masked average pooling to our dataset

#### 4. **Leverage segmentation_models.pytorch for Baseline**
   ```bash
   pip install segmentation-models-pytorch
   ```
   **Why**:
   - Pre-built U-Net, DeepLabV3 models
   - Easy to use for fine-tuning baseline
   - Lots of pre-trained encoders (ResNet, EfficientNet)
   
   **Use for**: Quick baseline implementation

#### 5. **Use learn2learn for Prototypical Networks**
   ```bash
   pip install learn2learn
   ```
   **Why**:
   - High-level API for meta-learning
   - Built-in episodic data loaders
   - Prototypical Networks already implemented
   
   **Use for**: Prototypical Networks implementation

### Suggested Code Reuse Strategy

#### Phase 1: Baseline (Fine-tuning)
```python
# Use segmentation_models.pytorch
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=2
)

# Train on Grid 1301_11
# Fine-tune on K examples from target grid
```

#### Phase 2: Prototypical Networks
```python
# Adapt from PFENet or learn2learn
# Reference: https://github.com/dvlab-research/PFENet/blob/master/model/PFENet.py

# Key components to adapt:
# 1. Episodic data loader (sample K-shot tasks)
# 2. Prototype computation (masked average pooling)
# 3. Distance-based classification
```

#### Phase 3: PANet
```python
# Adapt from HSNet
# Reference: https://github.com/juhongm999/hsnet/blob/main/model/hsnet.py

# Key components to adapt:
# 1. Multi-scale feature extraction
# 2. Masked average pooling (more sophisticated than Prototypical)
# 3. Cosine similarity matching
```

### Recommended External Dataset

For cross-city evaluation, use **Inria Aerial Dataset (Vienna)**:

```python
# Available on HuggingFace
from datasets import load_dataset

inria = load_dataset("Jonathan/INRIA-Aerial-Dataset", split="train")

# Filter for Vienna
vienna_data = inria.filter(lambda x: x['city'] == 'vienna')
```

**Why Vienna**:
- Similar to Geneva (European city)
- Same coordinate system (will look similar visually)
- But different architecture/urban layout
- Perfect for testing domain adaptation

**Alternative**: Use **Austin** from Inria for larger domain shift (US city, very different)

### Code Adaptation Checklist

When adapting code from repositories above:

- [ ] **Data loader**: Modify to load Geneva dataset format
- [ ] **Number of classes**: Change from N-way to binary (rooftop/background)
- [ ] **Image size**: Adapt to 250x250 (Geneva's size)
- [ ] **Episodic sampling**: Sample from single grid (1301_11) instead of class-based sampling
- [ ] **Evaluation**: Add geographic split evaluation (test on different grids)
- [ ] **Metrics**: Ensure IoU computation is correct for binary segmentation

### Quick Start Code Snippets

#### Episodic Data Loader (adapt from PFENet)
```python
class EpisodicDataLoader:
    def __init__(self, grid_data, n_shot=5, n_query=5):
        self.grid_data = grid_data  # List of (image, mask) pairs
        self.n_shot = n_shot
        self.n_query = n_query
    
    def __iter__(self):
        while True:
            # Sample K support + Q query from grid
            indices = np.random.choice(len(self.grid_data), 
                                      self.n_shot + self.n_query, 
                                      replace=False)
            
            support_idx = indices[:self.n_shot]
            query_idx = indices[self.n_shot:]
            
            support_imgs = [self.grid_data[i][0] for i in support_idx]
            support_masks = [self.grid_data[i][1] for i in support_idx]
            query_imgs = [self.grid_data[i][0] for i in query_idx]
            query_masks = [self.grid_data[i][1] for i in query_idx]
            
            yield (support_imgs, support_masks, query_imgs, query_masks)
```

#### Masked Average Pooling (adapt from PANet/HSNet)
```python
def masked_average_pooling(features, masks):
    \"\"\"
    Compute prototypes using masked average pooling.
    
    Args:
        features: [K, C, H, W] - support features
        masks: [K, H, W] - support masks
    
    Returns:
        fg_proto, bg_proto: [C] - class prototypes
    \"\"\"
    # Resize masks to feature resolution
    masks = F.interpolate(masks.unsqueeze(1).float(), 
                         size=features.shape[-2:], 
                         mode='nearest').squeeze(1)
    
    # Foreground prototype
    fg_mask = (masks == 1).unsqueeze(1)  # [K, 1, H, W]
    fg_features = features * fg_mask
    fg_proto = fg_features.sum(dim=(0, 2, 3)) / (fg_mask.sum() + 1e-5)
    
    # Background prototype  
    bg_mask = (masks == 0).unsqueeze(1)
    bg_features = features * bg_mask
    bg_proto = bg_features.sum(dim=(0, 2, 3)) / (bg_mask.sum() + 1e-5)
    
    return fg_proto, bg_proto
```

### Additional Resources for Debugging

1. **Visualization Tools**
   - Use `tensorboard` for training monitoring
   - Use `matplotlib` for prototype visualization
   - Use `sklearn.manifold.TSNE` for embedding visualization

2. **Debugging Few-Shot Learning**
   - Common issue: Prototypes collapse (all embeddings same)
   - Solution: Check learning rate, add normalization
   - Test on PASCAL-5i first to verify implementation works

3. **Expected Results (Sanity Checks)**
   - Zero-shot on Geneva grids: ~0.45-0.55 IoU
   - Fine-tuning K=5: ~0.60-0.70 IoU
   - Prototypical K=5: ~0.65-0.75 IoU
   - PANet K=5: ~0.70-0.80 IoU

### Citation

If you use these approaches in your tutorial, cite the key papers:

```bibtex
@inproceedings{wang2019panet,
  title={PANet: Few-shot image semantic segmentation with prototype alignment},
  author={Wang, Kaixin and Liew, Jun Hao and Zou, Yingtian and Zhou, Daquan and Feng, Jiashi},
  booktitle={ICCV},
  year={2019}
}

@inproceedings{snell2017prototypical,
  title={Prototypical networks for few-shot learning},
  author={Snell, Jake and Swersky, Kevin and Zemel, Richard},
  booktitle={NeurIPS},
  year={2017}
}

@misc{geneva-satellite-dataset,
  title={Geneva Satellite Images Dataset},
  author={OverfitTeam},
  year={2024},
  publisher={HuggingFace},
  url={https://huggingface.co/datasets/raphaelattias/overfitteam-geneva-satellite-images}
}
```

---

## 8. References & Resources

### Academic Papers

#### Few-Shot Learning Foundations

1. **Prototypical Networks for Few-shot Learning**
   - Snell, J., Swersky, K., & Zemel, R. (2017)
   - NeurIPS 2017
   - Paper: https://arxiv.org/abs/1703.05175
   - *Foundation paper for prototypical networks in classification*

2. **Model-Agnostic Meta-Learning (MAML)**
   - Finn, C., Abbeel, P., & Levine, S. (2017)
   - ICML 2017
   - Paper: https://arxiv.org/abs/1703.03400
   - *Alternative meta-learning approach (more complex)*

#### Few-Shot Semantic Segmentation

3. **PANet: Few-Shot Image Semantic Segmentation with Prototype Alignment**
   - Wang, K., Liew, J. H., Zou, Y., Zhou, D., & Feng, J. (2019)
   - ICCV 2019
   - Paper: https://arxiv.org/abs/1908.06391
   - GitHub: https://github.com/kaixin96/PANet
   - *The main PANet paper - state-of-art few-shot segmentation*

4. **Adaptive Masked Proxies for Few-Shot Segmentation**
   - Boudiaf, M., Kervadec, H., Masud, Z. I., Piantanida, P., Ben Ayed, I., & Dolz, J. (2021)
   - ICCV 2021
   - Paper: https://arxiv.org/abs/2102.11123
   - GitHub: https://github.com/mboudiaf/RePRI-for-Few-Shot-Segmentation
   - *Improved version of PANet with adaptive prototypes*

5. **HSNet: Hypercorrelation Squeeze for Few-Shot Segmentation**
   - Min, J., Kang, D., & Cho, M. (2021)
   - ICCV 2021
   - Paper: https://arxiv.org/abs/2109.06211
   - GitHub: https://github.com/juhongm999/hsnet
   - *Alternative approach using correlation matching*

6. **Meta-Learning for Few-Shot Semantic Segmentation**
   - Tian, P., Wu, Z., Qi, L., Wang, L., Shi, Y., & Gao, Y. (2020)
   - Pattern Recognition 2020
   - Paper: https://arxiv.org/abs/2004.07730
   - *Good survey of meta-learning for segmentation*

#### Remote Sensing & Building Segmentation

7. **Building Extraction from Remote Sensing Images with Deep Learning**
   - Ji, S., Wei, S., & Lu, M. (2019)
   - Remote Sensing 2019
   - Paper: https://www.mdpi.com/2072-4292/11/7/778
   - *Survey of building segmentation methods*

8. **Domain Adaptation for Semantic Segmentation of Urban Scenes**
   - Hoffman, J., et al. (2018)
   - CVPR 2018
   - Paper: https://arxiv.org/abs/1711.06969
   - *Related work on domain adaptation for urban scenes*

### GitHub Repositories

#### Few-Shot Segmentation Implementations

1. **PANet (Official)**
   - https://github.com/kaixin96/PANet
   - TensorFlow implementation of PANet
   - Includes PASCAL-5i and COCO-20i benchmarks

2. **Few-Shot Semantic Segmentation Papers and Code**
   - https://github.com/xiaomengyc/Few-Shot-Semantic-Segmentation-Papers
   - Comprehensive collection of papers and code
   - Updated regularly with new methods

3. **PFENet: Prior Guided Feature Enrichment Network**
   - https://github.com/dvlab-research/PFENet
   - TPAMI 2020
   - PyTorch, very clean implementation
   - Good starting point for understanding few-shot segmentation

4. **HSNet (Hypercorrelation Squeeze Network)**
   - https://github.com/juhongm999/hsnet
   - PyTorch implementation
   - State-of-art results, well-documented

5. **RePRI: Adaptive Masked Proxies**
   - https://github.com/mboudiaf/RePRI-for-Few-Shot-Segmentation
   - PyTorch implementation
   - Improved PANet variant

#### Prototypical Networks Implementations

6. **Prototypical Networks (PyTorch)**
   - https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch
   - Clean PyTorch reimplementation
   - Good for understanding prototypical networks

7. **Few-Shot Learning Library (learn2learn)**
   - https://github.com/learnables/learn2learn
   - PyTorch library for meta-learning
   - Includes MAML, Prototypical Networks, and more
   - Great for experimentation

#### Building/Rooftop Segmentation

8. **SpaceNet Building Detection**
   - https://github.com/SpaceNetChallenge/BuildingDetectors
   - Collection of winning solutions for building detection
   - Various architectures (U-Net, Mask R-CNN, etc.)

9. **Inria Aerial Image Labeling Benchmark**
   - https://github.com/zorzi-s/projectRegularization
   - Segmentation on Inria dataset
   - Building footprint extraction

10. **Rooftop Segmentation for Solar Potential**
    - https://github.com/mdominguezd/Solar-Panel-Installability
    - Solar panel rooftop segmentation
    - Similar application to our use case

### Datasets

#### Few-Shot Segmentation Benchmarks

1. **PASCAL-5i**
   - Standard few-shot segmentation benchmark
   - 4-fold cross-validation setup
   - Download: http://host.robots.ox.ac.uk/pascal/VOC/

2. **COCO-20i**
   - MS COCO adapted for few-shot segmentation
   - 20 classes, 4 splits
   - Download: https://cocodataset.org/

#### Remote Sensing & Building Datasets

3. **Inria Aerial Image Dataset**
   - 5 cities, 180 km² coverage
   - Building footprints
   - Download: https://project.inria.fr/aerialimagelabeling/
   - HuggingFace: https://huggingface.co/datasets/Jonathan/INRIA-Aerial-Dataset

4. **SpaceNet Building Dataset**
   - Multiple cities worldwide
   - High-resolution satellite imagery
   - Download: https://spacenet.ai/datasets/

5. **Massachusetts Buildings Dataset**
   - 151 aerial images
   - Building segmentation masks
   - Download: https://www.cs.toronto.edu/~vmnih/data/

6. **Open Cities AI Challenge Dataset**
   - African cities building footprints
   - Good for domain adaptation
   - Download: https://www.drivendata.org/competitions/60/building-segmentation-disaster-resilience/

7. **xBD (xView2) Building Damage Dataset**
   - Pre/post disaster building assessment
   - Multiple cities
   - Download: https://xview2.org/

### Related Tutorials & Blog Posts

1. **Few-Shot Learning Tutorial**
   - https://lilianweng.github.io/posts/2018-11-30-meta-learning/
   - Excellent overview of meta-learning approaches
   - By Lilian Weng (OpenAI)

2. **Prototypical Networks Explained**
   - https://towardsdatascience.com/prototypical-networks-for-few-shot-learning-eb2c2b86baac
   - Intuitive explanation with code examples

3. **Satellite Image Segmentation**
   - https://github.com/robmarkcole/satellite-image-deep-learning
   - Comprehensive resource for satellite image analysis
   - Includes segmentation techniques

### Useful Libraries

1. **segmentation_models.pytorch**
   - https://github.com/qubvel/segmentation_models.pytorch
   - Pre-built segmentation architectures (U-Net, DeepLab, etc.)
   - Good for baseline implementations

2. **Albumentations**
   - https://github.com/albumentations-team/albumentations
   - Data augmentation library
   - Essential for episodic training

3. **PyTorch Metric Learning**
   - https://github.com/KevinMusgrave/pytorch-metric-learning
   - Useful for prototypical networks
   - Distance metrics, losses, miners

### Community & Forums

1. **Papers with Code - Few-Shot Segmentation**
   - https://paperswithcode.com/task/few-shot-semantic-segmentation
   - Leaderboards, benchmarks, and implementations

2. **r/MachineLearning - Meta-Learning discussions**
   - https://www.reddit.com/r/MachineLearning/
   - Active community for questions

3. **Computer Vision Foundation (CVF) Open Access**
   - https://openaccess.thecvf.com/
   - Free access to CVPR/ICCV/ECCV papers