# GeoIntel - MVRSD Dataset Exploration

## Military Vehicle Remote Sensing Dataset Analysis

This notebook provides comprehensive exploration of the MVRSD dataset for military vehicle detection in satellite imagery.

### Contents:
1. Dataset Loading & Verification
2. Ground Truth Visualization
3. Class Distribution Analysis
4. Object Size Analysis (Critical for Small Object Detection)
5. Tiling Strategy Visualization
6. Sample Inference Comparison (SAHI vs Standard)

In [None]:
# Setup and Imports
import sys
sys.path.append('..')

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pathlib import Path
from collections import Counter
import yaml

# GeoIntel modules
from src.data_loader import MVRSDDataLoader
from src.tiling_utils import ImageTiler, visualize_tiling_grid

%matplotlib inline
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [14, 8]

## 1. Dataset Loading & Verification

In [None]:
# Load and verify dataset structure
DATA_DIR = '../data/raw'
CONFIG_PATH = '../config/config.yaml'

loader = MVRSDDataLoader(DATA_DIR, CONFIG_PATH)
is_valid = loader.verify_structure()

if is_valid:
    print("\n✅ Dataset is ready for exploration!")
else:
    print("\n❌ Please ensure dataset is properly structured in data/raw/")

In [None]:
# Get dataset statistics
if is_valid:
    stats = loader.get_statistics()
    print(f"\nDataset Statistics:")
    print(f"  Total Images: {stats.total_images}")
    print(f"  Total Annotations: {stats.total_annotations}")
    print(f"  Average Objects/Image: {stats.avg_objects_per_image}")
    print(f"\nClass Distribution:")
    for cls, count in stats.class_distribution.items():
        print(f"    {cls}: {count}")

## 2. Ground Truth Visualization

Visualize ground truth bounding boxes on satellite images to understand the annotation quality.

In [None]:
# Class colors for visualization
CLASS_COLORS = {
    0: 'red',       # tank
    1: 'green',     # truck
    2: 'blue',      # cargo
    3: 'yellow'     # military_vehicle
}

CLASS_NAMES = ['tank', 'truck', 'cargo', 'military_vehicle']

