# Full model pipeline for optimizing anchor sizes for RPN

## Overview

This notebook provides a complete pipeline for optimizing RPN anchor sizes using:
1. **Dataset Analysis** - Analyze object sizes and aspect ratios in your dataset
2. **Data-Driven Suggestions** - Get initial anchor recommendations based on statistics
3. **Optuna Optimization** - Fine-tune anchors using Bayesian hyperparameter search
4. **Validation** - Compare default vs optimized anchors

## 1. Setup and Imports

In [None]:
import os
import sys
import gc
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path.cwd().parent))

# Set memory optimization for CUDA
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install optuna if not available
try:
    import optuna
    print(f"Optuna version: {optuna.__version__}")
except ImportError:
    print("Installing optuna...")
    !pip install optuna
    import optuna

In [None]:
# Import project modules
from datasets.isaid_dataset import iSAIDDataset
from training.transforms import get_transforms
from training.anchor_optimizer import (
    AnchorConfig,
    AnchorOptimizer, 
    DatasetAnchorAnalyzer,
    optimize_anchors_for_dataset,
    analyze_dataset_anchors,
)
from models.maskrcnn_model import CustomMaskRCNN, get_custom_maskrcnn

print("All modules imported successfully!")

## 2. Configuration

In [None]:
# Configuration
CONFIG = {
    # Dataset
    "data_root": "../iSAID_patches",  # Path to dataset
    "num_classes": 16,                 # 15 classes + background
    "image_size": 800,                 # Image size for training
    
    # Anchor Optimization
    "n_trials": 20,                    # Number of Optuna trials (increase for better results)
    "num_analysis_samples": 500,       # Samples for dataset analysis
    
    # Output
    "cache_path": "../optimized_anchors.pt",  # Where to save optimized anchors
}

# Default anchor configuration (for comparison)
DEFAULT_ANCHORS = {
    "sizes": ((16, 24), (32, 48), (64, 96), (128, 192)),
    "aspect_ratios": ((0.5, 1.0, 2.0),) * 4,
}

