# Full model pipeline for optimizing anchor sizes for RPN

## Overview

This notebook provides a complete pipeline for optimizing RPN anchor sizes using **GEOMETRIC COVERAGE**:
1. **Dataset Analysis** - Analyze object sizes and aspect ratios in your dataset
2. **Stride-Constrained Suggestions** - Get anchor recommendations respecting FPN constraints
3. **Optuna Optimization** - Find optimal anchors using theoretical recall (no training required!)
4. **Comparison** - Compare different anchor configurations

### Key Insight
Anchor optimization should find anchors that **geometrically cover** GT boxes with high IoU.
This is 100x faster than training-based optimization and more stable.

## 1. Setup and Imports

In [None]:
import os
import sys
import gc
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path.cwd().parent))

# Set memory optimization for CUDA
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install optuna if not available
try:
    import optuna
    print(f"Optuna version: {optuna.__version__}")
except ImportError:
    print("Installing optuna...")
    !pip install optuna
    import optuna

In [None]:
# Import project modules
from datasets.isaid_dataset import iSAIDDataset
from training.transforms import get_transforms
from training.anchor_optimizer import (
    AnchorConfig,
    GeometricAnchorOptimizer,  # New: uses geometric recall
    DatasetAnchorAnalyzer,
    optimize_anchors_for_dataset,
    analyze_dataset_anchors,
    compare_anchor_configs,
    generate_anchors_for_image,
    FPN_STRIDES,
)
from models.maskrcnn_model import CustomMaskRCNN, get_custom_maskrcnn

print("All modules imported successfully!")
print(f"FPN Strides: {FPN_STRIDES}")

## 2. Configuration

In [None]:
# Configuration
CONFIG = {
    # Dataset
    "data_root": "../iSAID_patches",  # Path to dataset
    "num_classes": 16,                 # 15 classes + background
    "image_size": 800,                 # Image size for training
    
    # Anchor Optimization (Geometric - much faster!)
    "n_trials": 50,                    # Number of Optuna trials
    "num_samples": 500,                # Samples for geometric recall evaluation
    
    # FPN Strides (critical for anchor constraints!)
    "strides": [4, 8, 16, 32],         # P2, P3, P4, P5
    
    # Output
    "cache_path": "../optimized_anchors.pt",  # Where to save optimized anchors
}

# Default stride-based anchor configuration (respects FPN constraints)
# Rule: anchor_size >= stride * 2 for effective detection
DEFAULT_ANCHORS = AnchorConfig(
    sizes=tuple((s*2, s*4) for s in CONFIG["strides"]),  # (8,16), (16,32), (32,64), (64,128)
    aspect_ratios=((0.5, 1.0, 2.0),) * 4,
)

