# PyTorch Fundamentals

This notebook covers the fundamentals of working with PyTorch for deep learning, including:
- Loading data from Hugging Face
- Applying proper transformations and normalization
- Using pre-trained models (ResNet50)
- Setting up training with appropriate optimization parameters

## Setting Up the Environment

In [6]:
# Check for or install all necessary packages with conda from environment.yml
# %conda env update -f ../environment.yml

# For colab
# !pip install datasets

In [2]:
# Import necessary libraries
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from huggingface_hub import login
from huggingface_hub import hf_hub_download
from datasets import load_dataset, Features, Value
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
import matplotlib.pyplot as plt
import numpy as np
import base64
import io
from PIL import Image
# Import F1 score calculation
from sklearn.metrics import f1_score

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [8]:
# Login to Hugging Face with token if necessary
# login()

## Data Loading from Hugging Face

In [None]:
train_dataset = load_dataset("mpg-ranch/horse-detection", split="train")

In [3]:
# Create a validation split from the training data
train_val_split = train_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = train_val_split['train']
val_dataset = train_val_split['test']

In [None]:
train_dataset["image"][0]

## Model Preprocessing Requirements

First, let's examine what preprocessing is expected by the pretrained ResNet50 model. We'll check the metadata from the model weights to understand the expected input size, normalization values, and default transforms.

In [None]:
# Get the preprocessing transforms directly from the weights
weights = ResNet50_Weights.IMAGENET1K_V1
default_transforms = weights.transforms()

# Print available metadata keys
print("Available metadata keys:")
if hasattr(weights, 'meta'):
    print(list(weights.meta.keys()))
else:
    print("No 'meta' attribute found")

# Print the preprocessing information safely
print("\nModel expects the following preprocessing:")
try:
    # Try different possible key names for input size
    if hasattr(weights, 'meta'):
        if 'input_size' in weights.meta:
            print(f"- Input size: {weights.meta['input_size']}")
        elif 'imageSize' in weights.meta:
            print(f"- Input size: {weights.meta['imageSize']}")
        else:
            print("- Input size: Not found in metadata")
            
        # Try different possible key names for mean/std
        if 'mean' in weights.meta:
            print(f"- Mean: {weights.meta['mean']}")
        else:
            print("- Mean: [0.485, 0.456, 0.406] (ImageNet standard)")
            
        if 'std' in weights.meta:
            print(f"- Std: {weights.meta['std']}")
        else:
            print("- Std: [0.229, 0.224, 0.225] (ImageNet standard)")
    else:
        print("Metadata not available, using standard ImageNet values:")
        print("- Input size: [3, 224, 224]")
        print("- Mean: [0.485, 0.456, 0.406]")
        print("- Std: [0.229, 0.224, 0.225]")
except Exception as e:
    print(f"Error accessing metadata: {e}")
    print("Using standard ImageNet values:")
    print("- Input size: [3, 224, 224]")
    print("- Mean: [0.485, 0.456, 0.406]")
    print("- Std: [0.229, 0.224, 0.225]")

# Print information about the default transforms
print("\nDefault transforms from weights:")
print(default_transforms)

## Custom Transforms for Our Dataset

Based on the model requirements, we'll now define our custom transforms for both training and validation datasets:

1. **Training transforms**: Include data augmentation (random crops, flips, rotation) to improve model generalization, along with the required normalization
2. **Validation transforms**: Only include deterministic center cropping and normalization (no augmentation) to evaluate the model on consistent inputs

Note the specific cropping approach:
- Training: 115px center crop followed by a 77px random crop to simulate our camera setup and add variation
- Validation: 77px deterministic center crop (1m equivalent in our camera setup) for consistent evaluation

Both transforms include the same normalization parameters to match the pretrained model's expectations, but validation excludes augmentation to ensure consistent evaluation.

