<a href="https://colab.research.google.com/github/hunkim98/earth_science/blob/main/lecture/EPS210_Lab5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Harvard EPS-210 AI for Earth and Planetary Science

Instructor: Mostafa Mouasvi

# **Lab 5**:

**Activity 1**: Building Change Detection from Satellite Imagery

**Activity 2**: Martian Crater Detection with YOLO

---


<div style="background: linear-gradient(135deg, #A51C30 0%, #1E5A96 100%); padding: 30px; border-radius: 10px; color: white; margin-bottom: 20px;">
<h1 style="color: white; margin: 0;"> Building Change Detection from Satellite Imagery</h1>
<h2 style="color: #f0d0d0; margin-top: 10px;">Using Siamese Multi-Scale CNNs</h2>

## üìã Overview

| | |
|---|---|
| **Topic** | Building Change Detection Using Siamese CNNs |
| **Dataset** | LEVIR-CD (Building Change Detection Benchmark) |
| **Framework** | PyTorch |
| **Duration** | ~ 45 Minutes |
| **Prerequisites** | Basic Python, intro to CNNs, familiarity with PyTorch |

### Learning Objectives

1. Understand the **Siamese network** architecture for bi-temporal image comparison
2. Implement a **Deep Siamese Multi-Scale CNN (DSMS-FCN)** for pixel-wise change detection
3. Train and evaluate a model on the **LEVIR-CD** building change dataset
4. Interpret change detection results and compute standard evaluation metrics (F1, IoU, OA)
5. Connect remote sensing change detection to real-world applications: urban growth monitoring and post-disaster damage assessment

### Key References