print("Configuration loaded:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

print(f"\nDefault (stride-based) anchors:")
print(f"  Sizes: {DEFAULT_ANCHORS.sizes}")
print(f"  Ratios: {DEFAULT_ANCHORS.aspect_ratios}")
print(f"\nNote: Anchor size >= stride*2 for each FPN level!")

## 3. Load Dataset

In [None]:
# Load datasets (without augmentation for analysis)
print("Loading datasets...")

train_dataset = iSAIDDataset(
    CONFIG["data_root"],
    split="train",
    transforms=get_transforms(train=False),
    image_size=CONFIG["image_size"],
)

val_dataset = iSAIDDataset(
    CONFIG["data_root"],
    split="val",
    transforms=get_transforms(train=False),
    image_size=CONFIG["image_size"],
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

## 4. Dataset Analysis

First, let's analyze the object sizes and aspect ratios in the dataset to understand what anchor configurations might work best.

In [None]:
# Analyze dataset for bounding box statistics
analyzer = DatasetAnchorAnalyzer(train_dataset, num_samples=CONFIG["num_analysis_samples"])
stats = analyzer.compute_box_statistics()

print(f"\n{'='*50}")
print("Dataset Bounding Box Statistics")
print(f"{'='*50}")
print(f"Number of boxes analyzed: {len(stats['widths']):,}")
print(f"\nWidth:  min={stats['widths'].min():.1f}, max={stats['widths'].max():.1f}, mean={stats['widths'].mean():.1f}, std={stats['widths'].std():.1f}")
print(f"Height: min={stats['heights'].min():.1f}, max={stats['heights'].max():.1f}, mean={stats['heights'].mean():.1f}, std={stats['heights'].std():.1f}")
print(f"Area:   min={stats['areas'].min():.1f}, max={stats['areas'].max():.1f}, mean={stats['areas'].mean():.1f}")
print(f"Aspect Ratio: min={stats['aspect_ratios'].min():.2f}, max={stats['aspect_ratios'].max():.2f}, mean={stats['aspect_ratios'].mean():.2f}")

In [None]:
# Visualize distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Object scales (sqrt of area)
scales = np.sqrt(stats['areas'])
axes[0, 0].hist(scales, bins=50, edgecolor='black', alpha=0.7, color='steelblue')
axes[0, 0].axvline(np.median(scales), color='red', linestyle='--', label=f'Median: {np.median(scales):.1f}')
axes[0, 0].set_xlabel('Object Scale (âˆšarea)')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Distribution of Object Scales')
axes[0, 0].legend()

# Aspect ratios
axes[0, 1].hist(stats['aspect_ratios'], bins=50, edgecolor='black', alpha=0.7, color='coral')
axes[0, 1].axvline(np.median(stats['aspect_ratios']), color='red', linestyle='--', 
                   label=f'Median: {np.median(stats["aspect_ratios"]):.2f}')
axes[0, 1].set_xlabel('Aspect Ratio (w/h)')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Distribution of Aspect Ratios')
axes[0, 1].set_xlim(0, 5)
axes[0, 1].legend()

# Width vs Height scatter
sample_idx = np.random.choice(len(stats['widths']), min(2000, len(stats['widths'])), replace=False)
axes[1, 0].scatter(stats['widths'][sample_idx], stats['heights'][sample_idx], alpha=0.3, s=5)
axes[1, 0].set_xlabel('Width')
axes[1, 0].set_ylabel('Height')
axes[1, 0].set_title('Width vs Height (sample)')
axes[1, 0].set_aspect('equal')

# Scale percentiles (for anchor size selection)
percentiles = [10, 25, 50, 75, 90]
scale_percentiles = np.percentile(scales, percentiles)
axes[1, 1].barh(range(len(percentiles)), scale_percentiles, color='teal', edgecolor='black')
axes[1, 1].set_yticks(range(len(percentiles)))
axes[1, 1].set_yticklabels([f'{p}th' for p in percentiles])
axes[1, 1].set_xlabel('Object Scale')
axes[1, 1].set_title('Scale Percentiles')
for i, v in enumerate(scale_percentiles):
    axes[1, 1].text(v + 1, i, f'{v:.1f}', va='center')

plt.tight_layout()
plt.show()

print("\nScale percentiles (useful for anchor sizes):")
for p in [10, 20, 30, 50, 70, 80, 90]:
    print(f"  {p}th percentile: {np.percentile(scales, p):.1f}")

## 5. Data-Driven Anchor Suggestions

Get initial anchor recommendations based on dataset statistics (without Optuna optimization).

In [None]:
# Get stride-constrained suggestions based on data
suggested_sizes = analyzer.suggest_anchor_sizes_with_stride_constraints(
    strides=CONFIG["strides"]
)
suggested_ratios = analyzer.suggest_aspect_ratios(num_ratios=3)

print(f"\n{'='*60}")
print("STRIDE-CONSTRAINED Anchor Suggestions")
print(f"{'='*60}")
print(f"\nFPN Level | Stride | Min Valid | Suggested Sizes")
print("-" * 60)
for i, (stride, sizes) in enumerate(zip(CONFIG["strides"], suggested_sizes)):
    min_valid = stride * 2
    print(f"   P{i+2}    |   {stride:2d}   |    {min_valid:3d}    | {sizes}")

print(f"\nSuggested aspect ratios: {suggested_ratios}")

print(f"\n{'='*60}")
print("Comparison")
print(f"{'='*60}")
print(f"Default sizes:    {DEFAULT_ANCHORS.sizes}")
print(f"Suggested sizes:  {suggested_sizes}")

## 6. Geometric Anchor Optimization (FAST!)

Run Optuna optimization using **geometric recall** (theoretical IoU coverage).

This is:
- **100x faster** than training-based optimization (no model training!)
- **More stable** (no random initialization noise)
- **Physically correct** (respects FPN stride constraints)

In [None]:
# Initialize the GEOMETRIC anchor optimizer (no training required!)
geo_optimizer = GeometricAnchorOptimizer(
    dataset=train_dataset,
    image_size=(CONFIG["image_size"], CONFIG["image_size"]),
    strides=CONFIG["strides"],
    base_aspect_ratios=(0.5, 1.0, 2.0),
    num_samples=CONFIG["num_samples"],
)

print("Geometric Optimizer initialized!")
print(f"Cached {geo_optimizer.total_gt_boxes} GT boxes for fast evaluation")
print(f"Suggested ratios from data: {geo_optimizer.suggested_ratios}")

In [None]:
# Run FAST geometric optimization (no training!)
# This should take only 1-2 minutes instead of 30+ minutes
print(f"Starting GEOMETRIC optimization with {CONFIG['n_trials']} trials...")
print("This is FAST because it only computes IoU coverage, no model training!\n")

best_config = geo_optimizer.optimize(
    n_trials=CONFIG["n_trials"],
    study_name="geometric_anchor_optimization",
)

In [None]:
# Save optimized anchors
torch.save({
    'sizes': best_config.sizes,
    'aspect_ratios': best_config.aspect_ratios,
}, CONFIG["cache_path"])

print(f"Optimized anchors saved to: {CONFIG['cache_path']}")
print(f"\nBest configuration found:")
print(f"  Sizes: {best_config.sizes}")
print(f"  Aspect ratios: {best_config.aspect_ratios}")

## 7. Compare Anchor Configurations

Compare geometric recall for different anchor configurations.
This shows the **theoretical maximum recall** each configuration can achieve.

In [None]:
# Compare multiple anchor configurations using geometric recall
configs_to_compare = {
    "Default (stride-based)": DEFAULT_ANCHORS,
    "Optimized": best_config,
    "Data-suggested": AnchorConfig(
        sizes=suggested_sizes,
        aspect_ratios=(suggested_ratios,) * 4
    ),
}

print("=" * 60)
print("GEOMETRIC RECALL COMPARISON")
print("(Theoretical maximum recall - no training involved)")
print("=" * 60)

comparison_results = compare_anchor_configs(
    dataset=train_dataset,
    configs=configs_to_compare,
    image_size=(CONFIG["image_size"], CONFIG["image_size"]),
    num_samples=CONFIG["num_samples"],
)

In [None]:
# Visualize comparison
fig, ax = plt.subplots(figsize=(12, 6))

config_names = list(comparison_results.keys())
metrics = ['recall@0.5', 'recall@0.7', 'recall@0.75']
x = np.arange(len(metrics))
width = 0.25
colors = ['steelblue', 'coral', 'seagreen']

for i, (name, results) in enumerate(comparison_results.items()):
    vals = [results.get(m, 0) for m in metrics]
    bars = ax.bar(x + i*width, vals, width, label=name, color=colors[i], edgecolor='black')
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width()/2, height),
                    xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)