def visualize_ground_truth(image_path, label_path, ax=None):
    """
    Visualize ground truth bounding boxes on an image.
    
    YOLO format: class_id x_center y_center width height (normalized)
    """
    # Load image
    img = cv2.imread(str(image_path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(14, 10))
    
    ax.imshow(img)
    ax.set_title(f'{Path(image_path).name} ({w}x{h})', fontsize=12)
    
    # Load and draw annotations
    if Path(label_path).exists():
        with open(label_path, 'r') as f:
            annotations = f.readlines()
        
        for ann in annotations:
            parts = ann.strip().split()
            if len(parts) != 5:
                continue
                
            class_id = int(parts[0])
            x_center, y_center, box_w, box_h = map(float, parts[1:])
            
            # Convert from normalized to pixel coordinates
            x_min = (x_center - box_w/2) * w
            y_min = (y_center - box_h/2) * h
            box_w_px = box_w * w
            box_h_px = box_h * h
            
            # Draw rectangle
            color = CLASS_COLORS.get(class_id, 'white')
            rect = patches.Rectangle(
                (x_min, y_min), box_w_px, box_h_px,
                linewidth=2, edgecolor=color, facecolor='none'
            )
            ax.add_patch(rect)
            
            # Add label
            label = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f'class_{class_id}'
            ax.text(x_min, y_min - 5, label, color=color, fontsize=8, fontweight='bold')
    
    ax.axis('off')
    return ax

In [None]:
# Visualize sample images with ground truth
if is_valid:
    images_dir = Path(DATA_DIR) / 'images'
    labels_dir = Path(DATA_DIR) / 'labels'
    
    # Get first few images
    image_files = sorted(images_dir.glob('*.jpg'))[:6]
    if not image_files:
        image_files = sorted(images_dir.glob('*.png'))[:6]
    
    if image_files:
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        for ax, img_path in zip(axes, image_files):
            label_path = labels_dir / f"{img_path.stem}.txt"
            visualize_ground_truth(img_path, label_path, ax)
        
        plt.suptitle('Ground Truth Annotations (MVRSD Dataset)', fontsize=14, y=1.02)
        plt.tight_layout()
        plt.show()
    else:
        print("No images found in dataset")

## 3. Class Distribution Analysis

In [None]:
def analyze_class_distribution(labels_dir):
    """Analyze class distribution across the dataset."""
    class_counts = Counter()
    objects_per_image = []
    
    labels_dir = Path(labels_dir)
    for label_file in labels_dir.glob('*.txt'):
        with open(label_file, 'r') as f:
            lines = [l.strip() for l in f.readlines() if l.strip()]
        
        objects_per_image.append(len(lines))
        
        for line in lines:
            parts = line.split()
            if parts:
                class_id = int(parts[0])
                class_counts[class_id] += 1
    
    return class_counts, objects_per_image

if is_valid:
    class_counts, objects_per_image = analyze_class_distribution(Path(DATA_DIR) / 'labels')
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Class distribution bar chart
    classes = [CLASS_NAMES[i] if i < len(CLASS_NAMES) else f'class_{i}' for i in sorted(class_counts.keys())]
    counts = [class_counts[i] for i in sorted(class_counts.keys())]
    colors = [CLASS_COLORS.get(i, 'white') for i in sorted(class_counts.keys())]
    
    axes[0].bar(classes, counts, color=colors, edgecolor='white')
    axes[0].set_xlabel('Class')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Class Distribution')
    for i, (cls, count) in enumerate(zip(classes, counts)):
        axes[0].text(i, count + max(counts)*0.02, str(count), ha='center', fontsize=10)
    
    # Objects per image histogram
    axes[1].hist(objects_per_image, bins=20, color='cyan', edgecolor='white', alpha=0.7)
    axes[1].set_xlabel('Objects per Image')
    axes[1].set_ylabel('Frequency')
    axes[1].set_title(f'Objects per Image Distribution (mean: {np.mean(objects_per_image):.1f})')
    axes[1].axvline(np.mean(objects_per_image), color='red', linestyle='--', label='Mean')
    axes[1].legend()
    
    plt.tight_layout()
    plt.show()

## 4. Object Size Analysis - THE SMALL OBJECT PROBLEM

This is **CRITICAL** for understanding why SAHI is necessary. We analyze:
- Object sizes relative to image dimensions
- What happens when we resize to standard YOLO input (640x640)

In [None]:
def analyze_object_sizes(images_dir, labels_dir):
    """Analyze object sizes in pixels and relative to image dimensions."""
    sizes_pixels = []  # (width, height) in pixels
    sizes_relative = []  # (width, height) relative to image
    image_dimensions = []
    
    images_dir = Path(images_dir)
    labels_dir = Path(labels_dir)
    
    for img_path in list(images_dir.glob('*.jpg'))[:100] + list(images_dir.glob('*.png'))[:100]:
        label_path = labels_dir / f"{img_path.stem}.txt"
        if not label_path.exists():
            continue
            
        # Get image dimensions
        img = cv2.imread(str(img_path))
        if img is None:
            continue
        h, w = img.shape[:2]
        image_dimensions.append((w, h))
        
        # Parse annotations
        with open(label_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) != 5:
                    continue
                _, _, _, box_w, box_h = map(float, parts)
                
                # Pixel dimensions
                w_px = box_w * w
                h_px = box_h * h
                sizes_pixels.append((w_px, h_px))
                sizes_relative.append((box_w * 100, box_h * 100))  # Percentage
    
    return sizes_pixels, sizes_relative, image_dimensions

if is_valid:
    sizes_pixels, sizes_relative, image_dims = analyze_object_sizes(
        Path(DATA_DIR) / 'images', 
        Path(DATA_DIR) / 'labels'
    )
    
    if sizes_pixels:
        sizes_pixels = np.array(sizes_pixels)
        sizes_relative = np.array(sizes_relative)
        
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
        # Object sizes in pixels
        axes[0, 0].scatter(sizes_pixels[:, 0], sizes_pixels[:, 1], alpha=0.5, s=10, c='cyan')
        axes[0, 0].set_xlabel('Width (pixels)')
        axes[0, 0].set_ylabel('Height (pixels)')
        axes[0, 0].set_title('Object Sizes (Pixels)')
        axes[0, 0].axhline(20, color='red', linestyle='--', alpha=0.7, label='20px threshold')
        axes[0, 0].axvline(20, color='red', linestyle='--', alpha=0.7)
        axes[0, 0].legend()
        
        # Histogram of object widths
        axes[0, 1].hist(sizes_pixels[:, 0], bins=50, color='cyan', edgecolor='white', alpha=0.7)
        axes[0, 1].set_xlabel('Object Width (pixels)')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].set_title(f'Object Width Distribution (median: {np.median(sizes_pixels[:, 0]):.1f}px)')
        axes[0, 1].axvline(np.median(sizes_pixels[:, 0]), color='red', linestyle='--', label='Median')
        axes[0, 1].legend()
        
        # Relative sizes
        axes[1, 0].scatter(sizes_relative[:, 0], sizes_relative[:, 1], alpha=0.5, s=10, c='yellow')
        axes[1, 0].set_xlabel('Width (% of image)')
        axes[1, 0].set_ylabel('Height (% of image)')
        axes[1, 0].set_title('Object Sizes (Relative to Image)')
        
        # What happens after resize to 640x640
        if image_dims:
            avg_img_width = np.mean([d[0] for d in image_dims])
            resize_factor = 640 / avg_img_width
            resized_sizes = sizes_pixels * resize_factor
            
            axes[1, 1].hist(resized_sizes[:, 0], bins=50, color='red', edgecolor='white', alpha=0.7)
            axes[1, 1].set_xlabel('Object Width After 640px Resize')
            axes[1, 1].set_ylabel('Frequency')
            axes[1, 1].set_title(f'⚠️ PROBLEM: After Resize to 640px (median: {np.median(resized_sizes[:, 0]):.1f}px)')
            axes[1, 1].axvline(5, color='yellow', linestyle='--', label='5px - Nearly invisible!')
            axes[1, 1].legend()
        
        plt.suptitle('THE SMALL OBJECT PROBLEM - Why SAHI is Essential', fontsize=14, y=1.02)
        plt.tight_layout()
        plt.show()
        
        # Print statistics
        print("\n" + "="*60)
        print("SMALL OBJECT ANALYSIS")
        print("="*60)
        print(f"Average object size: {np.mean(sizes_pixels[:, 0]):.1f}x{np.mean(sizes_pixels[:, 1]):.1f} px")
        print(f"Median object size: {np.median(sizes_pixels[:, 0]):.1f}x{np.median(sizes_pixels[:, 1]):.1f} px")
        print(f"Objects < 20px: {np.sum(sizes_pixels[:, 0] < 20)} ({100*np.mean(sizes_pixels[:, 0] < 20):.1f}%)")
        print(f"\nAfter resize to 640px:")
        print(f"  Median object becomes: {np.median(resized_sizes[:, 0]):.1f}x{np.median(resized_sizes[:, 1]):.1f} px")
        print(f"  Objects < 5px: {np.sum(resized_sizes[:, 0] < 5)} ({100*np.mean(resized_sizes[:, 0] < 5):.1f}%)")
        print("\n⚠️  This is why standard YOLO FAILS on satellite imagery!")
        print("✅  SAHI processes tiles at full resolution, preserving small objects.")

