# Enhanced AI Model Comparison for Remote Sensing with TorchGeo

This enhanced Jupyter Notebook allows you to:
- Upload your own satellite imagery
- Compare two different pre-trained models side-by-side
- Visualize results with detailed performance metrics
- Export comparison results

Built with [TorchGeo](https://github.com/microsoft/torchgeo) for seamless geospatial deep learning.

In [None]:
# Install necessary libraries (uncomment and run if not already installed)
%pip install torch torchvision
%pip install rasterio pillow numpy pytorch-lightning torchgeo matplotlib seaborn pandas scikit-learn ipywidgets

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


## Step 1: Import Libraries and Setup

In [46]:
import os
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.transforms import Compose, Normalize, ToTensor
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output
import requests
from io import BytesIO

# TorchGeo imports
try:
    import torchgeo
    from torchgeo.datasets import *
    from torchgeo.models import get_weight, list_models
    from torchgeo.transforms import AugmentationSequential
    print(f"TorchGeo version: {torchgeo.__version__}")
except ImportError:
    print("Please install torchgeo: pip install torchgeo")
    raise

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")



TorchGeo version: 0.7.0
Using device: cuda


## Step 2: Configuration and Model Selection

In [47]:
# Available pre-trained models for comparison
available_models = {
    'ResNet-18': 'resnet18',
    'ResNet-34': 'resnet34',
    'ResNet-50': 'resnet50',
    'U-Net': 'unet'
}

loaded_models = {}
image = None
image_tensor = None

In [48]:
def load_models(output_area, selected_models):
        """Load selected models"""
        global loaded_models
        
        with output_area:
            clear_output(wait=True)
            print("Loading models...")
                        
            loaded_models = {}
            
            for model_name in selected_models:
                try:
                    print(f"Loading {model_name}...")
                    
                    # Get model architecture
                    arch_name = available_models[model_name]
                    
                    # Try to load with pretrained weights
                    if 'resnet' in arch_name:
                        from torchvision.models import resnet18, resnet50
                        if arch_name == 'resnet18':
                            model = resnet18(pretrained=True)
                        else:
                            model = resnet50(pretrained=True)
                    elif 'efficientnet' in arch_name:
                        from torchvision.models import efficientnet_b0, efficientnet_b1
                        if 'b0' in arch_name:
                            model = efficientnet_b0(pretrained=True)
                        else:
                            model = efficientnet_b1(pretrained=True)
                    elif 'vit' in arch_name:
                        from torchvision.models import vit_b_16
                        model = vit_b_16(pretrained=True)
                    elif 'swin' in arch_name:
                        from torchvision.models import swin_t
                        model = swin_t(pretrained=True)
                    else:
                        print(f"Model {model_name} not implemented, skipping...")
                        continue
                    
                    model.eval()
                    model.to(device)
                    loaded_models[model_name] = model
                    print(f"✅ {model_name} loaded successfully")
                    
                except Exception as e:
                    print(f"❌ Error loading {model_name}: {str(e)}")
            
            if loaded_models:
                print(f"\n🎉 Successfully loaded {len(loaded_models)} models!")
                loaded_models = loaded_models
                return loaded_models
            else:
                print("❌ No models loaded successfully")
                return {}            

In [49]:
def preprocess_image(image):
        """Preprocess image for model inference"""
        if image is None:
            return
        
        # Standard ImageNet preprocessing
        transform = Compose([
            transforms.Resize((224, 224)),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        return transform(image).unsqueeze(0).to(device)
    
def display_image(image, title="Current Image for Inference"):
    """Display an image"""
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.imshow(image)
    ax.set_title(title)
    ax.axis('off')
    plt.tight_layout()
    plt.show()
            
def handle_uploaded_image(uploaded_file, output_area):
    """Handle uploaded image file"""
    global current_image, image_tensor
    
    with output_area:
        clear_output(wait=True)
        print("Processing uploaded image...")
        
        try:
            # Process the uploaded file
            image = Image.open(BytesIO(uploaded_file['content'])).convert('RGB')
            current_image = image
            image_tensor = preprocess_image(image)
            display_image(image)
            print("✅ Image uploaded and processed successfully!")
            return True
            
        except Exception as e:
            print(f"❌ Error processing uploaded image: {str(e)}")
            return False
    

In [50]:
def run_inference(loaded_models, image_tensor, results_area):
    """Run inference on all loaded models"""
    if not loaded_models or image_tensor is None:
        with results_area:
            print("❌ Please load models and an image first!")
        return None
    
    with results_area:
        clear_output(wait=True)
        print("🚀 Running inference on all models...")
        print("=" * 50)
        
        results = {}
        
        # Run inference on each model
        with torch.no_grad():
            for model_name, model in loaded_models.items():
                try:
                    print(f"Running {model_name}...")
                    
                    # Forward pass
                    outputs = model(image_tensor)
                    
                    # Get probabilities
                    probabilities = F.softmax(outputs, dim=1)
                    top_probs, top_indices = torch.topk(probabilities, 5)
                    
                    results[model_name] = {
                        'probabilities': top_probs.cpu().numpy()[0],
                        'indices': top_indices.cpu().numpy()[0],
                        'raw_outputs': outputs.cpu().numpy()[0]
                    }
                    
                    print(f"✅ {model_name} completed")
                    
                except Exception as e:
                    print(f"❌ Error with {model_name}: {str(e)}")
                    results[model_name] = None
                    
        return results

In [51]:
def display_results(results):
        """Display inference results comparison"""
        print("\n" + "="*50)
        print("🎯 INFERENCE RESULTS COMPARISON")
        print("="*50)
        
        # Create comparison visualization
        valid_results = {k: v for k, v in results.items() if v is not None}
        
        if not valid_results:
            print("❌ No valid results to display")
            return
        
        # Plot comparison
        fig, axes = plt.subplots(2, len(valid_results), figsize=(5*len(valid_results), 10))
        if len(valid_results) == 1:
            axes = axes.reshape(2, 1)
        
        # Load ImageNet class names (simplified version)
        imagenet_classes = [f"Class_{i}" for i in range(1000)]  # Simplified
        
        for idx, (model_name, result) in enumerate(valid_results.items()):
            # Top predictions bar chart
            ax1 = axes[0, idx] if len(valid_results) > 1 else axes[0]
            top_probs = result['probabilities']
            top_indices = result['indices']
            
            class_names = [f"Class {i}" for i in top_indices]
            
            bars = ax1.bar(range(5), top_probs, color=plt.cm.viridis(np.linspace(0, 1, 5)))
            ax1.set_xlabel('Top 5 Classes')
            ax1.set_ylabel('Probability')
            ax1.set_title(f'{model_name}\nTop 5 Predictions')
            ax1.set_xticks(range(5))
            ax1.set_xticklabels([f"C{i}" for i in top_indices], rotation=45)
            
            # Add probability values on bars
            for bar, prob in zip(bars, top_probs):
                height = bar.get_height()
                ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{prob:.3f}', ha='center', va='bottom', fontsize=10)
            
            # Raw output distribution
            ax2 = axes[1, idx] if len(valid_results) > 1 else axes[1]
            raw_outputs = result['raw_outputs']
            
            # Show distribution of raw outputs (sample first 100 for clarity)
            sample_outputs = raw_outputs[:100] if len(raw_outputs) > 100 else raw_outputs
            ax2.hist(sample_outputs, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
            ax2.set_xlabel('Raw Output Value')
            ax2.set_ylabel('Frequency')
            ax2.set_title(f'{model_name}\nRaw Output Distribution')
            ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print detailed results
        print("\n📊 DETAILED RESULTS:")
        print("-" * 30)
        
        for model_name, result in valid_results.items():
            print(f"\n🔍 {model_name}:")
            print(f"   Top prediction: Class {result['indices'][0]} ({result['probabilities'][0]:.4f})")
            print(f"   Confidence: {result['probabilities'][0]:.1%}")
            print(f"   Top 3 classes: {', '.join([f'Class {i}' for i in result['indices'][:3]])}")
        
        # Model agreement analysis
        print("\n🤝 MODEL AGREEMENT ANALYSIS:")
        print("-" * 30)
        
        if len(valid_results) > 1:
            # Check if models agree on top prediction
            top_predictions = [result['indices'][0] for result in valid_results.values()]
            if len(set(top_predictions)) == 1:
                print("✅ All models agree on the top prediction!")
            else:
                print("⚠️  Models disagree on top prediction:")
                for model_name, result in valid_results.items():
                    print(f"   {model_name}: Class {result['indices'][0]}")
            
            # Confidence comparison
            confidences = [result['probabilities'][0] for result in valid_results.values()]
            most_confident = max(valid_results.keys(), 
                               key=lambda x: valid_results[x]['probabilities'][0])
            least_confident = min(valid_results.keys(), 
                                key=lambda x: valid_results[x]['probabilities'][0])
            
            print(f"🎯 Most confident: {most_confident} ({valid_results[most_confident]['probabilities'][0]:.1%})")
            print(f"🤔 Least confident: {least_confident} ({valid_results[least_confident]['probabilities'][0]:.1%})")


## Interactive Model and Dataset Selection

In [None]:
# Create widgets
def create_interface():
    """Create and return the interactive interface widgets"""
    
    # Model selection
    model_selector_1 = widgets.Dropdown(
        options=list(available_models.keys()),
        value='ResNet-18',
        description='Select model 1:',
        disabled=False,
    )
    
    model_selector_2 = widgets.Dropdown(
        options=list(available_models.keys()),
        value='ResNet-50',
        description='Select model 2 :',
        disabled=False,
    )
    
    # Image upload widget
    image_uploader = widgets.FileUpload(
        accept='image/*',
        multiple=False,
        description='Upload Image'
    )
    
    # Buttons
    load_models_btn = widgets.Button(
        description='Load Selected Models',
        button_style='primary',
        icon='download'
    )
    
    run_inference_btn = widgets.Button(
        description='Run Inference',
        button_style='success',
        icon='play',
        disabled=True
    )
    
    # Output areas
    output_area = widgets.Output()
    results_area = widgets.Output()
    
    # Event handlers
    def on_load_models_click(button):
        selected_models = [model_selector_1.value, model_selector_2.value]
        load_models(output_area, selected_models)
        if loaded_models:
            run_inference_btn.disabled = False if image_tensor is not None else True
    
    def on_run_inference_click(button):
        results = run_inference(loaded_models, image_tensor, results_area)
        display_results(results)
    
    def on_upload_change(change):
        if change['new']:
            uploaded_file = list(change['new'].values())[0]
            success = handle_uploaded_image(uploaded_file, output_area)
            if success and loaded_models:
                run_inference_btn.disabled = False
    
    # Attach event handlers
    load_models_btn.on_click(on_load_models_click)
    run_inference_btn.on_click(on_run_inference_click)
    image_uploader.observe(on_upload_change, names='value')
    
    return {
        'model_selector_1': model_selector_1,
        'model_selector_2': model_selector_2,
        'image_uploader': image_uploader,
        'load_models_btn': load_models_btn,
        'run_inference_btn': run_inference_btn,
        'output_area': output_area,
        'results_area': results_area
    }

# Create and display the interface
def display_interface():
    """Display the complete interface"""
    widgets_dict = create_interface()
    
    # Layout the interface
    model_box = widgets.VBox([
        widgets.HTML("<h3>1. Select Models to Compare</h3>"),
        widgets_dict['model_selector_1'],
        widgets_dict['model_selector_2'],
        widgets_dict['load_models_btn']
    ])
    
    image_box = widgets.VBox([
        widgets.HTML("<h3>2. Choose Image Source</h3>"),
        widgets.HTML("<b>Upload your own image</b>"),
        widgets_dict['image_uploader'],
    ])
    
    inference_box = widgets.VBox([
        widgets.HTML("<h3>3. Run Inference</h3>"),
        widgets_dict['run_inference_btn']
    ])
    
    interface = widgets.VBox([
        model_box,
        widgets.HTML("<hr>"),
        image_box, 
        widgets.HTML("<hr>"),
        inference_box,
        widgets.HTML("<hr>"),
        widgets.HTML("<h3>Output</h3>"),
        widgets_dict['output_area'],
        widgets_dict['results_area']
    ])
    
    display(interface)

# Initialize and display the interface
display_interface()

VBox(children=(VBox(children=(HTML(value='<h3>1. Select Models to Compare</h3>'), Dropdown(description='Select…