ax.set_ylabel('Geometric Recall (Theoretical Max)')
ax.set_title('Anchor Configuration Comparison\n(Higher is better - this is the theoretical ceiling)')
ax.set_xticks(x + width)
ax.set_xticklabels(['Recall@0.5', 'Recall@0.7', 'Recall@0.75'])
ax.legend(loc='upper right')
ax.set_ylim(0, 1)
ax.axhline(y=0.9, color='gray', linestyle='--', alpha=0.5, label='90% target')

plt.tight_layout()
plt.show()

# Print summary
print(f"\n{'='*60}")
print("SUMMARY: Geometric Recall Comparison")
print(f"{'='*60}")
for name, results in comparison_results.items():
    print(f"\n{name}:")
    for m, v in results.items():
        print(f"  {m}: {v:.4f}")

## 8. Visualize Anchor Coverage

See how anchors cover the object size distribution.

In [None]:
# Visualize anchor sizes vs object scale distribution
fig, ax = plt.subplots(figsize=(14, 6))

# Plot object scale histogram
scales = np.sqrt(stats['areas'])
ax.hist(scales, bins=100, alpha=0.5, color='gray', edgecolor='none', label='Object scales')

# Plot anchor sizes for each configuration
colors = {'Default (stride-based)': 'blue', 'Optimized': 'red', 'Data-suggested': 'green'}
fpn_levels = ['P2', 'P3', 'P4', 'P5']