print("Configuration loaded:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 3. Load Dataset

In [None]:
# Load datasets (without augmentation for analysis)
print("Loading datasets...")

train_dataset = iSAIDDataset(
    CONFIG["data_root"],
    split="train",
    transforms=get_transforms(train=False),
    image_size=CONFIG["image_size"],
)

val_dataset = iSAIDDataset(
    CONFIG["data_root"],
    split="val",
    transforms=get_transforms(train=False),
    image_size=CONFIG["image_size"],
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

## 4. Dataset Analysis

First, let's analyze the object sizes and aspect ratios in the dataset to understand what anchor configurations might work best.

In [None]:
# Analyze dataset for bounding box statistics
analyzer = DatasetAnchorAnalyzer(train_dataset, num_samples=CONFIG["num_analysis_samples"])
stats = analyzer.compute_box_statistics()

print(f"\n{'='*50}")
print("Dataset Bounding Box Statistics")
print(f"{'='*50}")
print(f"Number of boxes analyzed: {len(stats['widths']):,}")
print(f"\nWidth:  min={stats['widths'].min():.1f}, max={stats['widths'].max():.1f}, mean={stats['widths'].mean():.1f}, std={stats['widths'].std():.1f}")
print(f"Height: min={stats['heights'].min():.1f}, max={stats['heights'].max():.1f}, mean={stats['heights'].mean():.1f}, std={stats['heights'].std():.1f}")
print(f"Area:   min={stats['areas'].min():.1f}, max={stats['areas'].max():.1f}, mean={stats['areas'].mean():.1f}")
print(f"Aspect Ratio: min={stats['aspect_ratios'].min():.2f}, max={stats['aspect_ratios'].max():.2f}, mean={stats['aspect_ratios'].mean():.2f}")

In [None]:
# Visualize distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Object scales (sqrt of area)
scales = np.sqrt(stats['areas'])
axes[0, 0].hist(scales, bins=50, edgecolor='black', alpha=0.7, color='steelblue')
axes[0, 0].axvline(np.median(scales), color='red', linestyle='--', label=f'Median: {np.median(scales):.1f}')
axes[0, 0].set_xlabel('Object Scale (√area)')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Distribution of Object Scales')
axes[0, 0].legend()

# Aspect ratios
axes[0, 1].hist(stats['aspect_ratios'], bins=50, edgecolor='black', alpha=0.7, color='coral')
axes[0, 1].axvline(np.median(stats['aspect_ratios']), color='red', linestyle='--', 
                   label=f'Median: {np.median(stats["aspect_ratios"]):.2f}')
axes[0, 1].set_xlabel('Aspect Ratio (w/h)')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Distribution of Aspect Ratios')
axes[0, 1].set_xlim(0, 5)
axes[0, 1].legend()

# Width vs Height scatter
sample_idx = np.random.choice(len(stats['widths']), min(2000, len(stats['widths'])), replace=False)
axes[1, 0].scatter(stats['widths'][sample_idx], stats['heights'][sample_idx], alpha=0.3, s=5)
axes[1, 0].set_xlabel('Width')
axes[1, 0].set_ylabel('Height')
axes[1, 0].set_title('Width vs Height (sample)')
axes[1, 0].set_aspect('equal')

# Scale percentiles (for anchor size selection)
percentiles = [10, 25, 50, 75, 90]
scale_percentiles = np.percentile(scales, percentiles)
axes[1, 1].barh(range(len(percentiles)), scale_percentiles, color='teal', edgecolor='black')
axes[1, 1].set_yticks(range(len(percentiles)))
axes[1, 1].set_yticklabels([f'{p}th' for p in percentiles])
axes[1, 1].set_xlabel('Object Scale')
axes[1, 1].set_title('Scale Percentiles')
for i, v in enumerate(scale_percentiles):
    axes[1, 1].text(v + 1, i, f'{v:.1f}', va='center')

plt.tight_layout()
plt.show()

print("\nScale percentiles (useful for anchor sizes):")
for p in [10, 20, 30, 50, 70, 80, 90]:
    print(f"  {p}th percentile: {np.percentile(scales, p):.1f}")

## 5. Data-Driven Anchor Suggestions

Get initial anchor recommendations based on dataset statistics (without Optuna optimization).

In [None]:
# Get data-driven suggestions
suggested_sizes = analyzer.suggest_anchor_sizes(num_scales=4)
suggested_ratios = analyzer.suggest_aspect_ratios(num_ratios=3)

print(f"\n{'='*50}")
print("Data-Driven Anchor Suggestions")
print(f"{'='*50}")
print(f"\nSuggested anchor sizes (per FPN level):")
for i, sizes in enumerate(suggested_sizes):
    print(f"  P{i+2}: {sizes}")
    
print(f"\nSuggested aspect ratios: {suggested_ratios}")

print(f"\n{'='*50}")
print("Comparison with Default Anchors")
print(f"{'='*50}")
print(f"\nDefault sizes:    {DEFAULT_ANCHORS['sizes']}")
print(f"Suggested sizes:  {suggested_sizes}")

## 6. Optuna Anchor Optimization

Now run the full Optuna-based optimization to find the best anchor configuration by maximizing RPN recall.

In [None]:
# Initialize the anchor optimizer
optimizer = AnchorOptimizer(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    num_classes=CONFIG["num_classes"],
    device=device,
    num_fpn_levels=4,
    base_aspect_ratios=(0.5, 1.0, 2.0),
)

print("Optimizer initialized!")
print(f"Data-suggested sizes: {optimizer.data_suggested_sizes}")
print(f"Data-suggested ratios: {optimizer.data_suggested_ratios}")

In [None]:
# Run optimization (this may take a while)
# Each trial creates a model, trains briefly, and evaluates RPN recall
print(f"Starting optimization with {CONFIG['n_trials']} trials...")
print("This may take 10-30 minutes depending on your hardware.\n")

best_config = optimizer.optimize(
    n_trials=CONFIG["n_trials"],
    study_name="anchor_optimization",
    storage=None,  # Use "sqlite:///anchor_study.db" for persistence
)

In [None]:
# Save optimized anchors
torch.save({
    'sizes': best_config.sizes,
    'aspect_ratios': best_config.aspect_ratios,
}, CONFIG["cache_path"])

print(f"Optimized anchors saved to: {CONFIG['cache_path']}")
print(f"\nBest configuration found:")
print(f"  Sizes: {best_config.sizes}")
print(f"  Aspect ratios: {best_config.aspect_ratios}")

## 7. Validation: Compare Default vs Optimized Anchors

Let's compare the RPN recall between the default and optimized anchor configurations.

In [None]:
from torch.utils.data import DataLoader, Subset
from torchvision.ops import box_iou
from torchvision.models.detection.image_list import ImageList

def evaluate_anchor_config(anchor_sizes, anchor_ratios, name, num_batches=50):
    """Evaluate anchor configuration by computing RPN recall."""
    print(f"\nEvaluating: {name}")
    print(f"  Sizes: {anchor_sizes}")
    print(f"  Ratios: {anchor_ratios}")
    
    # Create model with specified anchors
    model = CustomMaskRCNN(
        num_classes=CONFIG["num_classes"],
        rpn_anchor_sizes=anchor_sizes,
        rpn_aspect_ratios=anchor_ratios,
    )
    model.to(device)
    model.eval()
    
    # Create validation loader
    def collate_fn(batch):
        return tuple(zip(*batch))
    
    val_subset = Subset(val_dataset, range(min(100, len(val_dataset))))
    val_loader = DataLoader(val_subset, batch_size=1, shuffle=False, 
                           collate_fn=collate_fn, num_workers=0)
    
    recalls_05 = []
    recalls_075 = []
    
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(val_loader):
            if batch_idx >= num_batches:
                break
            
            try:
                images = [img.to(device) for img in images]
                images_tensor = torch.stack(images)
                
                original_sizes = [img.shape[-2:] for img in images]
                image_list = ImageList(images_tensor, original_sizes)
                
                features = model.backbone(images_tensor)
                proposals, _ = model.rpn(image_list, features, None)
                
                for props, target in zip(proposals, targets):
                    gt_boxes = target['boxes'].to(device)
                    if len(gt_boxes) == 0 or len(props) == 0:
                        continue
                    
                    ious = box_iou(props, gt_boxes)
                    max_ious, _ = ious.max(dim=0)
                    
                    recalls_05.append((max_ious >= 0.5).float().mean().item())
                    recalls_075.append((max_ious >= 0.75).float().mean().item())
                    
            except Exception as e:
                continue
    
    # Cleanup
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    recall_05 = np.mean(recalls_05) if recalls_05 else 0.0
    recall_075 = np.mean(recalls_075) if recalls_075 else 0.0
    
    print(f"  Recall@0.5: {recall_05:.4f}")
    print(f"  Recall@0.75: {recall_075:.4f}")
    
    return {"recall@0.5": recall_05, "recall@0.75": recall_075}

In [None]:
# Evaluate default anchors
default_results = evaluate_anchor_config(
    DEFAULT_ANCHORS["sizes"],
    DEFAULT_ANCHORS["aspect_ratios"],
    "Default Anchors"
)

# Evaluate optimized anchors
optimized_results = evaluate_anchor_config(
    best_config.sizes,
    best_config.aspect_ratios,
    "Optimized Anchors"
)

In [None]:
# Visualize comparison
fig, ax = plt.subplots(figsize=(10, 6))

metrics = ['recall@0.5', 'recall@0.75']
x = np.arange(len(metrics))
width = 0.35

default_vals = [default_results[m] for m in metrics]
optimized_vals = [optimized_results[m] for m in metrics]

bars1 = ax.bar(x - width/2, default_vals, width, label='Default Anchors', color='steelblue')
bars2 = ax.bar(x + width/2, optimized_vals, width, label='Optimized Anchors', color='coral')

ax.set_ylabel('Recall')
ax.set_title('RPN Recall Comparison: Default vs Optimized Anchors')
ax.set_xticks(x)
ax.set_xticklabels(['Recall@0.5', 'Recall@0.75'])
ax.legend()
ax.set_ylim(0, 1)

# Add value labels
for bar in bars1:
    height = bar.get_height()
    ax.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width()/2, height),
                xytext=(0, 3), textcoords="offset points", ha='center', va='bottom')