In [None]:
# Define transformations for training data (augmentations + the model's preprocessing)
train_transforms = transforms.Compose([
    transforms.CenterCrop(115),  # First, center crop to 1.5m (115 pixels)
    transforms.RandomCrop(77),   # Then, random crop to 1m (77 pixels)
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Validation transformations (no augmentation, only resize and normalize)
val_transforms = transforms.Compose([
    transforms.CenterCrop(77),   # Deterministic crop to 1m (77 pixels)
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Print your custom transforms
print("Custom transforms defined for training and validation:")
print("\nTraining transforms:")
print(train_transforms)
print("\nValidation transforms:")
print(val_transforms)

## Applying Transformations to Dataset

In [None]:
# Function to apply transformations to the training dataset
def transform_train_dataset(examples):
    examples["pixel_values"] = [
        train_transforms(image.convert("RGB")) 
        for image in examples["image"]
    ]
    return examples

# Function to apply transformations to the validation dataset
def transform_val_dataset(examples):
    examples["pixel_values"] = [
        val_transforms(image.convert("RGB")) 
        for image in examples["image"]
    ]
    return examples

# Apply transformations to training set
transformed_train_dataset = train_dataset.map(
    transform_train_dataset,
    batched=True,
    remove_columns=["image"]  # Remove original images after transformation
)

# Apply transformations to test set
transformed_val_dataset = val_dataset.map(
    transform_val_dataset,
    batched=True,
    remove_columns=["image"]
)

# Set the format for PyTorch
transformed_train_dataset.set_format(type="torch", columns=["pixel_values", "Presence"])
transformed_val_dataset.set_format(type="torch", columns=["pixel_values", "Presence"])

## Creating DataLoaders

In [8]:
# Create DataLoaders
train_dataloader = DataLoader(
    transformed_train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

val_dataloader = DataLoader(
    transformed_val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4
)

## Visualizing a Batch

Let's visualize some images from our dataloader to verify transformations are applied correctly.

In [None]:
# Function to denormalize images for visualization
def denormalize(tensor):
    # Make sure tensor is the right shape and type
    if tensor.ndim != 3:  # If not a single image with 3 dimensions (C,H,W)
        if tensor.ndim == 4:  # If it's a batch of images (B,C,H,W)
            tensor = tensor[0]  # Take the first image
        else:
            raise ValueError(f"Unexpected tensor shape: {tensor.shape}")
    
    # Make sure we're working with the image tensor, not other data
    if tensor.shape[0] == 3:  # If first dimension is 3, it's likely the channel dimension
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        return tensor * std + mean
    else:
        raise ValueError(f"Expected 3 channels, got {tensor.shape[0]}")

# Get a batch from the dataloader
batch = next(iter(train_dataloader))

# Check the structure of the batch
print(f"Batch type: {type(batch)}")
if isinstance(batch, dict):
    print(f"Batch keys: {batch.keys()}")
    images = batch['pixel_values']  # Adjust based on your actual key
    Presences = batch['Presence']  # Adjust based on your actual key
else:
    # If it's a tuple or list, unpack accordingly
    images, Presences = batch

print(f"Images shape: {images.shape}")
print(f"Presences shape: {Presences.shape}")

# Visualize a few images from the batch
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.flatten()

for i, ax in enumerate(axes):
    if i < len(images):
        # Denormalize the image
        img = denormalize(images[i])
        img = img.permute(1, 2, 0).numpy()  # Change from CxHxW to HxWxC
        img = np.clip(img, 0, 1)  # Clip values to valid range
        
        ax.imshow(img)
        ax.set_title(f"Presence: {Presences[i].item()}")
        ax.axis("off")

plt.tight_layout()
plt.show()

## Model Architecture with ResNet50

In [None]:
# Load pre-trained ResNet50
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# Unfreeze all layers from the start
for param in model.parameters():
    param.requires_grad = True

# Modify the final fully connected layer for binary classification
model.fc = nn.Linear(model.fc.in_features, 1)  # Output a single value

# Print model architecture summary
print(f"Model: ResNet50")
print(f"Number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## Training Configuration

We'll use default settings for the optimizer and only tune the learning rate and number of epochs.

In [None]:
# Define loss function for binary classification
criterion = nn.BCEWithLogitsLoss()

# Use a smaller learning rate for fine-tuning all layers
learning_rate = 0.0001  # Reduced from 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params:,}")

# Number of epochs is another parameter we'll tune
num_epochs = 10



## Learning Rate Testing

In [None]:
# Define learning rates to test
learning_rates = [0.01, 0.001, 0.0001, 0.00001]

# Function to train model with a specific learning rate
def train_with_learning_rate(lr, num_epochs=3):
    # Initialize a fresh model (using the same architecture)
    test_model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    
    # Unfreeze all layers
    for param in test_model.parameters():
        param.requires_grad = True
    
    # Modify the final fully connected layer for binary classification
    test_model.fc = nn.Linear(test_model.fc.in_features, 1)
    test_model = test_model.to(device)
    
    # Define loss function
    criterion = nn.BCEWithLogitsLoss()
    
    # Use the specified learning rate
    optimizer = optim.Adam(test_model.parameters(), lr=lr)
    
    # Lists to store metrics
    train_losses = []
    val_losses = []
    val_accs = []
    val_f1s = []
    
    # Run for fewer epochs when testing learning rates
    for epoch in range(num_epochs):
        # Training phase
        test_model.train()
        running_loss = 0.0
        
        for batch in train_dataloader:
            # Extract data
            if isinstance(batch, dict):
                images = batch['pixel_values']
                Presences = batch['Presence']
            else:
                images, Presences = batch
            
            # Move to device and prepare
            images, Presences = images.to(device), Presences.to(device)
            Presences = Presences.float()
            
            # Forward pass
            optimizer.zero_grad()
            outputs = test_model(images)
            outputs = outputs.squeeze(1)
            loss = criterion(outputs, Presences)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Calculate average training loss
        train_loss = running_loss / len(train_dataloader)
        train_losses.append(train_loss)
        
        # Validation phase
        test_model.eval()
        val_running_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch in val_dataloader:
                # Extract data
                if isinstance(batch, dict):
                    images = batch['pixel_values']
                    Presences = batch['Presence']
                else:
                    images, Presences = batch
                
                # Move to device and prepare
                images, Presences = images.to(device), Presences.to(device)
                Presences = Presences.float()
                
                # Forward pass
                outputs = test_model(images)
                outputs = outputs.squeeze(1)
                loss = criterion(outputs, Presences)
                
                # Calculate metrics
                val_running_loss += loss.item()
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                total += Presences.size(0)
                correct += predicted.eq(Presences).sum().item()
                
                # Store predictions for F1 score
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(Presences.cpu().numpy())
        
        # Calculate validation metrics
        val_loss = val_running_loss / len(val_dataloader)
        val_acc = 100. * correct / total
        val_f1 = f1_score(all_labels, all_preds, average='binary')
        
        # Store metrics
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_f1s.append(val_f1)
        
        # Print progress
        print(f'LR: {lr:.6f}, Epoch: {epoch+1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
              f'Val Acc: {val_acc:.2f}%, Val F1: {val_f1:.4f}')
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'val_f1s': val_f1s,
        'final_val_acc': val_accs[-1],
        'final_val_f1': val_f1s[-1]
    }

# Set device (this should match the one used in the main training loop)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dictionary to store results
lr_results = {}

# Test each learning rate
for lr in learning_rates:
    print(f"\nTesting learning rate: {lr}")
    lr_results[lr] = train_with_learning_rate(lr)

# Visualize results
plt.figure(figsize=(15, 10))

# Plot training loss
plt.subplot(2, 2, 1)
for lr in learning_rates:
    plt.plot(range(1, len(lr_results[lr]['train_losses'])+1), 
             lr_results[lr]['train_losses'], 
             label=f'LR: {lr}')
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.title('Training Loss by Learning Rate')
plt.legend()
plt.grid(True)

# Plot validation loss
plt.subplot(2, 2, 2)
for lr in learning_rates:
    plt.plot(range(1, len(lr_results[lr]['val_losses'])+1), 
             lr_results[lr]['val_losses'], 
             label=f'LR: {lr}')
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss by Learning Rate')
plt.legend()
plt.grid(True)

# Plot validation accuracy
plt.subplot(2, 2, 3)
for lr in learning_rates:
    plt.plot(range(1, len(lr_results[lr]['val_accs'])+1), 
             lr_results[lr]['val_accs'], 
             label=f'LR: {lr}')
plt.xlabel('Epoch')
plt.ylabel('Validation Accuracy (%)')
plt.title('Validation Accuracy by Learning Rate')
plt.legend()
plt.grid(True)

# Plot validation F1 score
plt.subplot(2, 2, 4)
for lr in learning_rates:
    plt.plot(range(1, len(lr_results[lr]['val_f1s'])+1), 
             lr_results[lr]['val_f1s'], 
             label=f'LR: {lr}')
plt.xlabel('Epoch')
plt.ylabel('Validation F1 Score')
plt.title('Validation F1 Score by Learning Rate')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Print summary of results
print("\nSummary of Learning Rate Performance:")
print("-" * 60)
print(f"{'Learning Rate':<15} {'Final Val Acc':<15} {'Final Val F1':<15}")
print("-" * 60)
for lr in learning_rates:
    print(f"{lr:<15.6f} {lr_results[lr]['final_val_acc']:<15.2f} {lr_results[lr]['final_val_f1']:<15.4f}")

# Identify best learning rate based on validation accuracy
best_lr = max(learning_rates, key=lambda lr: lr_results[lr]['final_val_acc'])
print(f"\nBest learning rate based on validation accuracy: {best_lr}")

# Recommended learning rate to use in the main training loop
print(f"\nRecommendation: Use learning_rate = {best_lr} in the main training configuration")

## Training Loop

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

# Lists to store metrics for plotting
train_losses = []
train_accs = []
train_f1s = []  # New list for training F1 scores
val_losses = []
val_accs = []
val_f1s = []    # New list for validation F1 scores

# Check the structure of a batch to understand the data format
sample_batch = next(iter(train_dataloader))
print(f"Batch type: {type(sample_batch)}")
if isinstance(sample_batch, dict):
    print(f"Batch keys: {sample_batch.keys()}")
    # Adjust these based on your actual keys
    image_key = 'pixel_values' if 'pixel_values' in sample_batch else 'img'
    Presence_key = 'Presence' if 'Presence' in sample_batch else 'Presences'
    print(f"Using keys - Images: '{image_key}', Presences: '{Presence_key}'")
else:
    print(f"Batch is a {type(sample_batch)} with {len(sample_batch)} elements")
    for i, item in enumerate(sample_batch):
        print(f"  Item {i} type: {type(item)}")

# Training loop
for epoch in range(num_epochs):
    # Training phase
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    all_train_preds = []  # To store all training predictions
    all_train_labels = [] # To store all training labels

    for batch in train_dataloader:
        # Extract images and Presences based on the batch structure
        if isinstance(batch, dict):
            # Dictionary format (common with Hugging Face datasets)
            images = batch['pixel_values'] if 'pixel_values' in batch else batch['img']
            Presences = batch['Presence'] if 'Presence' in batch else batch['Presences']
        else:
            # Tuple/list format (common with PyTorch datasets)
            images, Presences = batch

        # Move data to device
        images, Presences = images.to(device), Presences.to(device)

        # Convert Presences to float for BCEWithLogitsLoss
        Presences = Presences.float()

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        outputs = outputs.squeeze(1)  # Change from [batch_size, 1] to [batch_size]
        loss = criterion(outputs, Presences)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()

        # For binary classification with BCEWithLogitsLoss
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        total += Presences.size(0)
        correct += predicted.eq(Presences).sum().item()

        # Collect predictions and labels for F1 calculation
        all_train_preds.extend(predicted.cpu().numpy())
        all_train_labels.extend(Presences.cpu().numpy())

    # Calculate training metrics
    train_loss = running_loss / len(train_dataloader)
    train_acc = 100. * correct / total
    train_f1 = f1_score(all_train_labels, all_train_preds, average='binary')

    # Store metrics
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    train_f1s.append(train_f1)

    # Validation phase
    model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0
    all_val_preds = []    # To store all validation predictions
    all_val_labels = []   # To store all validation labels

    with torch.no_grad():
        for batch in val_dataloader:
            # Extract images and Presences based on the batch structure
            if isinstance(batch, dict):
                # Dictionary format (common with Hugging Face datasets)
                images = batch['pixel_values'] if 'pixel_values' in batch else batch['img']
                Presences = batch['Presence'] if 'Presence' in batch else batch['Presences']
            else:
                # Tuple/list format (common with PyTorch datasets)
                images, Presences = batch

            # Move data to device
            images, Presences = images.to(device), Presences.to(device)

            # Convert Presences to float for BCEWithLogitsLoss
            Presences = Presences.float()

            outputs = model(images)
            outputs = outputs.squeeze(1)  # Change from [batch_size, 1] to [batch_size]
            loss = criterion(outputs, Presences)

            val_running_loss += loss.item()

            # For binary classification with BCEWithLogitsLoss
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            total += Presences.size(0)
            correct += predicted.eq(Presences).sum().item()

            # Collect predictions and labels for F1 calculation
            all_val_preds.extend(predicted.cpu().numpy())
            all_val_labels.extend(Presences.cpu().numpy())

    # Calculate validation metrics
    val_loss = val_running_loss / len(val_dataloader)
    val_acc = 100. * correct / total
    val_f1 = f1_score(all_val_labels, all_val_preds, average='binary')

    # Store metrics
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    val_f1s.append(val_f1)

    # Print epoch results including F1 scores
    print(f'Epoch: {epoch+1}/{num_epochs}, '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Train F1: {train_f1:.4f}, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, Val F1: {val_f1:.4f}')

## Visualizing Training Progress

In [None]:
# Plot training and validation metrics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot loss
ax1.plot(range(1, num_epochs+1), train_losses, label='Train Loss')
ax1.plot(range(1, num_epochs+1), val_losses, label='Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Plot accuracy
ax2.plot(range(1, num_epochs+1), train_accs, label='Train Accuracy')
ax2.plot(range(1, num_epochs+1), val_accs, label='Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

## Key Points to Remember

1. **Data Normalization is Critical**
   - Always normalize your input data using mean and standard deviation
   - For transfer learning with pre-trained models, use the same normalization values that were used during pre-training (e.g., ImageNet stats)

2. **Data Transformations**
   - Apply appropriate augmentations for training data (flips, rotations, crops)
   - Use only resizing and normalization for validation/test data
   - Transformations help prevent overfitting and improve model generalization

3. **Model Architecture**
   - Start with a pre-trained model like ResNet50
   - Modify only the final layer (head) to match your specific task
   - Freeze pre-trained layers initially to leverage transfer learning

4. **Optimization Settings**
   - Start with default optimizer settings
   - Focus on tuning learning rate and number of epochs first
   - Monitor validation metrics to prevent overfitting

5. **Progressive Unfreezing**
   - After initial training, you can unfreeze more layers gradually
   - Use a smaller learning rate when fine-tuning pre-trained layers

## Next Steps

1. **Hyperparameter Tuning**
   - Try different learning rates
   - Experiment with different batch sizes
   - Test different optimizers (SGD with momentum, AdamW)

2. **Model Improvements**
   - Unfreeze more layers for fine-tuning
   - Try different pre-trained architectures (EfficientNet, ViT)
   - Implement learning rate scheduling

3. **Advanced Techniques**
   - Implement data augmentation strategies like mixup or cutmix
   - Try different loss functions
   - Implement ensemble methods

# Orthomosaic Testing Section


In [3]:
# ==== Orthomosaic Testing Section ====

# Import necessary libraries
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import Dataset, Features, Image, Value
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import seaborn as sns

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define transformations
train_transforms = transforms.Compose([
    transforms.CenterCrop(115),
    transforms.RandomCrop(77),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_transforms = transforms.Compose([
    transforms.CenterCrop(77),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Transformation functions
def transform_train_dataset(examples):
    examples["pixel_values"] = [
        train_transforms(image.convert("RGB")) 
        for image in examples["image"]
    ]
    return examples

def transform_val_dataset(examples):
    examples["pixel_values"] = [
        val_transforms(image.convert("RGB")) 
        for image in examples["image"]
    ]
    return examples

# Load and prepare data
print("Loading dataset...")
original_dataset = load_dataset("mpg-ranch/horse-detection", split="train")
df = pd.DataFrame(original_dataset)

# Get list of all orthomosaics
orthomosaics = list(df['orthomosaic'].unique())
print(f"Found {len(orthomosaics)} orthomosaics")

# Dictionary to store results
orthomosaic_results = {}

# Training configuration
num_epochs = 10
learning_rate = 0.0001

# Train model for each orthomosaic
for ortho in orthomosaics:
    print(f"\nTraining model for orthomosaic: {ortho}")
    
    # Filter data for this orthomosaic
    ortho_df = df[df['orthomosaic'] == ortho]
    
    # Split into train and validation (80/20)
    ortho_df = ortho_df.sample(frac=1, random_state=42)  # Shuffle
    split_idx = int(0.8 * len(ortho_df))
    train_ortho_df = ortho_df.iloc[:split_idx]
    val_ortho_df = ortho_df.iloc[split_idx:]
    
    # Create datasets
    train_dataset = Dataset.from_dict({
        'image': train_ortho_df['image'].tolist(),
        'Presence': train_ortho_df['Presence'].tolist()
    }, features=Features({
        'image': Image(),
        'Presence': Value('int64')
    }))
    
    val_dataset = Dataset.from_dict({
        'image': val_ortho_df['image'].tolist(),
        'Presence': val_ortho_df['Presence'].tolist()
    }, features=Features({
        'image': Image(),
        'Presence': Value('int64')
    }))
    
    # Apply transformations
    transformed_train_dataset = train_dataset.map(
        transform_train_dataset,
        batched=True,
        remove_columns=["image"]
    )
    
    transformed_val_dataset = val_dataset.map(
        transform_val_dataset,
        batched=True,
        remove_columns=["image"]
    )
    
    # Set format for PyTorch
    transformed_train_dataset.set_format(type="torch", columns=["pixel_values", "Presence"])
    transformed_val_dataset.set_format(type="torch", columns=["pixel_values", "Presence"])
    
    # Create dataloaders
    train_dataloader = DataLoader(
        transformed_train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4
    )
    
    val_dataloader = DataLoader(
        transformed_val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4
    )
    
    # Initialize model
    model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    for param in model.parameters():
        param.requires_grad = True
    model.fc = nn.Linear(model.fc.in_features, 1)
    model = model.to(device)
    
    # Define loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training loop
    best_f1 = 0.0
    train_losses = []
    val_losses = []
    train_f1s = []
    val_f1s = []
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        all_train_preds = []
        all_train_labels = []
        
        for batch in train_dataloader:
            images = batch['pixel_values'].to(device)
            presences = batch['Presence'].float().to(device)
            
            optimizer.zero_grad()
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, presences)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            all_train_preds.extend(predicted.cpu().numpy())
            all_train_labels.extend(presences.cpu().numpy())
        
        # Validation phase
        model.eval()
        val_running_loss = 0.0
        all_val_preds = []
        all_val_labels = []
        
        with torch.no_grad():
            for batch in val_dataloader:
                images = batch['pixel_values'].to(device)
                presences = batch['Presence'].float().to(device)
                
                outputs = model(images).squeeze(1)
                loss = criterion(outputs, presences)
                val_running_loss += loss.item()
                
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                all_val_preds.extend(predicted.cpu().numpy())
                all_val_labels.extend(presences.cpu().numpy())
        
        # Calculate metrics
        train_loss = running_loss / len(train_dataloader)
        val_loss = val_running_loss / len(val_dataloader)
        train_f1 = f1_score(all_train_labels, all_train_preds, average='binary')
        val_f1 = f1_score(all_val_labels, all_val_preds, average='binary')
        
        # Store metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_f1s.append(train_f1)
        val_f1s.append(val_f1)
        
        # Update best F1
        if val_f1 > best_f1:
            best_f1 = val_f1
        
        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"Train F1: {train_f1:.4f}, Val F1: {val_f1:.4f}")
    
    # Store results
    orthomosaic_results[ortho] = {
        'best_val_f1': best_f1,
        'num_train_samples': len(train_dataset),
        'num_val_samples': len(val_dataset),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_f1s': train_f1s,
        'val_f1s': val_f1s
    }
    
    print(f"Completed training for {ortho}")
    print(f"Best validation F1: {best_f1:.4f}")
    print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

# Create histogram of F1 scores
plt.figure(figsize=(12, 6))
f1_scores = [results['best_val_f1'] for results in orthomosaic_results.values()]

# Create histogram
sns.histplot(f1_scores, bins=20, kde=True)
plt.axvline(np.mean(f1_scores), color='r', linestyle='--', label=f'Mean: {np.mean(f1_scores):.3f}')
plt.axvline(np.median(f1_scores), color='g', linestyle='--', label=f'Median: {np.median(f1_scores):.3f}')

# Add labels and title
plt.xlabel('F1 Score')
plt.ylabel('Count')
plt.title('Distribution of Best Validation F1 Scores Across Orthomosaics')
plt.legend()

# Add text with statistics
stats_text = f"""
Statistics:
Mean: {np.mean(f1_scores):.3f}
Median: {np.median(f1_scores):.3f}
Std: {np.std(f1_scores):.3f}
Min: {np.min(f1_scores):.3f}
Max: {np.max(f1_scores):.3f}
"""
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, 
         verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
os.makedirs('../results/figures', exist_ok=True)
print("Created or verified results/figures directory")
plt.savefig('../results/figures/f1_score_distribution.png')
plt.show()

# Print summary
print("\nSummary of Orthomosaic Models:")
print("-" * 80)
print(f"{'Orthomosaic':<20} {'Val F1':<10} {'Train Samples':<15} {'Val Samples':<15}")
print("-" * 80)
for ortho, results in orthomosaic_results.items():
    print(f"{ortho:<20} {results['best_val_f1']:.4f} {results['num_train_samples']:<15} {results['num_val_samples']:<15}")

Using device: cpu
Loading dataset...
Found 10 orthomosaics

Training model for orthomosaic: 240424_upperpartridge


Map:   0%|          | 0/1022 [00:00<?, ? examples/s]

Map:   0%|          | 0/256 [00:00<?, ? examples/s]

KeyboardInterrupt: 