for name, config in configs_to_compare.items():
    for level_idx, sizes in enumerate(config.sizes):
        for size in sizes:
            alpha = 0.7 if name == 'Optimized' else 0.4
            lw = 3 if name == 'Optimized' else 1.5
            ax.axvline(size, color=colors[name], alpha=alpha, linestyle='--', linewidth=lw)

# Add FPN stride markers
for stride in CONFIG["strides"]:
    min_anchor = stride * 2
    ax.axvline(min_anchor, color='black', alpha=0.3, linestyle=':', linewidth=1)
    ax.text(min_anchor + 1, ax.get_ylim()[1] * 0.95, f'min@s={stride}', fontsize=7, rotation=90, va='top')

ax.set_xlabel('Object Scale / Anchor Size (pixels)')
ax.set_ylabel('Frequency')
ax.set_title('Object Scale Distribution with Anchor Sizes\n(Dotted lines = minimum valid anchor for each FPN stride)')
ax.legend(['Objects', 'Default anchors', 'Optimized anchors', 'Data-suggested anchors'], loc='upper right')
ax.set_xlim(0, 300)

plt.tight_layout()
plt.show()

# Show the configurations clearly
print(f"\n{'='*60}")
print("ANCHOR CONFIGURATION DETAILS")
print(f"{'='*60}")
print(f"\nFPN Level | Stride | Min Valid |  Default  |  Optimized")
print("-" * 60)
for i, stride in enumerate(CONFIG["strides"]):
    min_valid = stride * 2
    default_sz = DEFAULT_ANCHORS.sizes[i]
    opt_sz = best_config.sizes[i]
    print(f"   P{i+2}    |   {stride:2d}   |    {min_valid:3d}    | {str(default_sz):12s} | {str(opt_sz):12s}")

## 9. Save and Usage

In [None]:
# Save optimized anchors
torch.save({
    'sizes': best_config.sizes,
    'aspect_ratios': best_config.aspect_ratios,
    'geometric_recall': comparison_results.get('Optimized', {}),
}, CONFIG["cache_path"])

print(f"Optimized anchors saved to: {CONFIG['cache_path']}")

# Print the optimized configuration for copy-paste
print("\n" + "="*60)
print("OPTIMIZED ANCHOR CONFIGURATION")
print("="*60)
print(f"""
# Option 1: Use directly in model creation
model = CustomMaskRCNN(
    num_classes={CONFIG['num_classes']},
    rpn_anchor_sizes={best_config.sizes},
    rpn_aspect_ratios={best_config.aspect_ratios},
)

# Option 2: Load from cached file
cached = torch.load("{CONFIG['cache_path']}")
model = CustomMaskRCNN(
    num_classes={CONFIG['num_classes']},
    rpn_anchor_sizes=cached['sizes'],
    rpn_aspect_ratios=cached['aspect_ratios'],
)
""")

In [None]:
# Cleanup
del train_dataset, val_dataset, analyzer, geo_optimizer
gc.collect()
torch.cuda.empty_cache()

print("Cleanup complete!")