## 5. Tiling Strategy Visualization

Visualize how SAHI slices images into overlapping tiles.

In [None]:
def visualize_tiling_strategy(image_path, tile_sizes=[256, 512, 640], overlap=0.2):
    """Visualize different tiling strategies."""
    img = cv2.imread(str(image_path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    
    fig, axes = plt.subplots(1, len(tile_sizes) + 1, figsize=(18, 5))
    
    # Original image
    axes[0].imshow(img)
    axes[0].set_title(f'Original ({w}x{h})')
    axes[0].axis('off')
    
    # Different tile sizes
    colors = ['red', 'green', 'blue']
    for ax, tile_size, color in zip(axes[1:], tile_sizes, colors):
        ax.imshow(img, alpha=0.7)
        
        tiler = ImageTiler(tile_size=tile_size, overlap_ratio=overlap)
        n_cols, n_rows, positions = tiler.calculate_tile_grid(w, h)
        
        for x_offset, y_offset in positions:
            rect = patches.Rectangle(
                (x_offset, y_offset), tile_size, tile_size,
                linewidth=1, edgecolor=color, facecolor='none', alpha=0.8
            )
            ax.add_patch(rect)
        
        ax.set_title(f'{tile_size}px tiles\n{n_cols}x{n_rows}={len(positions)} tiles')
        ax.axis('off')
    
    plt.suptitle(f'Tiling Strategies (Overlap: {int(overlap*100)}%)', fontsize=12)
    plt.tight_layout()
    plt.show()

if is_valid:
    images_dir = Path(DATA_DIR) / 'images'
    sample_images = list(images_dir.glob('*.jpg'))[:1] or list(images_dir.glob('*.png'))[:1]
    
    if sample_images:
        visualize_tiling_strategy(sample_images[0])
    else:
        print("No images found for tiling visualization")

## 6. SAHI Inference Demo (Requires Model)

Compare standard YOLO inference vs SAHI sliced inference.

In [None]:
# This cell demonstrates SAHI inference
# Requires: pip install sahi ultralytics

DEMO_MODEL_PATH = '../models/geointel_best.pt'  # Update with your model path

def run_sahi_comparison(image_path, model_path):
    """Compare SAHI vs standard inference on a single image."""
    try:
        from src.geointel_eye import GeoIntelEye
        
        eye = GeoIntelEye(model_path=model_path)
        
        # SAHI inference
        sahi_result = eye.scan(
            image_path=str(image_path),
            output_path=None,
            visualize=False,
            export_geojson=False
        )
        
        # Standard inference
        standard_result = eye.scan_standard(str(image_path))
        
        return sahi_result, standard_result
        
    except ImportError as e:
        print(f"Missing dependency: {e}")
        print("Install with: pip install sahi ultralytics")
        return None, None
    except FileNotFoundError:
        print(f"Model not found at {model_path}")
        print("Train a model first or download pretrained weights")
        return None, None

# Uncomment to run inference comparison
# if is_valid and Path(DEMO_MODEL_PATH).exists():
#     sample_image = list((Path(DATA_DIR) / 'images').glob('*.jpg'))[0]
#     sahi_result, std_result = run_sahi_comparison(sample_image, DEMO_MODEL_PATH)
#     
#     if sahi_result and std_result:
#         print(f"\nSAHI detections: {sahi_result.total_detections}")
#         print(f"Standard detections: {std_result.total_detections}")
#         improvement = sahi_result.total_detections - std_result.total_detections
#         print(f"Improvement: +{improvement} detections")

## Summary

### Key Findings:

1. **The Small Object Problem**: Military vehicles in satellite imagery are often 20x20 pixels or smaller

2. **Why Standard YOLO Fails**: Resizing 4000x4000 images to 640x640 reduces objects to ~3-5 pixels - impossible to detect

3. **SAHI Solution**: Process overlapping 512x512 tiles, maintaining full resolution for small objects

### Recommended Configuration:
- Tile size: 512x512
- Overlap: 20%
- Model: YOLOv8-Medium
- Post-processing: NMS with IOU threshold 0.5