- [DSMSCN](https://github.com/ChenHongruixuan/DSMSCN) ‚Äî Deep Siamese Multi-Scale Convolutional Network (Chen et al., 2019)
- [SNUNet-CD / Siam-NestedUNet](https://github.com/likyoo/Siam-NestedUNet) ‚Äî Densely Connected Siamese Network (Fang et al., 2021)
- [KPCAMNet](https://github.com/ChenHongruixuan/KPCAMNet) ‚Äî Unsupervised Change Detection with Kernel PCA (Chen et al., 2022)
- [ChangeDetectionRepository](https://github.com/ChenHongruixuan/ChangeDetectionRepository) ‚Äî Collection of traditional & DL-based CD methods
- [Change-Detection-Review](https://github.com/MinZHANG-WHU/Change-Detection-Review) ‚Äî Comprehensive review of AI-based CD methods
- [Awesome RS Change Detection](https://github.com/wenhwu/awesome-remote-sensing-change-detection) ‚Äî Datasets, methods, and competitions

---
# 1. Background

## 1.1 Why Change Detection Matters

Change detection from satellite imagery is one of the most impactful applications of remote sensing. By comparing images of the same location acquired at different times, we can automatically identify where buildings have appeared, been demolished, or sustained damage. This capability is critical for:

- **Urban planning**: tracking city expansion and infrastructure development
- **Disaster response**: assessing earthquake, hurricane, or wildfire damage
- **Environmental monitoring**: detecting deforestation, coastal erosion, land-use change

Traditional approaches relied on hand-crafted features and pixel-level differencing, but these struggle with the complex heterogeneity of VHR satellite images where illumination changes, seasonal variation, and registration errors produce false alarms.

## 1.2 Siamese Networks for Change Detection

A **Siamese Neural Network** (sometimes called a twin neural network) is a unique architecture designed not to classify an input, but to differentiate or find similarities between two different inputs.

<p align="center">
  <img src="https://github.com/smousavi05/Harvard-EPS-210/raw/main/figures/siames_nn.png" width="700">
</p>

 For change detection:
- **Input**: Two images of the same area at times T‚ÇÅ and T‚ÇÇ
- **Feature Extraction**: Each network transforms its input into a low-dimensional vector, called an embedding.
- **Shared encoder**: Maps both images into the same feature space
- **Difference module**: Computes |F(T‚ÇÅ) ‚àí F(T‚ÇÇ)| at multiple scales
- **Decoder**: Upsamples difference features to produce a pixel-wise change map
- **Similarity Measurement**: The outputs are then fed into a loss function (like Triplet Loss or Contrastive Loss) that calculates the distance between the two vectors.

    * Small distance: The inputs are very similar.
    * Large distance: The inputs are different.

The key innovation in this lab is the **Multi-Scale Feature Convolution Unit (MFCU)**, which extracts features at multiple spatial scales (1√ó1, 3√ó3, 5√ó5 kernels) within a single layer, inspired by the DSMSCN architecture (Chen et al., 2019).

## 1.3 The LEVIR-CD Dataset

We use the **LEVIR-CD** dataset (Chen & Shi, 2020):

- **637** pairs of VHR Google Earth images (0.5 m/pixel, 1024√ó1024)
- **20 regions** in Texas, USA, spanning 2002‚Äì2018
- **31,333** annotated building change instances
- Binary labels: 1 = change (new construction/demolition), 0 = no change

---
# 2. Part 1 ‚Äî Environment Setup & Data Preparation

## 2.1 Install Dependencies

In [None]:
#@title üì¶ Install Required Packages { display-mode: "form" }
!pip install -q torch torchvision torchaudio
!pip install -q torchmetrics matplotlib scikit-learn scikit-image tqdm gdown

In [None]:
# @title Imports
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from tqdm.auto import tqdm
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Check GPU availability
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}')
if DEVICE == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

## 2.2 Download the LEVIR-CD Dataset

The LEVIR-CD dataset is organized as:
```
LEVIR-CD/
‚îú‚îÄ‚îÄ train/
‚îÇ   ‚îú‚îÄ‚îÄ A/       # Time 1 images
‚îÇ   ‚îú‚îÄ‚îÄ B/       # Time 2 images
‚îÇ   ‚îî‚îÄ‚îÄ label/   # Binary change masks
‚îú‚îÄ‚îÄ val/
‚îÇ   ‚îú‚îÄ‚îÄ A/, B/, label/
‚îî‚îÄ‚îÄ test/
    ‚îú‚îÄ‚îÄ A/, B/, label/
```




In [None]:
#@title üì• Download, Unzip & Organize LEVIR-CD { display-mode: "form" }
import os
import shutil
import zipfile
import gdown

DATA_ROOT = '/content/LEVIR-CD'

if not os.path.exists(DATA_ROOT):
    os.makedirs(DATA_ROOT)

print("üöÄ Starting download from Google Drive...")

file_ids = {
    'test.zip': '1UPaZuyYe-JufA6042go7pIvxuiuICN1s',
    'train.zip': '1qeyzaXk5ZF7MqVOe1OVxtEd0MnCMzBWf',
    'val.zip': '1L78dDgeKSd7UTP2hjWeAnnwTIpAHvMiL'
}

for filename, file_id in file_ids.items():
    output_path = os.path.join(DATA_ROOT, filename)
    if not os.path.exists(output_path):
        print(f"   Downloading {filename}...")
        gdown.download(id=file_id, output=output_path, quiet=False)
    else:
        print(f"   {filename} already exists, skipping download.")

# 4. Clean, Unzip, and Reorganize
def clean_and_reorganize():
    print("\nüßπ Cleaning up old extracted folders to prevent conflicts...")
    for folder in ['A', 'B', 'label', 'train', 'val', 'test']:
        path = os.path.join(DATA_ROOT, folder)
        if os.path.exists(path) and os.path.isdir(path):
            shutil.rmtree(path)

    print("üì¶ Extracting and organizing files...")

    for split in ['train', 'val', 'test']:
        zip_path = os.path.join(DATA_ROOT, f'{split}.zip')
        target_dir = os.path.join(DATA_ROOT, split)
        os.makedirs(target_dir, exist_ok=True)

        if os.path.exists(zip_path):
            print(f"   Processing {split}.zip...")
            try:
                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall(target_dir)

                # Handle nested folders
                sub_items = os.listdir(target_dir)
                nested_folder = os.path.join(target_dir, split)

                if split in sub_items and os.path.isdir(nested_folder):
                    print(f"   -> Fixing nested folder structure for {split}...")
                    for item in os.listdir(nested_folder):
                        shutil.move(os.path.join(nested_folder, item), target_dir)
                    os.rmdir(nested_folder)
            except zipfile.BadZipFile:
                print(f"   ‚ö†Ô∏è Warning: {split}.zip appears to be corrupted. skipping.")
        else:
             print(f"   ‚ö†Ô∏è Warning: {split}.zip not found in download.")

    print("‚úÖ Reorganization complete.")

clean_and_reorganize()

train_path = os.path.join(DATA_ROOT, 'train', 'A')
if os.path.exists(train_path):
    count = len(os.listdir(train_path))
    print(f"\nüéâ SUCCESS! Real LEVIR-CD dataset is ready.")
    print(f"   Training samples found: {count}")
    print(f"   Location: {train_path}")
else:
    print("\n‚ùå Error: Data still not found. Please check the /content/LEVIR-CD folder manually.")

In [None]:
#@title ‚úÇÔ∏è Resize Dataset to 128x128 (to prevent RAM Crash) { display-mode: "form" }
import os
from PIL import Image
from tqdm.auto import tqdm

# Paths
ORIGINAL_ROOT = '/content/LEVIR-CD'
NEW_ROOT = '/content/LEVIR-CD-128'
TARGET_SIZE = (128, 128)

def resize_dataset_to_disk(src_root, dst_root, size):
    print(f"üìâ Resizing dataset from {src_root} to {dst_root}...")

    if not os.path.exists(src_root):
        raise FileNotFoundError(f"Original dataset not found at {src_root}!")

    # Process all splits and subfolders
    for split in ['train', 'val', 'test']:
        for subdir in ['A', 'B', 'label']:
            src_dir = os.path.join(src_root, split, subdir)
            dst_dir = os.path.join(dst_root, split, subdir)

            os.makedirs(dst_dir, exist_ok=True)

            files = sorted(os.listdir(src_dir))
            for fname in tqdm(files, desc=f"{split}/{subdir}", leave=False):
                if not fname.endswith(('.png', '.jpg', '.tif')):
                    continue

                # Load
                src_path = os.path.join(src_dir, fname)
                img = Image.open(src_path)

                # Resize (Nearest Neighbor for masks to keep them binary!)
                if subdir == 'label':
                    img_resized = img.resize(size, resample=Image.NEAREST)
                else:
                    img_resized = img.resize(size, resample=Image.BILINEAR)

                # Save
                dst_path = os.path.join(dst_dir, fname)
                img_resized.save(dst_path)

    print(f"‚úÖ Resize Complete. New dataset located at: {dst_root}")

# Run the resize
resize_dataset_to_disk(ORIGINAL_ROOT, NEW_ROOT, TARGET_SIZE)

## 2.3 Data Loading and Exploration

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image

DATA_ROOT_FOR_LOADER = NEW_ROOT # Use the 128x128 resized dataset

class LEVIRCDDataset(Dataset):
    """
    PyTorch Dataset for LEVIR-CD change detection.

    Loads bi-temporal image pairs (A=T1, B=T2) and binary change masks.
    Images are normalized using ImageNet statistics.
    """
    def __init__(self, root_dir, split='train', augment=False):
        self.augment = augment
        self.normalize = T.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

        # Check for standard structure: root/train/A
        if os.path.exists(os.path.join(root_dir, split, 'A')):
            self.img_A_dir = os.path.join(root_dir, split, 'A')
            self.img_B_dir = os.path.join(root_dir, split, 'B')
            self.label_dir = os.path.join(root_dir, split, 'label')
            # Load all files in the split directory
            self.filenames = sorted([f for f in os.listdir(self.img_A_dir)
                                     if f.endswith(('.png', '.jpg', '.tif'))])
        else:
             raise FileNotFoundError(f"Could not find dataset images in {root_dir}")

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        img_A = Image.open(os.path.join(self.img_A_dir, fname)).convert('RGB')
        img_B = Image.open(os.path.join(self.img_B_dir, fname)).convert('RGB')
        label = Image.open(os.path.join(self.label_dir, fname)).convert('L')

        # Data augmentation (training only)
        if self.augment:
            # Random horizontal flip
            if np.random.random() > 0.5:
                img_A = img_A.transpose(Image.FLIP_LEFT_RIGHT)
                img_B = img_B.transpose(Image.FLIP_LEFT_RIGHT)
                label = label.transpose(Image.FLIP_LEFT_RIGHT)
            # Random vertical flip
            if np.random.random() > 0.5:
                img_A = img_A.transpose(Image.FLIP_TOP_BOTTOM)
                img_B = img_B.transpose(Image.FLIP_TOP_BOTTOM)
                label = label.transpose(Image.FLIP_TOP_BOTTOM)
            # Random 90-degree rotation
            if np.random.random() > 0.5:
                k = np.random.choice([1, 2, 3])
                img_A = img_A.rotate(90 * k)
                img_B = img_B.rotate(90 * k)
                label = label.rotate(90 * k)

        # Convert to tensors
        img_A = self.normalize(T.ToTensor()(img_A))
        img_B = self.normalize(T.ToTensor()(img_B))
        label = (T.ToTensor()(label) > 0.5).float().squeeze(0)

        return img_A, img_B, label


# Create datasets
train_ds = LEVIRCDDataset(DATA_ROOT_FOR_LOADER, split='train', augment=True)
val_ds   = LEVIRCDDataset(DATA_ROOT_FOR_LOADER, split='val',   augment=False)
test_ds  = LEVIRCDDataset(DATA_ROOT_FOR_LOADER, split='test',  augment=False)

print(f'Training samples:   {len(train_ds)}')
print(f'Validation samples: {len(val_ds)}')
print(f'Test samples:       {len(test_ds)}')



### Visualize Sample Pairs

Let's look at a few bi-temporal image pairs and their change masks:

In [None]:
import matplotlib.pyplot as plt

def show_samples(dataset, indices, title=''):
    """Display bi-temporal image pairs and change masks."""
    n = len(indices)
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])

    fig, axes = plt.subplots(n, 3, figsize=(14, 4.5 * n))
    if n == 1: axes = axes[np.newaxis, :]

    col_titles = ['Time 1 (Before)', 'Time 2 (After)', 'Change Mask (GT)']
    for j, t in enumerate(col_titles):
        axes[0, j].set_title(t, fontsize=14, fontweight='bold')

    for i, idx in enumerate(indices):
        imgA, imgB, mask = dataset[idx]
        imgA_np = np.clip(imgA.permute(1,2,0).numpy() * std + mean, 0, 1)
        imgB_np = np.clip(imgB.permute(1,2,0).numpy() * std + mean, 0, 1)

        axes[i, 0].imshow(imgA_np)
        axes[i, 1].imshow(imgB_np)
        axes[i, 2].imshow(mask.numpy(), cmap='hot', vmin=0, vmax=1)

        for ax in axes[i]:
            ax.axis('off')

        # Show change percentage
        pct = mask.sum().item() / mask.numel() * 100
        axes[i, 2].text(5, 15, f'{pct:.1f}% changed', color='cyan',
                        fontsize=11, fontweight='bold',
                        bbox=dict(boxstyle='round', facecolor='black', alpha=0.7))

    if title:
        fig.suptitle(title, fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()


# Display 4 random samples
np.random.seed(42)
sample_indices = np.random.choice(len(train_ds), 4, replace=False)
show_samples(train_ds, sample_indices, title='LEVIR-CD Training Samples')

### Dataset Statistics

In [None]:
# Compute class balance statistics
total_pixels = 0
change_pixels = 0

print('Analyzing class balance (sampling 100 images)...')
sample_size = min(100, len(train_ds))
for i in tqdm(range(sample_size)):
    _, _, mask = train_ds[i]
    total_pixels += mask.numel()
    change_pixels += mask.sum().item()

change_ratio = change_pixels / total_pixels
print(f'\nClass Balance Analysis:')
print(f'  Changed pixels:   {change_ratio:.2%}')
print(f'  Unchanged pixels: {1 - change_ratio:.2%}')
print(f'  Imbalance ratio:  1:{(1 - change_ratio) / change_ratio:.0f}')
print(f'\n‚ö†Ô∏è  Significant class imbalance! This motivates our use of Dice Loss.')

<div style="background: #e8f5e9; border-left: 4px solid #4caf50; padding: 12px; margin: 10px 0; border-radius: 4px;">
<b>‚úÖ Checkpoint:</b> Verify that you can load and display image pairs. You should see RGB satellite images and binary change masks. Notice how changed areas (new buildings) appear as white regions in the mask, and that the changed class is a small fraction of total pixels.
</div>

---
# 3. Part 2 ‚Äî Building the Siamese CNN Model

## 3.1 Multi-Scale Feature Convolution Unit (MFCU)

The **MFCU** is the building block of our network, inspired by the [DSMSCN architecture](https://github.com/ChenHongruixuan/DSMSCN). It extracts spatial features at multiple scales using **parallel convolution branches** with kernel sizes 1√ó1, 3√ó3, and 5√ó5. The outputs are concatenated and fused with a 1√ó1 convolution.

This design (similar to Inception modules) captures both fine-grained building edges and broader contextual information simultaneously.

In [None]:
class MFCU(nn.Module):
    """
    Multi-Scale Feature Convolution Unit.

    Parallel branches with 1x1, 3x3, and 5x5 kernels capture
    features at different spatial scales, then fuse them.

    Inspired by: Chen et al. (2019) "Deep Siamese Multi-scale
    Convolutional Network for Change Detection" (DSMSCN)
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        mid = out_ch // 3
        remainder = out_ch - mid * 3

        # Three parallel branches at different scales
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_ch, mid, kernel_size=1),
            nn.BatchNorm2d(mid),
            nn.ReLU(inplace=True)
        )
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_ch, mid, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid),
            nn.ReLU(inplace=True)
        )
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_ch, mid + remainder, kernel_size=5, padding=2),
            nn.BatchNorm2d(mid + remainder),
            nn.ReLU(inplace=True)
        )

        # Fusion: combine multi-scale features
        self.fuse = nn.Sequential(
            nn.Conv2d(mid * 3 + remainder, out_ch, kernel_size=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        f1 = self.branch1x1(x)  # Fine details
        f3 = self.branch3x3(x)  # Local context
        f5 = self.branch5x5(x)  # Broader context
        return self.fuse(torch.cat([f1, f3, f5], dim=1))


# Quick test
mfcu = MFCU(3, 32)
test_input = torch.randn(1, 3, 64, 64)
test_output = mfcu(test_input)
print(f'MFCU: {test_input.shape} ‚Üí {test_output.shape}')

## 3.2 Siamese Encoder (Weight-Shared)

The encoder processes each temporal image through 4 MFCU blocks with max-pooling. Both branches **share identical weights**, ensuring the images are mapped into the same feature space for meaningful comparison.

In [None]:
class SiameseEncoder(nn.Module):
    """
    Weight-shared encoder with 4 MFCU blocks.
    Produces feature maps at 4 spatial scales.
    """
    def __init__(self):
        super().__init__()
        self.enc1 = MFCU(3, 32)
        self.enc2 = MFCU(32, 64)
        self.enc3 = MFCU(64, 128)
        self.enc4 = MFCU(128, 256)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        f1 = self.enc1(x)               # [B, 32,  H,   W]
        f2 = self.enc2(self.pool(f1))    # [B, 64,  H/2, W/2]
        f3 = self.enc3(self.pool(f2))    # [B, 128, H/4, W/4]
        f4 = self.enc4(self.pool(f3))    # [B, 256, H/8, W/8]
        return [f1, f2, f3, f4]


# Verify feature map sizes
enc = SiameseEncoder()
test_img = torch.randn(1, 3, 256, 256)
features = enc(test_img)
for i, f in enumerate(features):
    print(f'  Scale {i+1}: {f.shape}')

## 3.3 Change Detection Decoder (U-Net Style)

The decoder takes **absolute-difference feature maps** from both branches at each scale and progressively upsamples them with skip connections. This recovers spatial resolution while preserving both fine and coarse change information.

In [None]:
class ChangeDecoder(nn.Module):
    """
    U-Net style decoder that processes multi-scale difference features.
    Uses transposed convolutions for upsampling and skip connections.
    """
    def __init__(self):
        super().__init__()
        # Upsampling + MFCU blocks
        self.up4 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = MFCU(256, 128)   # 128 (upsampled) + 128 (skip)

        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = MFCU(128, 64)    # 64 (upsampled) + 64 (skip)

        self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = MFCU(64, 32)     # 32 (upsampled) + 32 (skip)

        # Final classification head
        self.head = nn.Conv2d(32, 1, kernel_size=1)  # Binary output

    def forward(self, diffs):
        d1, d2, d3, d4 = diffs  # Multi-scale difference features

        x = self.up4(d4)                            # [B, 128, H/4, W/4]
        x = self.dec3(torch.cat([x, d3], dim=1))    # [B, 128, H/4, W/4]

        x = self.up3(x)                              # [B, 64, H/2, W/2]
        x = self.dec2(torch.cat([x, d2], dim=1))     # [B, 64, H/2, W/2]

        x = self.up2(x)                              # [B, 32, H, W]
        x = self.dec1(torch.cat([x, d1], dim=1))     # [B, 32, H, W]

        return self.head(x)                           # [B, 1, H, W]

## 3.4 Complete Siamese Change Detection Network

In [None]:
class SiamMSCDNet(nn.Module):
    """
    Siamese Multi-Scale Change Detection Network.

    Architecture:
      1. Shared encoder processes both T1 and T2 images
      2. Absolute difference computed at each feature scale
      3. U-Net decoder produces pixel-wise change map

    References:
      - DSMSCN (Chen et al., 2019): Multi-scale feature convolution
      - SNUNet-CD (Fang et al., 2021): Siamese nested architecture
    """
    def __init__(self):
        super().__init__()
        self.encoder = SiameseEncoder()  # Shared weights
        self.decoder = ChangeDecoder()

    def forward(self, img_t1, img_t2):
        # Extract features from both temporal images
        feats_t1 = self.encoder(img_t1)  # [f1, f2, f3, f4]
        feats_t2 = self.encoder(img_t2)  # Same encoder!

        # Compute absolute difference at each scale
        diffs = [torch.abs(f1 - f2)
                 for f1, f2 in zip(feats_t1, feats_t2)]

        # Decode differences into change map
        logits = self.decoder(diffs)
        return logits.squeeze(1)  # [B, H, W]


# ‚îÄ‚îÄ Verify the full architecture ‚îÄ‚îÄ
model = SiamMSCDNet()
x1 = torch.randn(2, 3, 256, 256)
x2 = torch.randn(2, 3, 256, 256)
out = model(x1, x2)

n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Input shape:      2 √ó [B, 3, 256, 256]')
print(f'Output shape:     {list(out.shape)}')
print(f'Total parameters: {n_params:,}')
print(f'Trainable params: {n_trainable:,}')
print(f'Model size:       ~{n_params * 4 / 1e6:.1f} MB (float32)')

<div style="background: #e3f2fd; border-left: 4px solid #1e88e5; padding: 12px; margin: 10px 0; border-radius: 4px;">
<b>üèóÔ∏è Architecture Summary:</b><br>
<b>Input:</b> Two 256√ó256√ó3 RGB satellite images (T‚ÇÅ and T‚ÇÇ)<br>
<b>Encoder:</b> 4 MFCU blocks with shared weights ‚Üí features at 4 scales<br>
<b>Difference:</b> |F(T‚ÇÅ) ‚àí F(T‚ÇÇ)| computed at each scale<br>
<b>Decoder:</b> U-Net upsampling with skip connections<br>
<b>Output:</b> 256√ó256 binary change probability map
</div>

---
# 4. Part 3 ‚Äî Training the Model

## 4.1 Loss Function: Dice + BCE

Change detection is highly **imbalanced** (most pixels are unchanged). We combine:

- **Binary Cross-Entropy (BCE)**: Standard pixel-wise classification loss
- **Dice Loss**: Directly optimizes overlap between prediction and ground truth, helping with rare positive pixels

In [None]:
class DiceBCELoss(nn.Module):
    """
    Combined Dice Loss + Binary Cross-Entropy Loss.
    Dice Loss handles class imbalance by directly optimizing
    the overlap metric (similar to F1 score).
    """
    def __init__(self, bce_weight=0.5):
        super().__init__()
        self.bce_weight = bce_weight
        self.bce = nn.BCEWithLogitsLoss()

    def dice_loss(self, pred, target, smooth=1.0):
        pred_sig = torch.sigmoid(pred)
        intersection = (pred_sig * target).sum()
        union = pred_sig.sum() + target.sum()
        return 1 - (2.0 * intersection + smooth) / (union + smooth)

    def forward(self, pred, target):
        bce = self.bce(pred, target)
        dice = self.dice_loss(pred, target)
        return self.bce_weight * bce + (1 - self.bce_weight) * dice

## 4.2 Evaluation Metrics

High **Precision** ($\frac{TP}{TP + FP}$) means the model is careful and doesn't produce many "False Positives" (crying wolf).

High **Recall** ($\frac{TP}{TP + FN}$) means the model is thorough and doesn't miss many "False Negatives."

**F1-Score** ($ 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}}$) penalizes extreme values; if either Precision or Recall is very low, the F1-score will be low.

**IoU (Intersection over Union)** ($\frac{\text{Area of Overlap}}{\text{Area of Union}}$) measures the overlap between the predicted boundary and the ground truth boundary. IoU = 1: Perfect overlap. IoU > 0.5: Generally considered a "good" prediction in many benchmarks.

**OA (Overall Accuracy)** ($\frac{TP + TN}{\text{Total Population}}$) It is the fraction of total predictions that were correct. OA can be very deceptive if your data is imbalanced. For example, if 99% of your data is "Class A," a model that always predicts "Class A" will have a 99% OA but is actually useless for finding "Class B."

In [None]:
@torch.no_grad()
def evaluate(model, loader, device=DEVICE):
    """
    Evaluate change detection model.
    Returns: dict with precision, recall, f1, iou, oa
    """
    model.eval()
    TP, FP, FN, TN = 0, 0, 0, 0

    for imgA, imgB, mask in loader:
        imgA = imgA.to(device)
        imgB = imgB.to(device)
        mask = mask.to(device)

        pred = torch.sigmoid(model(imgA, imgB)) > 0.5
        pred_b = pred.bool()
        mask_b = mask.bool()

        TP += (pred_b & mask_b).sum().item()
        FP += (pred_b & ~mask_b).sum().item()
        FN += (~pred_b & mask_b).sum().item()
        TN += (~pred_b & ~mask_b).sum().item()

    eps = 1e-8
    precision = TP / (TP + FP + eps)
    recall    = TP / (TP + FN + eps)
    f1        = 2 * precision * recall / (precision + recall + eps)
    iou       = TP / (TP + FP + FN + eps)
    oa        = (TP + TN) / (TP + FP + FN + TN + eps)

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'iou': iou,
        'oa': oa
    }

## 4.3 Training Loop

In [None]:
import torch.cuda.amp as amp

# 1. Set specific allocation config to prevent fragmentation
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# 2. Hyperparameters
EPOCHS = 25
BATCH_SIZE = 8
LR = 1e-3
NUM_WORKERS = 2

# üßπ CLEANUP: Clear GPU before re-initializing
import gc
gc.collect()
torch.cuda.empty_cache()

# --- Data Loaders ---
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          drop_last=True, persistent_workers=True)

val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          persistent_workers=True)

test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          persistent_workers=True)

# --- Model, Loss, Optimizer ---
model = SiamMSCDNet().to(DEVICE)
criterion = DiceBCELoss(bce_weight=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# 3. Initialize GradScaler for Mixed Precision
scaler = amp.GradScaler()

print(f'Training on {DEVICE} with {len(train_ds)} samples')
print(f'Batch size: {BATCH_SIZE}, Epochs: {EPOCHS}, LR: {LR}')

history = {'train_loss': [], 'val_f1': [], 'val_iou': [], 'val_oa': [], 'lr': []}
best_f1 = 0

for epoch in range(1, EPOCHS + 1):
    # --- Train ---
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{EPOCHS}', leave=False)

    for imgA, imgB, mask in pbar:
        imgA, imgB, mask = imgA.to(DEVICE), imgB.to(DEVICE), mask.to(DEVICE)

        optimizer.zero_grad()

        with amp.autocast():
            pred = model(imgA, imgB)
            loss = criterion(pred, mask)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * imgA.size(0)
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    train_loss = running_loss / len(train_loader.dataset)

    # --- Validate ---
    # We use a custom evaluate function that strictly uses torch.no_grad()
    metrics = evaluate(model, val_loader)
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]

    # --- Log ---
    history['train_loss'].append(train_loss)
    history['val_f1'].append(metrics['f1'])
    history['val_iou'].append(metrics['iou'])
    history['val_oa'].append(metrics['oa'])
    history['lr'].append(current_lr)

    star = ''
    if metrics['f1'] > best_f1:
        best_f1 = metrics['f1']
        torch.save(model.state_dict(), 'best_model.pth')
        star = ' ‚òÖ Best!'

    print(f"Epoch {epoch:3d}/{EPOCHS} ‚îÇ Loss: {train_loss:.4f} ‚îÇ "
          f"F1: {metrics['f1']:.4f} ‚îÇ IoU: {metrics['iou']:.4f} ‚îÇ "
          f"OA: {metrics['oa']:.4f} ‚îÇ LR: {current_lr:.6f}{star}")

print(f'\n‚úÖ Training complete! Best validation F1: {best_f1:.4f}')

### Training Curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss
axes[0].plot(history['train_loss'], color='#A51C30', linewidth=2)
axes[0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Dice + BCE Loss')
axes[0].grid(True, alpha=0.3)

# F1 and IoU
axes[1].plot(history['val_f1'], color='#1E5A96', linewidth=2, label='F1 Score')
axes[1].plot(history['val_iou'], color='#4CAF50', linewidth=2, label='IoU')
axes[1].set_title('Validation Metrics', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].legend(fontsize=12)
axes[1].grid(True, alpha=0.3)

# Learning Rate
axes[2].plot(history['lr'], color='#FF9800', linewidth=2)
axes[2].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('LR')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---
# 5. Part 4 ‚Äî Evaluation and Visualization

## 5.1 Test Set Evaluation

In [None]:
# Load best model and evaluate on test set
model.load_state_dict(torch.load('best_model.pth', map_location=DEVICE))
test_metrics = evaluate(model, test_loader)

print('=' * 55)
print('       TEST SET RESULTS')
print('=' * 55)
print(f'  Precision:  {test_metrics["precision"]:.4f}')
print(f'  Recall:     {test_metrics["recall"]:.4f}')
print(f'  F1 Score:   {test_metrics["f1"]:.4f}')
print(f'  IoU:        {test_metrics["iou"]:.4f}')
print(f'  Overall Acc:{test_metrics["oa"]:.4f}')
print('=' * 55)

## 5.2 Visualize Predictions

In [None]:
def visualize_predictions(model, dataset, indices, device=DEVICE):
    """Visualize model predictions alongside inputs and ground truth."""
    model.eval()
    n = len(indices)
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])

    fig, axes = plt.subplots(n, 4, figsize=(18, 4.5 * n))
    if n == 1: axes = axes[np.newaxis, :]

    col_titles = ['Time 1 (Before)', 'Time 2 (After)', 'Ground Truth', 'Prediction']
    for j, t in enumerate(col_titles):
        axes[0, j].set_title(t, fontsize=13, fontweight='bold')

    for i, idx in enumerate(indices):
        imgA, imgB, mask = dataset[idx]
        with torch.no_grad():
            pred = torch.sigmoid(
                model(imgA.unsqueeze(0).to(device),
                      imgB.unsqueeze(0).to(device))).cpu().squeeze()
        pred_mask = (pred > 0.5).float().numpy()

        imgA_np = np.clip(imgA.permute(1,2,0).numpy() * std + mean, 0, 1)
        imgB_np = np.clip(imgB.permute(1,2,0).numpy() * std + mean, 0, 1)

        axes[i, 0].imshow(imgA_np)
        axes[i, 1].imshow(imgB_np)
        axes[i, 2].imshow(mask.numpy(), cmap='hot', vmin=0, vmax=1)
        axes[i, 3].imshow(pred_mask, cmap='hot', vmin=0, vmax=1)

        for ax in axes[i]: ax.axis('off')

    plt.suptitle('Model Predictions on Test Set', fontsize=16, fontweight='bold', y=1.01)
    plt.tight_layout()
    plt.savefig('predictions.png', dpi=150, bbox_inches='tight')
    plt.show()


# Show 6 test predictions
np.random.seed(123)
test_indices = np.random.choice(len(test_ds), 6, replace=False)
visualize_predictions(model, test_ds, test_indices)

## 5.3 Error Analysis: Confusion Maps

Color-coded confusion maps reveal where the model succeeds and fails:
- üü¢ **Green** = True Positive (correctly detected change)
- üî¥ **Red** = False Positive (false alarm)
- üîµ **Blue** = False Negative (missed change)

In [None]:
def plot_confusion_maps(model, dataset, indices, device=DEVICE):
    """Create color-coded confusion maps for error analysis."""
    model.eval()
    n = len(indices)
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])

    fig, axes = plt.subplots(n, 3, figsize=(15, 5 * n))
    if n == 1: axes = axes[np.newaxis, :]

    col_titles = ['Time 2 (After)', 'Ground Truth', 'Confusion Map']
    for j, t in enumerate(col_titles):
        axes[0, j].set_title(t, fontsize=13, fontweight='bold')

    for i, idx in enumerate(indices):
        imgA, imgB, mask = dataset[idx]
        with torch.no_grad():
            pred = (torch.sigmoid(
                model(imgA.unsqueeze(0).to(device),
                      imgB.unsqueeze(0).to(device))).cpu().squeeze() > 0.5)

        mask_b = mask.bool()
        pred_b = pred.bool()

        # Build confusion map
        h, w = mask.shape
        confusion = np.zeros((h, w, 3), dtype=np.uint8)
        confusion[mask_b & pred_b]   = [0, 200, 0]      # TP: Green
        confusion[~mask_b & pred_b]  = [220, 50, 50]     # FP: Red
        confusion[mask_b & ~pred_b]  = [50, 80, 220]     # FN: Blue

        imgB_np = np.clip(imgB.permute(1,2,0).numpy() * std + mean, 0, 1)

        axes[i, 0].imshow(imgB_np)
        axes[i, 1].imshow(mask.numpy(), cmap='hot', vmin=0, vmax=1)
        axes[i, 2].imshow(confusion)

        # Count stats
        tp = (mask_b & pred_b).sum().item()
        fp = (~mask_b & pred_b).sum().item()
        fn = (mask_b & ~pred_b).sum().item()
        axes[i, 2].text(5, 15, f'TP:{tp} FP:{fp} FN:{fn}',
                        color='white', fontsize=9, fontweight='bold',
                        bbox=dict(facecolor='black', alpha=0.7))

        for ax in axes[i]: ax.axis('off')

    # Legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='green', label='True Positive'),
        Patch(facecolor='red', label='False Positive'),
        Patch(facecolor='blue', label='False Negative'),
    ]
    fig.legend(handles=legend_elements, loc='lower center',
              ncol=3, fontsize=12, bbox_to_anchor=(0.5, -0.02))
    plt.suptitle('Error Analysis: Confusion Maps', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('confusion_maps.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_confusion_maps(model, test_ds, test_indices[:4])

---
# 6. Part 5 ‚Äî Experiments and Extensions

## Transfer to Damage Detection

The [xBD dataset](https://xview2.org/) (from the xView2 challenge) provides satellite imagery of disaster-affected areas with building damage labels at four severity levels. A model pre-trained on LEVIR-CD for building change can be **fine-tuned** for damage classification.

The idea: your LEVIR-CD encoder already knows how to detect *structural change* in buildings. Damage (from earthquakes, hurricanes, etc.) is fundamentally a type of structural change. By freezing the encoder and training a new classification head, you can transfer these learned features.

```python
# Pseudocode for transfer learning to damage detection
# 1. Load pre-trained encoder weights
pretrained = SiamMSCDNet()
pretrained.load_state_dict(torch.load('best_model.pth'))

# 2. Freeze encoder
for param in pretrained.encoder.parameters():
    param.requires_grad = False

# 3. Replace decoder head for 4-class damage grading
# (no damage, minor, major, destroyed)
pretrained.decoder.head = nn.Conv2d(32, 4, kernel_size=1)

# 4. Fine-tune on xBD with CrossEntropyLoss
```

---
# 7. Discussion Questions

Answer these questions in markdown cells below (1‚Äì2 paragraphs each).

**Q1.** Why do Siamese networks use weight sharing between the two branches? What would happen if each branch had independent weights? What are the implications for the learned feature space?

**Q2.** Examine your confusion maps. What types of errors does the model make most frequently? Are false positives or false negatives more common, and why might that be?

**Q3.** How does the class imbalance between changed and unchanged pixels affect training? How does the Dice Loss component help, compared to using BCE alone?

**Q4.** Compare the multi-scale (MFCU) vs. single-scale ablation results. Why would multi-scale features improve change detection for buildings of varying sizes?

**Q5.** How could this change detection approach be applied to post-earthquake damage assessment? What additional challenges would arise compared to the LEVIR-CD scenario?

**Q6.** What role does the temporal gap between images play in change detection accuracy? How might very short or very long time gaps affect model performance?

**Q1 Answer:**

*(Your answer here)*

**Q2 Answer:**

*(Your answer here)*

**Q3 Answer:**

*(Your answer here)*

**Q4 Answer:**

*(Your answer here)*

**Q5 Answer:**

*(Your answer here)*

**Q6 Answer:**

*(Your answer here)*

---
# 8. References and Resources

### Key Papers

1. Chen, H., Wu, C., Du, B., & Zhang, L. (2019). *Deep Siamese Multi-scale Convolutional Network for Change Detection in Multi-Temporal VHR Images.* MultiTemp 2019. [[Code]](https://github.com/ChenHongruixuan/DSMSCN)
2. Chen, H. & Shi, Z. (2020). *A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection.* Remote Sensing, 12(10). [[LEVIR-CD]](https://justchenhao.github.io/LEVIR/)
3. Fang, S. et al. (2021). *SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images.* IEEE GRSL. [[Code]](https://github.com/likyoo/Siam-NestedUNet)
4. Chen, H. et al. (2022). *Unsupervised Change Detection in Multitemporal VHR Images Based on Deep Kernel PCA Convolutional Mapping Network.* IEEE TCYB. [[Code]](https://github.com/ChenHongruixuan/KPCAMNet)
5. Zhang, M. et al. (2020). *Change Detection Based on Artificial Intelligence: State-of-the-Art and Challenges.* [[Review]](https://github.com/MinZHANG-WHU/Change-Detection-Review)

### Code Repositories

| Repository | Description |
|---|---|
| [DSMSCN](https://github.com/ChenHongruixuan/DSMSCN) | TensorFlow implementation of Siamese Multi-Scale CNN |
| [Siam-NestedUNet](https://github.com/likyoo/Siam-NestedUNet) | PyTorch Siamese Nested U-Net (SNUNet-CD) |
| [KPCAMNet](https://github.com/ChenHongruixuan/KPCAMNet) | Unsupervised CD with deep kernel PCA |
| [ChangeDetectionRepository](https://github.com/ChenHongruixuan/ChangeDetectionRepository) | Traditional & DL-based CD methods collection |
| [Change-Detection-Review](https://github.com/MinZHANG-WHU/Change-Detection-Review) | Comprehensive review with code & datasets |
| [Awesome RS CD](https://github.com/wenhwu/awesome-remote-sensing-change-detection) | Datasets, methods, competitions |

### Datasets

| Dataset | Description |
|---|---|
| [LEVIR-CD](https://justchenhao.github.io/LEVIR/) | 637 VHR image pairs, 31K building changes, Texas |
| [WHU Building CD](http://sigma.whu.edu.cn/resource.php) | Christchurch, NZ (2012 earthquake) |
| [SZTAKI AirChange](http://mplab.sztaki.hu/remotesensing/airchange_benchmark.html) | Multi-temporal aerial change benchmark |
| [xBD / xView2](https://xview2.org/) | Building damage assessment across disasters |

----

# üî¥ Activity 2: Martian Crater Detection with YOLO
### Using the [2022 GeoAI Martian Challenge](http://cici.lab.asu.edu/martian/#home) Dataset


---

## Learning Objectives

By the end of this lab, you will be able to:

1. Work with a **planetary science benchmark dataset**
2. Train (fine tune) a **YOLO** object detection model to locate Martian craters
3. Evaluate detection performance using standard metrics (precision, recall, mAP)
4. Analyze model behavior across different crater sizes and terrain types

---

## Background

### Why Martian Crater Detection?

Impact craters are the dominant landform on Mars. A global census of craters enables:

- **Surface age dating** ‚Äî Crater size-frequency distributions are the primary chronometer for planetary surfaces.
- **Geological mapping** ‚Äî Crater morphology reveals subsurface ice, lava flows, and erosion history.
- **Landing site hazard assessment** ‚Äî Missions like Perseverance and future crewed landings need automated terrain analysis.
- **Climate history** ‚Äî Crater degradation patterns record billions of years of atmospheric and fluvial erosion.

### The 2022 GeoAI Martian Challenge Dataset

This benchmark dataset was developed by the [ASU CICI Lab](https://cici.lab.asu.edu/martian/) ([Hsu et al., 2021](https://www.mdpi.com/2072-4292/13/11/2116)) and assembles:

- **102,675 images** extracted from a global Mars mosaic
- **301,912 annotated craters** with bounding boxes
- **Source imagery:** Mars Odyssey THEMIS (Thermal Emission Imaging System) daytime infrared, 100 m/pixel ([Edwards et al., 2011](https://doi.org/10.1029/2010JE003755))
- **Crater labels:** From the [Robbins & Hynek (2012)](https://doi.org/10.1029/2011JE003966) global catalog of 640K+ craters
- **Image size:** 256√ó256 pixels (25.6√ó25.6 km per tile)
- **Crater sizes:** 0.2 km (2 px) to 25.5 km (255 px) in diameter

The dataset was designed for a formal AI competition, with train/val/test splits and COCO-format annotations.

### Crater Size Groups

| Group | Diameter | Pixels | Count | % |
|-------|----------|--------|-------|---|
| Small | 0.2‚Äì1 km | 2‚Äì10 px | 115,871 | 38% |
| Medium | 1.1‚Äì5 km | 11‚Äì50 px | 172,251 | 57% |
| Large | 5‚Äì25.5 km | 50‚Äì255 px | 13,790 | 5% |

### YOLO: You Only Look Once

YOLO is a family of **single-stage object detectors** that predict bounding boxes and class probabilities in one forward pass:

```
Image ‚Üí Backbone (features) ‚Üí Neck (multi-scale fusion) ‚Üí Head ‚Üí Boxes + Confidence
```

Key concepts: bounding boxes as `(x_center, y_center, width, height)`, confidence thresholds, Non-Maximum Suppression (NMS), and mean Average Precision (mAP).

> ‚ö†Ô∏è **Note on dataset size:** The full dataset is 4.3 GB. For this 1-hour lab, we use a **subset** (~5,000 training images) so training completes in ~10 min on a T4 GPU. The full dataset can be used for research projects.


---
# 1. Background

## 1.1 Semantic Segmentation vs. Object Detection

Two fundamental computer vision tasks in remote sensing are:

- **Semantic segmentation**: Classify *every pixel* into categories (e.g., crater rim vs. background). The output is a mask the same size as the input image.
- **Object detection**: Locate objects with *bounding boxes* and classify them (e.g., draw a box around each ship in a SAR image). The output is a list of (x, y, w, h, class) tuples.

## 1.2 The DeepMoon Approach

The [DeepMoon](https://github.com/silburt/DeepMoon) project (Silburt et al., 2019) demonstrated that a CNN can identify lunar craters from Digital Elevation Maps (DEMs) with high accuracy. Their pipeline:

1. **Input**: DEM image patches of the lunar surface (grayscale elevation data)
2. **Target**: Binary ring masks where crater rims are marked as white annuli
3. **Model**: A U-Net-style encoder‚Äìdecoder network that outputs a pixel-wise probability map
4. **Post-processing**: Template matching on the predicted ring mask to extract (x, y, radius) for each crater

The key insight is that craters appear as circular depressions in elevation data, and their rims produce characteristic ring-shaped gradients that CNNs can learn to detect.

## 1.3 SAR Ship Detection (YOLO-style)

The [SAR_yolov3](https://github.com/humblecoder612/SAR_yolov3) project applies the YOLO (You Only Look Once) object detection framework to Synthetic Aperture Radar (SAR) satellite imagery for ship detection. SAR sensors are immune to weather and lighting conditions, making them ideal for maritime surveillance. YOLO-style detectors process the entire image in a single pass and output bounding boxes directly, achieving real-time performance. In Part 5, we extend our U-Net framework toward a simple detection head.

In [None]:
# @title Part 1: Setup & Installation (~3 min)
!pip install -q ultralytics pycocotools

import os, json, glob, random, shutil, yaml
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from pathlib import Path
from collections import Counter

import torch
print(f"PyTorch: {torch.__version__}  |  CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("‚ö†Ô∏è  No GPU! Go to Runtime ‚Üí Change runtime type ‚Üí GPU (T4)")

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

---
## Part 2: Download & Explore the Dataset (~10 min)

We download the 2022 GeoAI Martian Challenge Dataset directly from ASU's server.

In [None]:
#@title üì• Download, Unzip & Organize Martian Dataset { display-mode: "form" }

import os
import shutil
import zipfile
import gdown
import glob

DATA_DIR = '/content/MARTIAN_DATASET'
ZIP_PATH = os.path.join(DATA_DIR, "martian_yolo_dataset.zip")
EXTRACT_DIR = os.path.join(DATA_DIR, "extracted")

if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)

print("üöÄ Starting download from Google Drive ...")

file_ids = {
    'martian_yolo_dataset.zip': '1CIIqqVVdfnyF7EV7X27fhkvPvhbrVSv_',
}

for filename, file_id in file_ids.items():
    output_path = os.path.join(DATA_DIR, filename)
    if not os.path.exists(output_path):
        print(f"   Downloading {filename}...")
        gdown.download(id=file_id, output=output_path, quiet=False)
    else:
        print(f"   {filename} already exists, skipping download.")

# Extract
if not os.path.exists(EXTRACT_DIR):
    print("Extracting dataset...")
    !unzip -q "{ZIP_PATH}" -d "{EXTRACT_DIR}"
else:
    print("Already extracted.")

# Find the root of the extracted data
# It may be nested inside a subfolder
candidates = glob.glob(os.path.join(EXTRACT_DIR, '**', 'ids.json'), recursive=True)
if candidates:
    CHALLENGE_DIR = os.path.dirname(candidates[0])
else:
    # Try direct path
    CHALLENGE_DIR = EXTRACT_DIR

print(f"Dataset root: {CHALLENGE_DIR}")
print("Contents:")
if os.path.exists(CHALLENGE_DIR):
    for item in sorted(os.listdir(CHALLENGE_DIR)):
        full = os.path.join(CHALLENGE_DIR, item)
        if os.path.isdir(full):
            n_files = len(os.listdir(full))
            print(f"  üìÅ {item}/ ({n_files:,} files)")
        else:
            print(f"  üìÑ {item} ({os.path.getsize(full)/1e6:.1f} MB)")
else:
    print(f"‚ùå Error: Directory {CHALLENGE_DIR} does not exist. Extraction might have failed.")


In [None]:
import os
from pathlib import Path

def print_yolo_stats(base_dir):
    print(f"Dataset Statistics for: {base_dir}")
    print("="*50)
    print(f"{'Split':<10} | {'Images':<8} | {'Craters':<8} | {'Avg Craters/Img':<15}")
    print("-"*50)

    for split in ['train', 'val', 'test']:
        img_dir = os.path.join(base_dir, split, 'images')
        lbl_dir = os.path.join(base_dir, split, 'labels')

        if not os.path.exists(img_dir):
            print(f"{split:<10} | Directory not found")
            continue

        images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg'))]
        labels = [f for f in os.listdir(lbl_dir) if f.endswith('.txt')]

        total_craters = 0
        for lbl in labels:
            with open(os.path.join(lbl_dir, lbl), 'r') as f:
                total_craters += len(f.readlines())

        n_img = len(images)
        avg = total_craters / n_img if n_img > 0 else 0

        print(f"{split:<10} | {n_img:<8,} | {total_craters:<8,} | {avg:<15.2f}")
    print("="*50)

# Run stats for the Martian YOLO folder
print_yolo_stats('/content/martian_yolo')


**‚ùì Questions to consider:**

1. Most craters are very small (< 10 pixels). Why is small-object detection particularly challenging for neural networks?
2. The images are THEMIS daytime **infrared** rather than optical. What surface properties does IR capture that visible light does not?
3. Some images have many craters, others just one. How might this imbalance affect training?

In [None]:
# Write data.yaml for YOLO
import os
import yaml

# Ensure YOLO_DIR is defined
YOLO_DIR = "/content/martian_yolo"

data_yaml = {
    'path': YOLO_DIR,
    'train': 'train/images',
    'val': 'val/images',
    'test': 'test/images',
    'nc': 1,
    'names': ['crater'],
}

yaml_path = os.path.join("/content/MARTIAN_DATASET/extracted", 'data.yaml')
with open(yaml_path, 'w') as f:
    yaml.dump(data_yaml, f, default_flow_style=False)

print(f"Created config at: {yaml_path}")
print(open(yaml_path).read())

In [None]:
# Verify: visualize YOLO-format annotations on sample tiles
import matplotlib.patches as patches
import random
import matplotlib.pyplot as plt
from PIL import Image

def show_yolo_samples(yolo_dir, split='train', n=6, seed=42):
    img_dir = os.path.join(yolo_dir, split, 'images')
    lbl_dir = os.path.join(yolo_dir, split, 'labels')

    all_imgs = sorted(glob.glob(os.path.join(img_dir, '*.png')))
    # Prefer images with annotations
    with_labels = [p for p in all_imgs
                   if os.path.getsize(os.path.join(lbl_dir, Path(p).stem + '.txt')) > 0]
    random.seed(seed)
    samples = random.sample(with_labels, min(n, len(with_labels)))

    cols = 3
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(15, 5.5 * rows))
    axes = axes.flatten()

    for idx, img_path in enumerate(samples):
        img = np.array(Image.open(img_path))
        h, w = img.shape[:2]
        axes[idx].imshow(img, cmap='gray')

        lbl_path = os.path.join(lbl_dir, Path(img_path).stem + '.txt')
        n_cr = 0
        with open(lbl_path) as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 5:
                    _, xc, yc, bw, bh = [float(v) for v in parts]
                    x1 = (xc - bw/2) * w
                    y1 = (yc - bh/2) * h
                    rect = patches.Rectangle(
                        (x1, y1), bw*w, bh*h,
                        linewidth=1.2, edgecolor='cyan', facecolor='none'
                    )
                    axes[idx].add_patch(rect)
                    n_cr += 1

        axes[idx].set_title(Path(img_path).stem, fontsize=9)
        axes[idx].set_xlabel(f'{n_cr} craters', fontsize=9)
        axes[idx].axis('off')

    for idx in range(len(samples), len(axes)):
        axes[idx].axis('off')

    plt.suptitle(f'YOLO-Format Annotations ‚Äî {split} set',
                fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

show_yolo_samples(YOLO_DIR, 'train', n=6)

**‚ùì Questions:**

1. Why did we include ~10% **negative images** (no craters) in the training set?
2. The COCO‚ÜíYOLO conversion normalizes coordinates to [0,1]. Why is this beneficial for training?
3. Some craters span just 2-3 pixels. Should we filter these out, or keep them? What are the tradeoffs?

---
## Part 4: Train YOLOv8 for Crater Detection (~20 min)

We use **YOLOv8n** (nano, 3.2M params) for fast training. The model is pretrained on COCO (everyday objects) and we fine-tune it on Martian craters ‚Äî a compelling case of **transfer learning** across domains.

In [None]:
from ultralytics import YOLO

model = YOLO('yolov5n.pt')  # try the latest pretrained model: yolov8n.pt
print(f"Parameters: {sum(p.numel() for p in model.model.parameters()):,}")

In [None]:
# Train ‚Äî expect ~20 min on a T4 GPU with our subset
results = model.train(
    data=yaml_path,
    epochs=25,               # Adjust based on time
    imgsz=256,               # Native image size
    batch=64,                # Increase if GPU allows
    name='martian_craters',
    patience=10,             # Early stopping
    save=True,
    plots=True,
    # Augmentation
    flipud=0.5,              # Craters look the same flipped
    fliplr=0.5,
    degrees=180.0,           # Full rotation (craters are rotationally symmetric)
    mosaic=1.0,              # Mosaic: stitch 4 images into one
    mixup=0.1,               # MixUp: blend two images
    hsv_h=0.0,               # No hue shift (grayscale IR)
    hsv_s=0.0,               # No saturation shift
    hsv_v=0.3,               # Brightness variation (simulates different thermal conditions)
    scale=0.3,               # Scale augmentation for multi-size craters
)

**‚è±Ô∏è While the model trains**, let's think about what's happening:

1. How YOLOv5 differe from earlier versions and YOLOv8?
2. **Loss components**:
   - **CIoU box loss**: Penalizes incorrect box placement (center, size, aspect ratio)
   - **BCE classification loss**: Binary cross-entropy for crater vs. background
   - **DFL (Distribution Focal Loss)**: Refined bounding box regression

---
## Part 5: Evaluate & Visualize Results (~15 min)

In [None]:
from IPython.display import display, Image as IPImage

train_dir = Path(results.save_dir)
print(f"Results saved to: {train_dir}")

# Training curves
for fname in ['results.png']:
    fpath = train_dir / fname
    if fpath.exists():
        display(IPImage(filename=str(fpath), width=900))

This code block performs a formal evaluation of your best-trained model.

**Loading the Best Model**: It initializes a new YOLO object using best.pt, which is the version of the model that achieved the highest performance during training.

**Looping through Splits**: It iterates through both the 'val' (validation) and 'test' datasets.

**Running Validation**: The model.val() function runs the model on those specific images and calculates standard object detection metrics.

**Printing Metrics**: Finally, it prints key performance indicators:

  * Precision: How many of the detected craters were actually real craters.

  * Recall: What percentage of all real craters the model successfully found.
  
  * mAP@50: The Mean Average Precision at an Intersection over Union (IoU) threshold of 0.5 (a standard accuracy measure).
  
  * mAP@50-95: A more rigorous metric that averages precision across multiple IoU thresholds, reflecting how perfectly the bounding boxes fit the craters.

In [None]:
# Formal evaluation
best_model = YOLO(str(train_dir / 'weights' / 'best.pt'))

for split_name in ['val']:
    metrics = best_model.val(data=yaml_path, imgsz=256, split=split_name, verbose=False)
    print(f"\n{'='*50}")
    print(f"{split_name.upper()} SET METRICS")
    print(f"{'='*50}")
    print(f"  Precision:   {metrics.box.mp:.3f}")
    print(f"  Recall:      {metrics.box.mr:.3f}")
    print(f"  mAP@50:      {metrics.box.map50:.3f}")
    print(f"  mAP@50-95:   {metrics.box.map:.3f}")
    print(f"{'='*50}")

In [None]:
# Visualize detections vs. ground truth
def visualize_detections(model, yolo_dir, split='val', n=6, conf=0.25, seed=99):
    img_dir = os.path.join(yolo_dir, split, 'images')
    lbl_dir = os.path.join(yolo_dir, split, 'labels')

    all_imgs = sorted(glob.glob(os.path.join(img_dir, '*.png')))
    with_labels = [p for p in all_imgs
                   if os.path.getsize(os.path.join(lbl_dir, Path(p).stem + '.txt')) > 0]
    random.seed(seed)
    samples = random.sample(with_labels, min(n, len(with_labels)))

    cols = 3
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(16, 5.5 * rows))
    axes = axes.flatten()

    for idx, img_path in enumerate(samples):
        img = np.array(Image.open(img_path))
        h, w = img.shape[:2]
        axes[idx].imshow(img, cmap='gray')

        # Ground truth (green dashed)
        lbl_path = os.path.join(lbl_dir, Path(img_path).stem + '.txt')
        n_gt = 0
        with open(lbl_path) as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 5:
                    _, xc, yc, bw, bh = [float(v) for v in parts]
                    rect = patches.Rectangle(
                        ((xc-bw/2)*w, (yc-bh/2)*h), bw*w, bh*h,
                        linewidth=1.5, edgecolor='lime', facecolor='none', linestyle='--'
                    )
                    axes[idx].add_patch(rect)
                    n_gt += 1

        # Predictions (red solid)
        preds = model.predict(img_path, conf=conf, verbose=False)
        n_pred = 0
        if preds[0].boxes is not None:
            for box in preds[0].boxes:
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                c = box.conf[0].cpu().numpy()
                rect = patches.Rectangle(
                    (x1, y1), x2-x1, y2-y1,
                    linewidth=1.5, edgecolor='red', facecolor='none'
                )
                axes[idx].add_patch(rect)
                axes[idx].text(x1, max(y1-3, 8), f'{c:.2f}', color='red', fontsize=7,
                              fontweight='bold',
                              bbox=dict(boxstyle='round,pad=0.1', facecolor='black', alpha=0.5))
                n_pred += 1

        axes[idx].set_title(Path(img_path).stem, fontsize=9)
        axes[idx].set_xlabel(f'GT: {n_gt} (green) | Pred: {n_pred} (red)', fontsize=9)
        axes[idx].axis('off')

    for idx in range(len(samples), len(axes)):
        axes[idx].axis('off')

    from matplotlib.lines import Line2D
    legend = [Line2D([0],[0], color='lime', ls='--', lw=2, label='Ground Truth'),
              Line2D([0],[0], color='red', lw=2, label='Prediction')]
    fig.legend(handles=legend, loc='upper center', ncol=2, fontsize=11,
              bbox_to_anchor=(0.5, 1.02))
    plt.suptitle(f'Crater Detection ‚Äî {split} (conf ‚â• {conf})',
                fontsize=14, fontweight='bold', y=1.04)
    plt.tight_layout()
    plt.show()

visualize_detections(best_model, YOLO_DIR, 'val', n=6)

In [None]:
# Analyze performance by crater size
# Bin detections into Small / Medium / Large and measure recall for each
print("Analyzing detection performance by crater size...\n")

size_bins = {'Small (‚â§10 px)': (0, 10), 'Medium (11-50 px)': (11, 50), 'Large (>50 px)': (51, 999)}
size_stats = {k: {'total': 0, 'detected': 0} for k in size_bins}

val_imgs = sorted(glob.glob(os.path.join(YOLO_DIR, 'val', 'images', '*.png')))
for img_path in val_imgs[:200]:  # Analyze first 200 val images
    lbl_path = os.path.join(YOLO_DIR, 'val', 'labels', Path(img_path).stem + '.txt')
    if not os.path.exists(lbl_path) or os.path.getsize(lbl_path) == 0:
        continue

    gt_boxes = []
    with open(lbl_path) as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) == 5:
                _, xc, yc, bw, bh = [float(v) for v in parts]
                gt_boxes.append((xc, yc, bw, bh))

    preds = best_model.predict(img_path, conf=0.25, verbose=False)
    pred_boxes = []
    if preds[0].boxes is not None:
        for box in preds[0].boxes:
            x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
            pred_boxes.append(((x1+x2)/2/256, (y1+y2)/2/256, (x2-x1)/256, (y2-y1)/256))

    for gxc, gyc, gbw, gbh in gt_boxes:
        diam_px = max(gbw, gbh) * 256
        for label, (lo, hi) in size_bins.items():
            if lo <= diam_px <= hi:
                size_stats[label]['total'] += 1
                # Check if any prediction overlaps (simple center-distance match)
                for pxc, pyc, pbw, pbh in pred_boxes:
                    dist = ((gxc - pxc)**2 + (gyc - pyc)**2)**0.5
                    if dist < max(gbw, gbh) * 0.5:
                        size_stats[label]['detected'] += 1
                        break
                break

print(f"{'Size Group':<20} {'Total':>8} {'Detected':>10} {'Recall':>8}")
print("-" * 48)
for label, s in size_stats.items():
    recall = s['detected'] / max(s['total'], 1)
    print(f"{label:<20} {s['total']:>8} {s['detected']:>10} {recall:>8.1%}")

In [None]:
# Confidence threshold analysis
sample_img = sorted(glob.glob(os.path.join(YOLO_DIR, 'val', 'images', '*.png')))
# Pick an image with several craters
for p in sample_img:
    lp = os.path.join(YOLO_DIR, 'val', 'labels', Path(p).stem + '.txt')
    if os.path.exists(lp) and os.path.getsize(lp) > 30:
        test_img = p
        break

thresholds = [0.10, 0.25, 0.50, 0.75]
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

for i, thresh in enumerate(thresholds):
    img = np.array(Image.open(test_img))
    axes[i].imshow(img, cmap='gray')
    preds = best_model.predict(test_img, conf=thresh, verbose=False)
    n_det = 0
    if preds[0].boxes is not None:
        for box in preds[0].boxes:
            x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
            rect = patches.Rectangle(
                (x1, y1), x2-x1, y2-y1,
                linewidth=1.5, edgecolor='cyan', facecolor='none'
            )
            axes[i].add_patch(rect)
            n_det += 1
    axes[i].set_title(f'Conf ‚â• {thresh} ‚Üí {n_det} detections', fontsize=11)
    axes[i].axis('off')

plt.suptitle('Effect of Confidence Threshold on Detections',
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

**‚ùì Questions:**

1. How does detection recall vary across the three size groups? Why are small craters harder to detect?
2. What confidence threshold would you choose for (a) building a complete crater catalog vs. (b) safe landing site selection?
3. Look at the gap between mAP@50 and mAP@50-95. What does this tell you about localization precision?

---
## Part 6: Optional Experiments

Try one of these if time permits.

In [None]:
# ============================================================
# EXPERIMENT A: Larger model (YOLOv5s ‚Äî 11M params) (YOLOv5m ‚Äî 25M params) (YOLOv5l ‚Äî 43M params)
# ============================================================
# model_s = YOLO('yolov5s.pt')
# results_s = model_s.train(
#     data=yaml_path, epochs=40, imgsz=256, batch=32,
#     name='martian_craters_small', patience=10, verbose=False
# )
# metrics_s = model_s.val(data=yaml_path, split='val')
# print(f"YOLOv8s mAP@50: {metrics_s.box.map50:.3f}")

In [None]:
# ============================================================
# EXPERIMENT B: More advanced model (YOLOv8n, YOLOv11n,)
# ============================================================
# model_s = YOLO('yolov8s.pt')
# results_s = model_s.train(
#     data=yaml_path, epochs=40, imgsz=256, batch=32,
#     name='martian_craters_small', patience=10, verbose=False
# )
# metrics_s = model_s.val(data=yaml_path, split='val')
# print(f"YOLOv8s mAP@50: {metrics_s.box.map50:.3f}")

In [None]:
# ============================================================
# EXPERIMENT C: Filter tiny craters
# Remove craters < 5 pixels from training labels.
# Does focusing on detectable craters improve overall mAP?
# ============================================================
# Hint: modify the build_yolo_split function to skip
# annotations where max(bw, bh) * IMG_SIZE < 5

---
## üìù Lab Wrap-Up & Discussion

### Discussion Questions

1. **Small object detection**: Nearly 40% of craters are ‚â§10 pixels. What architectural modifications could improve small-crater detection? (Hint: consider the Feature Pyramid Network, higher input resolution, or specialized anchor sizes.)

2. **Transfer learning**: We transferred from COCO (everyday objects like cars and cats) to Martian craters. Why does this work? What features from COCO are useful for craters? Would starting from a model pretrained on satellite imagery be better?

3. **Catalog quality**: The Robbins & Hynek catalog was manually compiled using THEMIS IR + topographic data. What are the implications of using the same imagery for both labeling and detection? Could the model find craters the catalog missed?

4. **From Mars to the Moon**: How would you adapt this pipeline for lunar crater detection? What differences in imagery (LROC vs. THEMIS), crater morphology, and surface conditions would you need to account for?

5. **Scientific validation**: If you deployed this globally on Mars, how would you validate the output? How could you estimate completeness and contamination rates as a function of crater size?


### Going Further

- Train on the **full 50K training set** for a research-grade model
- Try **YOLOv8m** or **YOLOv8l** for higher accuracy (at the cost of speed)
- Fuse THEMIS IR imagery with **MOLA topographic data** as additional input channels
- Submit results to the [GeoAI challenge leaderboard](https://codalab.lisn.upsaclay.fr/competitions/1934)

### Citations

```bibtex
@article{hsu2021knowledge,
  title={Knowledge-Driven GeoAI: Integrating Spatial Knowledge into Multi-Scale
         Deep Learning for Mars Crater Detection},
  author={Hsu, Chia-Yu and Li, Wenwen and Wang, Sizhe},
  journal={Remote Sensing},
  volume={13}, number={11}, pages={2116},
  year={2021}, publisher={MDPI}
}

@article{robbins2012new,
  title={A new global database of Mars impact craters $\geq$ 1 km:
         1. Database creation, properties, and parameters},
  author={Robbins, Stuart J and Hynek, Brian M},
  journal={Journal of Geophysical Research: Planets},
  volume={117}, number={E5}, year={2012}
}

@article{edwards2011mosaicking,
  title={Mosaicking of global planetary image datasets: 1. Techniques and data
         processing for THEMIS multi-spectral data},
  author={Edwards, Christopher S and others},
  journal={Journal of Geophysical Research: Planets},
  volume={116}, number={E10}, year={2011}
}
```

In [None]:
# # ============================================================
# # EXPERIMENT D: Train on MORE data
# # Increase N_TRAIN to 20000 or even 50000 (full dataset)
# # and re-run Parts 3-5. Does more data help?
# # ============================================================

# # Download & Unzip FULL Martian Dataset

# import os
# import shutil
# import zipfile
# import gdown
# import glob

# DATA_DIR = '/content/MARTIAN_DATASET'
# ZIP_PATH = os.path.join(DATA_DIR, "2022_GeoAI_Martian_Challenge_Dataset.zip")
# EXTRACT_DIR = os.path.join(DATA_DIR, "extracted")

# if not os.path.exists(DATA_DIR):
#     os.makedirs(DATA_DIR)

# print("üöÄ Starting download from Google Drive ...")

# file_ids = {
#     '2022_GeoAI_Martian_Challenge_Dataset.zip': '1eGCBMeyDzKL7DNk01qSbqOrqpLOnQiR6',
# }

# for filename, file_id in file_ids.items():
#     output_path = os.path.join(DATA_DIR, filename)
#     if not os.path.exists(output_path):
#         print(f"   Downloading {filename}...")
#         gdown.download(id=file_id, output=output_path, quiet=False)
#     else:
#         print(f"   {filename} already exists, skipping download.")

# # Extract
# if not os.path.exists(EXTRACT_DIR):
#     print("Extracting dataset...")
#     !unzip -q "{ZIP_PATH}" -d "{EXTRACT_DIR}"
# else:
#     print("Already extracted.")

# # Find the root of the extracted data
# # It may be nested inside a subfolder
# candidates = glob.glob(os.path.join(EXTRACT_DIR, '**', 'ids.json'), recursive=True)
# if candidates:
#     CHALLENGE_DIR = os.path.dirname(candidates[0])
# else:
#     # Try direct path
#     CHALLENGE_DIR = EXTRACT_DIR

# print(f"Dataset root: {CHALLENGE_DIR}")
# print("Contents:")
# if os.path.exists(CHALLENGE_DIR):
#     for item in sorted(os.listdir(CHALLENGE_DIR)):
#         full = os.path.join(CHALLENGE_DIR, item)
#         if os.path.isdir(full):
#             n_files = len(os.listdir(full))
#             print(f"  üìÅ {item}/ ({n_files:,} files)")
#         else:
#             print(f"  üìÑ {item} ({os.path.getsize(full)/1e6:.1f} MB)")
# else:
#     print(f"‚ùå Error: Directory {CHALLENGE_DIR} does not exist. Extraction might have failed.")



# # @title Load the split IDs and annotations\
# import json
# import os
# import numpy as np

# with open(os.path.join(CHALLENGE_DIR, 'ids.json'), 'r') as f:
#     ids = json.load(f)

# with open(os.path.join(CHALLENGE_DIR, 'gt_public.json'), 'r') as f:
#     gt_public = json.load(f)

# print(f"Split sizes:")
# for split_name in ['train', 'val', 'test']:
#     print(f"  {split_name:6s}: {len(ids.get(split_name, [])):>6,} images")

# print(f"\nAnnotation file keys: {list(gt_public.keys())[:5]}...")
# print(f"  Images with annotations:      {len(gt_public):,}")
# print(f"  Total annotations (bounding boxes): {sum(len(v) for v in gt_public.values()):,}")
# print(f"  Categories: ['crater'] ")


# sample_image_id = next(iter(gt_public))
# print(f"Sample Image ID: {sample_image_id}")
# print(json.dumps(gt_public[sample_image_id][:2], indent=2))

# print("\n‚Üí Bounding box format: [x_min, y_min, width, height] in pixels (COCO format)")
# print("‚Üí We need to convert this to YOLO format: (x_center, y_center, width, height) normalized to [0,1]")

# # image_id ‚Üí image info
# # The gt_public.json provided for this lab contains only image_id -> annotations.
# # We need to construct image info based on the fixed image size and known image IDs.
# # All images are 256x256 as per the lab description.
# img_info = {}
# # Collect all unique image IDs from the splits to ensure all images have info
# all_image_ids = set(ids['train']) | set(ids['val']) | set(ids['test'])
# for img_id in all_image_ids:
#     img_info[img_id] = {
#         'id': img_id,
#         'file_name': f"{img_id}.png",
#         'width': 256,
#         'height': 256
#     }

# # image_id ‚Üí list of annotations
# from collections import defaultdict
# # Populate img_anns using the gt_public directly, as it already maps image IDs to lists of annotations
# img_anns = defaultdict(list)
# for img_id, annotations_list in gt_public.items():
#     img_anns[img_id] = annotations_list

# # Quick stats on training annotations
# train_ids_set = set(ids['train'])
# train_crater_counts = []
# train_crater_sizes = []

# for img_id in ids['train']:
#     # Only consider images that actually have annotations in img_anns
#     if img_id in img_anns:
#         anns = img_anns[img_id]
#         train_crater_counts.append(len(anns))
#         for a in anns:
#             # Bounding box format is [x_min, y_min, w, h]
#             bw, bh = a[2], a[3]
#             train_crater_sizes.append(max(bw, bh))  # diameter in pixels

# print(f"Training set statistics:")
# print(f"  Total craters: {sum(train_crater_counts):,}")
# print(f"  Mean craters/image: {np.mean(train_crater_counts):.1f}")
# print(f"  Crater size (pixels): median={np.median(train_crater_sizes):.0f}, "
#       f"range=[{np.min(train_crater_sizes):.0f}, {np.max(train_crater_sizes):.0f}]")


# # Visualize dataset statistics
# fig, axes = plt.subplots(1, 3, figsize=(18, 4.5))

# # Craters per image
# axes[0].hist(train_crater_counts, bins=range(0, max(train_crater_counts[:10000])+2),
#              color='steelblue', edgecolor='white', linewidth=0.3)
# axes[0].set_xlabel('Craters per image')
# axes[0].set_ylabel('Number of images')
# axes[0].set_title('Crater Count Distribution')
# axes[0].set_xlim(0, 20)
# axes[0].axvline(np.mean(train_crater_counts), color='red', ls='--',
#                 label=f'Mean: {np.mean(train_crater_counts):.1f}')
# axes[0].legend()

# # Crater size distribution (pixels)
# axes[1].hist(train_crater_sizes, bins=np.arange(0, 260, 5),
#              color='coral', edgecolor='white', linewidth=0.3)
# axes[1].set_xlabel('Crater diameter (pixels)')
# axes[1].set_ylabel('Count')
# axes[1].set_title('Crater Size Distribution')
# axes[1].axvline(10, color='blue', ls='--', label='Small/Medium (10 px)')
# axes[1].axvline(50, color='green', ls='--', label='Medium/Large (50 px)')
# axes[1].legend(fontsize=9)

# # Size groups pie chart
# small = sum(1 for s in train_crater_sizes if s <= 10)
# medium = sum(1 for s in train_crater_sizes if 10 < s <= 50)
# large = sum(1 for s in train_crater_sizes if s > 50)
# axes[2].pie([small, medium, large], labels=['Small\n(‚â§10 px)', 'Medium\n(11-50 px)', 'Large\n(>50 px)'],
#             colors=['#e74c3c', '#f39c12', '#3498db'], autopct='%1.1f%%',
#             textprops={'fontsize': 11})
# axes[2].set_title('Crater Size Groups')

# plt.suptitle('GeoAI Martian Challenge ‚Äî Training Set Statistics',
#              fontsize=14, fontweight='bold')
# plt.tight_layout()
# plt.show()


# # ============================================================
# # Configuration
# # ============================================================
# N_TRAIN = 5000     # Subset size for training (full: ~50K)
# N_VAL = 1000       # Use the entire val set
# N_TEST = 1000      # Subset for testing
# IMG_SIZE = 256     # Image dimensions
# YOLO_DIR = "/content/martian_yolo"

# print(f"Building YOLO dataset: {N_TRAIN} train / {N_VAL} val / {N_TEST} test images")
# print(f"(Full dataset has ~50K train ‚Äî use for research projects)")


# def coco_to_yolo_box(bbox, img_w, img_h):
#     """Convert COCO bbox [x_min, y_min, w, h] to YOLO [xc, yc, w, h] normalized."""
#     x_min, y_min, bw, bh = bbox
#     xc = (x_min + bw / 2.0) / img_w
#     yc = (y_min + bh / 2.0) / img_h
#     w = bw / img_w
#     h = bh / img_h
#     # Clamp to [0, 1]
#     xc = max(0, min(1, xc))
#     yc = max(0, min(1, yc))
#     w = max(0.001, min(1, w))
#     h = max(0.001, min(1, h))
#     return xc, yc, w, h


# def build_yolo_split(split_name, image_ids, n_images, yolo_dir, challenge_dir,
#                      img_info, img_anns, img_size=256):
#     """
#     Sample n_images from image_ids, copy images and create YOLO label files.
#     Prioritizes images WITH annotations.
#     """
#     img_dir = os.path.join(yolo_dir, split_name, 'images')
#     lbl_dir = os.path.join(yolo_dir, split_name, 'labels')
#     os.makedirs(img_dir, exist_ok=True)
#     os.makedirs(lbl_dir, exist_ok=True)

#     # Prefer images with annotations, but include some without
#     with_anns = [i for i in image_ids if len(img_anns.get(i, [])) > 0]
#     without_anns = [i for i in image_ids if len(img_anns.get(i, [])) == 0]

#     # Sample: 90% with craters, 10% negatives
#     n_pos = min(int(n_images * 0.9), len(with_anns))
#     n_neg = min(n_images - n_pos, len(without_anns))

#     selected = random.sample(with_anns, n_pos)
#     if n_neg > 0 and without_anns:
#         selected += random.sample(without_anns, n_neg)
#     random.shuffle(selected)

#     total_craters = 0
#     for img_id in selected:
#         info = img_info.get(img_id)
#         if info is None:
#             continue

#         fname = info['file_name']
#         src_path = os.path.join(challenge_dir, 'images', fname)
#         if not os.path.exists(src_path):
#             src_path = os.path.join(challenge_dir, 'images', os.path.basename(fname))
#         if not os.path.exists(src_path):
#             continue

#         # Copy image
#         dst_img = os.path.join(img_dir, os.path.basename(fname))
#         shutil.copy2(src_path, dst_img)

#         # Write YOLO labels
#         stem = Path(fname).stem
#         lbl_path = os.path.join(lbl_dir, f"{stem}.txt")
#         anns = img_anns.get(img_id, [])

#         with open(lbl_path, 'w') as f:
#             for a in anns:
#                 xc, yc, w, h = coco_to_yolo_box(a, img_size, img_size)
#                 f.write(f"0 {xc:.6f} {yc:.6f} {w:.6f} {h:.6f}\n")
#                 total_craters += 1

#     return len(selected), total_craters


# print("Building YOLO dataset...")


# %%time
# import random
# import json
# import os
# import shutil
# from pathlib import Path

# # Also load gt_eval.json if available (contains val annotations)
# eval_path = os.path.join(CHALLENGE_DIR, 'gt_eval.json')
# if os.path.exists(eval_path):
#     with open(eval_path, 'r') as f:
#         gt_eval = json.load(f)
#     # Merge eval annotations into our lookup
#     for img_entry in gt_eval.get('images', []):
#         if img_entry['id'] not in img_info:
#             img_info[img_entry['id']] = img_entry
#     for ann in gt_eval.get('annotations', []):
#         img_anns[ann['image_id']].append(ann)
#     print(f"Loaded gt_eval.json: {len(gt_eval.get('annotations', []))} annotations")

# # --- FIX: Create a local test set from unused training images ---
# # The official 'test' set has no labels (blind test set for competition).
# # We will use a slice of the 'train' pool that wasn't used for training as our 'test' set.

# # Shuffle train IDs to ensure random selection
# all_train_ids = ids['train']
# random.shuffle(all_train_ids)

# # Split: Train (0 to N_TRAIN) | Test (N_TRAIN to N_TRAIN + N_TEST)
# train_split_ids = all_train_ids[:N_TRAIN]
# test_split_ids  = all_train_ids[N_TRAIN : N_TRAIN + N_TEST]
# val_split_ids   = ids['val'] # Keep official val set

# print(f"Redefined splits to ensure Test set has labels:")
# print(f"  Train source: ids['train'][:{N_TRAIN}]")
# print(f"  Test source:  ids['train'][{N_TRAIN}:{N_TRAIN+N_TEST}]")

# # Build each split
# stats = {}
# for split_name, split_ids, n in [('train', train_split_ids, N_TRAIN),
#                                    ('val', val_split_ids, N_VAL),
#                                    ('test', test_split_ids, N_TEST)]:
#     n_imgs, n_craters = build_yolo_split(
#         split_name, split_ids, n, YOLO_DIR, CHALLENGE_DIR,
#         img_info, img_anns, IMG_SIZE
#     )
#     stats[split_name] = (n_imgs, n_craters)
#     print(f"  {split_name:6s}: {n_imgs:>5,} images, {n_craters:>6,} craters")

# print("\n‚úÖ YOLO dataset ready!")


# # Write data.yaml for YOLO
# import os
# import yaml

# # Ensure YOLO_DIR is defined
# YOLO_DIR = "/content/martian_yolo"

# data_yaml = {
#     'path': YOLO_DIR,
#     'train': 'train/images',
#     'val': 'val/images',
#     'test': 'test/images',
#     'nc': 1,
#     'names': ['crater'],
# }

# yaml_path = os.path.join(YOLO_DIR, 'data.yaml')
# with open(yaml_path, 'w') as f:
#     yaml.dump(data_yaml, f, default_flow_style=False)

# print(f"Created config at: {yaml_path}")
# print(open(yaml_path).read())