for bar in bars2:
    height = bar.get_height()
    ax.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width()/2, height),
                xytext=(0, 3), textcoords="offset points", ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Print improvement
print(f"\n{'='*50}")
print("Improvement Summary")
print(f"{'='*50}")
for m in metrics:
    improvement = (optimized_results[m] - default_results[m]) / max(default_results[m], 1e-6) * 100
    print(f"{m}: {default_results[m]:.4f} → {optimized_results[m]:.4f} ({improvement:+.1f}%)")

## 8. Usage: How to Use Optimized Anchors

Copy the optimized configuration to use in your training.

In [None]:
# Print the optimized configuration for copy-paste
print("="*60)
print("OPTIMIZED ANCHOR CONFIGURATION")
print("="*60)
print(f"""
# Option 1: Use directly in model creation
model = CustomMaskRCNN(
    num_classes={CONFIG['num_classes']},
    rpn_anchor_sizes={best_config.sizes},
    rpn_aspect_ratios={best_config.aspect_ratios},
)

# Option 2: Use with Trainer class
trainer = Trainer(
    data_root="{CONFIG['data_root']}",
    rpn_anchor_sizes={best_config.sizes},
    rpn_aspect_ratios={best_config.aspect_ratios},
)

# Option 3: Load from cached file
cached = torch.load("{CONFIG['cache_path']}")
model = CustomMaskRCNN(
    num_classes={CONFIG['num_classes']},
    rpn_anchor_sizes=cached['sizes'],
    rpn_aspect_ratios=cached['aspect_ratios'],
)

# Option 4: Use automatic optimization in Trainer
trainer = Trainer(
    data_root="{CONFIG['data_root']}",
    optimize_anchors=True,  # Will use cached or re-optimize
    anchor_cache_path="{CONFIG['cache_path']}",
)
""")

## 9. Cleanup

In [None]:
# Free memory
del train_dataset, val_dataset, analyzer, optimizer
gc.collect()
torch.cuda.empty_cache()

print("Cleanup complete!")