# Enhanced AI Model Comparison for Remote Sensing with TorchGeo

This enhanced Jupyter Notebook allows you to:
- Upload your own satellite imagery or use sample datasets
- Compare two different pre-trained models side-by-side
- Visualize results with detailed performance metrics
- Export comparison results

Built with [TorchGeo](https://github.com/microsoft/torchgeo) for seamless geospatial deep learning.

In [None]:
# Install necessary libraries (uncomment and run if not already installed)
# !pip install torch torchvision torchgeo matplotlib seaborn pandas scikit-learn ipywidgets
# !pip install rasterio pillow numpy pytorch-lightning

## Step 1: Import Libraries and Setup

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from pathlib import Path
import time
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# TorchGeo imports
from torchgeo.datasets import EuroSAT, RESISC45
from torchgeo.trainers import ClassificationTask
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, random_split

# For file upload functionality
import ipywidgets as widgets
from IPython.display import display, clear_output
from PIL import Image
import io

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Step 2: Configuration and Model Selection

In [None]:
# Available pre-trained models for comparison
AVAILABLE_MODELS = {
    'ResNet-18': 'resnet18',
    'ResNet-34': 'resnet34',
    'ResNet-50': 'resnet50',
    'EfficientNet-B0': 'efficientnet_b0',
    'EfficientNet-B1': 'efficientnet_b1',
    'MobileNet-V2': 'mobilenet_v2',
    'DenseNet-121': 'densenet121',
    'VGG-16': 'vgg16'
}

# Dataset options
AVAILABLE_DATASETS = {
    'EuroSAT': 'eurosat',
    'RESISC45': 'resisc45',
    'Custom Upload': 'custom'
}

# Configuration class
class ModelComparisonConfig:
    def __init__(self):
        self.model_1 = 'resnet18'
        self.model_2 = 'resnet34'
        self.dataset = 'eurosat'
        self.batch_size = 32
        self.num_epochs = 5
        self.learning_rate = 1e-3
        self.test_split = 0.2
        
config = ModelComparisonConfig()
print("Configuration initialized successfully!")

## Step 3: Interactive Model and Dataset Selection

In [None]:
# Create interactive widgets for model selection
model1_dropdown = widgets.Dropdown(
    options=list(AVAILABLE_MODELS.keys()),
    value='ResNet-18',
    description='Model 1:',
    style={'description_width': 'initial'}
)

model2_dropdown = widgets.Dropdown(
    options=list(AVAILABLE_MODELS.keys()),
    value='ResNet-34',
    description='Model 2:',
    style={'description_width': 'initial'}
)

dataset_dropdown = widgets.Dropdown(
    options=list(AVAILABLE_DATASETS.keys()),
    value='EuroSAT',
    description='Dataset:',
    style={'description_width': 'initial'}
)

epochs_slider = widgets.IntSlider(
    value=5,
    min=1,
    max=20,
    step=1,
    description='Epochs:',
    style={'description_width': 'initial'}
)

batch_size_dropdown = widgets.Dropdown(
    options=[16, 32, 64, 128],
    value=32,
    description='Batch Size:',
    style={'description_width': 'initial'}
)

# File upload widget for custom datasets
file_upload = widgets.FileUpload(
    accept='.zip,.tar,.tar.gz',
    multiple=False,
    description='Upload Dataset',
    style={'description_width': 'initial'}
)

# Update configuration function
def update_config(*args):
    config.model_1 = AVAILABLE_MODELS[model1_dropdown.value]
    config.model_2 = AVAILABLE_MODELS[model2_dropdown.value]
    config.dataset = AVAILABLE_DATASETS[dataset_dropdown.value]
    config.num_epochs = epochs_slider.value
    config.batch_size = batch_size_dropdown.value

# Attach observers
model1_dropdown.observe(update_config, names='value')
model2_dropdown.observe(update_config, names='value')
dataset_dropdown.observe(update_config, names='value')
epochs_slider.observe(update_config, names='value')
batch_size_dropdown.observe(update_config, names='value')

# Display widgets
display(widgets.VBox([
    widgets.HTML("<h3>Model Comparison Configuration</h3>"),
    widgets.HBox([model1_dropdown, model2_dropdown]),
    dataset_dropdown,
    widgets.HBox([epochs_slider, batch_size_dropdown]),
    file_upload
]))

update_config()  # Initialize config

## Step 4: Data Loading and Preprocessing

In [None]:
def load_dataset(dataset_type, custom_path=None):
    """Load and prepare dataset based on selection"""
    transforms = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    if dataset_type == 'eurosat':
        dataset = EuroSAT(root="data/eurosat", download=True, transforms=transforms)
        class_names = [
            'AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway',
            'Industrial', 'Pasture', 'PermanentCrop', 'Residential',
            'River', 'SeaLake'
        ]
    elif dataset_type == 'resisc45':
        dataset = RESISC45(root="data/resisc45", download=True, transforms=transforms)
        class_names = [f"Class_{i}" for i in range(45)]  # Simplified for demo
    elif dataset_type == 'custom' and custom_path:
        # Handle custom dataset loading
        # This would need to be implemented based on your custom data format
        print("Custom dataset loading not fully implemented in this demo")
        return None, None, None
    else:
        raise ValueError(f"Unknown dataset type: {dataset_type}")
    
    # Split dataset
    total_size = len(dataset)
    test_size = int(config.test_split * total_size)
    train_size = total_size - test_size
    
    train_dataset, test_dataset = random_split(
        dataset, [train_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    return train_dataset, test_dataset, class_names

def create_data_loaders(train_dataset, test_dataset, batch_size):
    """Create data loaders for training and testing"""
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=2
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=2
    )
    return train_loader, test_loader

print("Data loading functions defined successfully!")

## Step 5: Model Setup and Training Functions

In [None]:
def create_model(model_name, num_classes):
    """Create and configure a model"""
    if model_name.startswith('resnet'):
        model = getattr(models, model_name)(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif model_name.startswith('efficientnet'):
        model = getattr(models, model_name)(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif model_name == 'mobilenet_v2':
        model = models.mobilenet_v2(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif model_name.startswith('densenet'):
        model = getattr(models, model_name)(pretrained=True)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    elif model_name == 'vgg16':
        model = models.vgg16(pretrained=True)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    else:
        raise ValueError(f"Model {model_name} not supported")
    
    return model.to(device)

def train_model(model, train_loader, test_loader, num_epochs, learning_rate, model_name):
    """Train a model and return training history"""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'epoch_times': []
    }
    
    print(f"\nTraining {model_name}...")
    
    for epoch in range(num_epochs):
        start_time = time.time()
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch_idx, batch in enumerate(train_loader):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in test_loader:
                images = batch['image'].to(device)
                labels = batch['label'].to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        # Calculate metrics
        epoch_time = time.time() - start_time
        train_acc = 100 * train_correct / train_total
        val_acc = 100 * val_correct / val_total
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(test_loader)
        
        # Store history
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_acc)
        history['epoch_times'].append(epoch_time)
        
        print(f"Epoch [{epoch+1}/{num_epochs}] - "
              f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%, "
              f"Time: {epoch_time:.2f}s")
    
    return model, history

def evaluate_model(model, test_loader, class_names):
    """Evaluate model and return detailed metrics"""
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted'
    )
    cm = confusion_matrix(all_labels, all_predictions)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'predictions': all_predictions,
        'labels': all_labels
    }

print("Model training and evaluation functions defined successfully!")

## Step 6: Visualization Functions

In [None]:
def plot_training_history(history1, history2, model1_name, model2_name):
    """Plot training history comparison"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    epochs = range(1, len(history1['train_loss']) + 1)
    
    # Training Loss
    axes[0, 0].plot(epochs, history1['train_loss'], 'b-', label=f'{model1_name}')
    axes[0, 0].plot(epochs, history2['train_loss'], 'r-', label=f'{model2_name}')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Validation Loss
    axes[0, 1].plot(epochs, history1['val_loss'], 'b-', label=f'{model1_name}')
    axes[0, 1].plot(epochs, history2['val_loss'], 'r-', label=f'{model2_name}')
    axes[0, 1].set_title('Validation Loss')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Training Accuracy
    axes[1, 0].plot(epochs, history1['train_acc'], 'b-', label=f'{model1_name}')
    axes[1, 0].plot(epochs, history2['train_acc'], 'r-', label=f'{model2_name}')
    axes[1, 0].set_title('Training Accuracy')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Validation Accuracy
    axes[1, 1].plot(epochs, history1['val_acc'], 'b-', label=f'{model1_name}')
    axes[1, 1].plot(epochs, history2['val_acc'], 'r-', label=f'{model2_name}')
    axes[1, 1].set_title('Validation Accuracy')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy (%)')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()

def plot_confusion_matrices(eval1, eval2, model1_name, model2_name, class_names):
    """Plot confusion matrices side by side"""
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    
    # Model 1 Confusion Matrix
    sns.heatmap(eval1['confusion_matrix'], annot=True, fmt='d', 
                xticklabels=class_names, yticklabels=class_names,
                ax=axes[0], cmap='Blues')
    axes[0].set_title(f'{model1_name} Confusion Matrix')
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('Actual')
    
    # Model 2 Confusion Matrix
    sns.heatmap(eval2['confusion_matrix'], annot=True, fmt='d',
                xticklabels=class_names, yticklabels=class_names,
                ax=axes[1], cmap='Reds')
    axes[1].set_title(f'{model2_name} Confusion Matrix')
    axes[1].set_xlabel('Predicted')
    axes[1].set_ylabel('Actual')
    
    plt.tight_layout()
    plt.show()

def create_comparison_table(eval1, eval2, history1, history2, model1_name, model2_name):
    """Create a comprehensive comparison table"""
    comparison_data = {
        'Metric': [
            'Final Validation Accuracy (%)',
            'Test Accuracy (%)',
            'Precision',
            'Recall',
            'F1-Score',
            'Avg Training Time per Epoch (s)',
            'Total Training Time (s)'
        ],
        model1_name: [
            f"{history1['val_acc'][-1]:.2f}",
            f"{eval1['accuracy']*100:.2f}",
            f"{eval1['precision']:.4f}",
            f"{eval1['recall']:.4f}",
            f"{eval1['f1_score']:.4f}",
            f"{np.mean(history1['epoch_times']):.2f}",
            f"{sum(history1['epoch_times']):.2f}"
        ],
        model2_name: [
            f"{history2['val_acc'][-1]:.2f}",
            f"{eval2['accuracy']*100:.2f}",
            f"{eval2['precision']:.4f}",
            f"{eval2['recall']:.4f}",
            f"{eval2['f1_score']:.4f}",
            f"{np.mean(history2['epoch_times']):.2f}",
            f"{sum(history2['epoch_times']):.2f}"
        ]
    }
    
    df = pd.DataFrame(comparison_data)
    return df

def visualize_sample_predictions(model1, model2, test_loader, class_names, 
                               model1_name, model2_name, num_samples=8):
    """Visualize sample predictions from both models"""
    model1.eval()
    model2.eval()
    
    # Get a batch of test data
    data_iter = iter(test_loader)
    batch = next(data_iter)
    images = batch['image'][:num_samples].to(device)
    labels = batch['label'][:num_samples]
    
    with torch.no_grad():
        outputs1 = model1(images)
        outputs2 = model2(images)
        _, pred1 = torch.max(outputs1, 1)
        _, pred2 = torch.max(outputs2, 1)
    
    # Plot results
    fig, axes = plt.subplots(2, num_samples//2, figsize=(20, 8))
    axes = axes.flatten()
    
    for i in range(num_samples):
        # Denormalize image for display
        img = images[i].cpu()
        img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        img = img.permute(1, 2, 0).clamp(0, 1).numpy()
        
        # Model 1 predictions
        axes[i].imshow(img)
        axes[i].set_title(f"{model1_name}: {class_names[pred1[i]]} (True: {class_names[labels[i]]})")
        axes[i].axis('off')
        
        # Model 2 predictions
        axes[i + num_samples//2].imshow(img)
        axes[i + num_samples//2].set_title(f"{model2_name}: {class_names[pred2[i]]} (True: {class_names[labels[i]]})")
        axes[i + num_samples//2].axis('off')
        
    plt.tight_layout()
    plt.show()

print("Visualization functions defined successfully!")