 # Landing Strip Detection Training Pipeline



 This notebook implements a training pipeline for detecting landing strips using satellite imagery. The pipeline includes:



 - Loading input landing strip data.

 - Creating input areas around the landing strips.

 - Downloading Sentinel-2 imagery from Google Earth Engine.

 - Preparing a dataset for training.

 - Loading the Geo Foundation Model (GFM) for transfer learning.

 - Setting up a training loop with Weights & Biases (wandb) logging.



 **Note**: Ensure that you have authenticated with Google Earth Engine (GEE) using `ee.Authenticate()` and have initialized it with `ee.Initialize()`. Also, make sure `train_utils.py` is in your working directory or Python path.

 ## 1. Setup and Imports

In [1]:
import sys
import os
import random
import ee
import wandb
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import timm  # PyTorch Image Models library
import logging
from pathlib import Path
import re

# Add the src directory to the sys.path
sys.path.append(os.path.abspath('..'))

# Append to sys path to use files in google drive from colab
sys.path.append('/content/drive/MyDrive/Secret_Runway_Detection')

# Import functions and constants from train_utils
from secret_runway_detection.model import (
    SegmentationHead,
    CombinedModel,
)
from secret_runway_detection.dataset import LandingStripDataset, SegmentationTransform
from secret_runway_detection.train_utils import (
    RANDOM_SEED
)

 ## 2. Configuration and Initialization

In [2]:
# %%
# Debug flag: Set to True to run on CPU, False to use GPU if available
# With DEBUG == True, test and train sets are reduced to 10 samples each
DEBUG = True

TRAINING_DATASET = 'point'
TRAIN_PERCENTAGE = 0.8

# Number of epochs to train for
NUM_EPOCHS = 10 if not DEBUG else 1  # Adjust as needed

# Device configuration
device = torch.device('cpu') if DEBUG else torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# Set up logging
logging.basicConfig(level=logging.INFO)
# logging.getLogger('secret_runway_detection.train_utils').setLevel(logging.DEBUG)
logging.getLogger('secret_runway_detection.train_utils').setLevel(logging.INFO)

# Initialize wandb
wandb.init(project='secret-runway-detection',
           mode='online' if not DEBUG else 'dryrun')

# Authenticate and initialize Earth Engine
ee.Authenticate()
ee.Initialize()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Using device: cpu


 ## 5. Load Data into Dataset

In [3]:
train_dir = Path(f'../training_data_{TRAINING_DATASET}')

images_dir = train_dir / 'images'
labels_dir = train_dir / 'labels'

# Get all filenames in the images directory
all_filenames = os.listdir(images_dir)

# Initialize dictionaries and lists
strip_to_files = {}        # For files with strip numbers
possibly_empty_files = []  # For 'possibly_empty' files

# Regular expression pattern to match filenames with strip numbers
pattern = re.compile(r'^area_\d+_of_strip_(\d+)\.npy$')

# Process filenames
for filename in all_filenames:
    if 'possibly_empty' in filename:
        # This is a 'possibly_empty' file
        possibly_empty_files.append(filename)
    else:
        # Try to match the pattern to extract strip number
        match = pattern.match(filename)
        if match:
            strip_number = int(match.group(1))
            # Add filename to the list for this strip number
            strip_to_files.setdefault(strip_number, []).append(filename)
        else:
            print(f"Filename does not match expected pattern: {filename}")

# List of all unique strip numbers
strip_numbers = list(strip_to_files.keys())

# Shuffle strip numbers for random splitting
random.seed(RANDOM_SEED)  # Ensure reproducibility
random.shuffle(strip_numbers)

# Calculate split index for strips
num_strips = len(strip_numbers)
split_index = int(num_strips * TRAIN_PERCENTAGE)

# Split strip numbers into train and test sets
train_strip_numbers = strip_numbers[:split_index]
test_strip_numbers = strip_numbers[split_index:]

# Collect filenames for train and test sets based on strip numbers
train_files = []
for strip_num in train_strip_numbers:
    train_files.extend(strip_to_files[strip_num])

test_files = []
for strip_num in test_strip_numbers:
    test_files.extend(strip_to_files[strip_num])

# Now handle the 'possibly_empty' files
# Shuffle the possibly_empty files
random.shuffle(possibly_empty_files)

# Calculate split index for possibly_empty files
num_possibly_empty = len(possibly_empty_files)
split_index_empty = int(num_possibly_empty * TRAIN_PERCENTAGE)

# Split possibly_empty files into train and test sets
train_possibly_empty_files = possibly_empty_files[:split_index_empty]
test_possibly_empty_files = possibly_empty_files[split_index_empty:]

# Add the possibly_empty files to the train and test file lists
train_files.extend(train_possibly_empty_files)
test_files.extend(test_possibly_empty_files)

# Output some information
print(f"Total files: {len(all_filenames)}")
print(f"Total strips: {len(strip_numbers)}")
print(f"Training files: {len(train_files)}")
print(f"Testing files: {len(test_files)}")

# Define your transform if you have one; otherwise, set to None
segmentation_transform = None  # Replace with your actual transform if any

# Create train dataset
train_dataset = LandingStripDataset(
    images_dir=images_dir,
    labels_dir=labels_dir,
    file_list=train_files,
    transform=segmentation_transform
)

# Create test dataset
test_dataset = LandingStripDataset(
    images_dir=images_dir,
    labels_dir=labels_dir,
    file_list=test_files,
    transform=segmentation_transform
)

if DEBUG:
    train_dataset.samples = train_dataset.samples[:10]
    test_dataset.samples = test_dataset.samples[:10]

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Total files: 472
Total strips: 94
Training files: 376
Testing files: 96


 ## 6. Load the Geo Foundation Model (GFM)

In [4]:
def load_gfm_model(model_path):
    """
    Loads the Geo Foundation Model (GFM) from a checkpoint.

    Parameters:
    - model_path (str): Path to the model checkpoint.

    Returns:
    - model (torch.nn.Module): Loaded model.
    """
    model = timm.create_model(
        'swin_base_patch4_window7_224',
        pretrained=False,
        num_classes=0,  # Assuming binary classification
    )
    checkpoint = torch.load(model_path, map_location='cpu')

    # Extract the state dictionary
    if 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint

    # Clean the state dictionary (remove 'module.' prefix if present)
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[len('module.'):]] = v
        else:
            new_state_dict[k] = v

    # Load the state dictionary
    model.load_state_dict(new_state_dict, strict=False)
    model = model.to(device)
    print("Model loaded and moved to device.")
    return model


# Path to the pre-trained GFM model
# Replace with your actual model path
backbone_model_path = '../simmim_pretrain/gfm.pth'

# Load the model
backbone_model = load_gfm_model(backbone_model_path)

  checkpoint = torch.load(model_path, map_location='cpu')


Model loaded and moved to device.


## 6.1. Add Segmentation Head

In [5]:
segmentation_head = SegmentationHead()

model = CombinedModel(backbone_model, segmentation_head)

 ## 7. Define Loss Function and Optimizer

In [6]:
# Define loss function and optimizer
# Suitable for binary classification with logits
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Optionally, define a learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

 ## 8. Training Loop with wandb Logging

In [7]:
# Before the training loop, watch the model
wandb.watch(model, criterion=criterion, log="all", log_freq=10)

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        outputs = outputs.squeeze(1)  # Adjust dimensions if necessary

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()

        # Log every 10 batches or last batch
        if (i + 1) % 10 == 0 or i == len(train_dataloader):
            avg_loss = running_loss / 10
            print(f"[Epoch {epoch + 1}, Batch {i + 1}] Training Loss: {avg_loss:.4f}")

            # Log metrics to wandb
            wandb.log({
                'epoch': epoch + 1,
                'batch': i + 1,
                'training_loss': avg_loss,
                'learning_rate': optimizer.param_groups[0]['lr']
            })

            running_loss = 0.0

    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)
            outputs = outputs.squeeze(1)  # Adjust dimensions if necessary

            # Compute loss
            loss = criterion(outputs, labels)

            # Accumulate validation loss
            val_loss += loss.item()

    avg_val_loss = val_loss / len(test_dataloader)
    print(f"Epoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")

    # Log validation loss to wandb
    wandb.log({
        'epoch': epoch + 1,
        'validation_loss': avg_val_loss
    })

    # Step the scheduler
    scheduler.step()

print("Training complete.")

Epoch 1 Validation Loss: 0.9336
Training complete.


In [8]:
wandb.finish()

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁
validation_loss,▁

0,1
epoch,1.0
validation_loss,0.93359


 ## 9. Save the Trained Model

In [9]:
# Save the trained model
model_save_path = '../checkpoints/trained_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to '{model_save_path}'.")

Model saved to '../checkpoints/trained_model.pth'.


 ## 10. Conclusion

In [10]:
print("""
# Training Summary

- **Model**: Swin Transformer (GFM) loaded from pre-trained checkpoint.
- **Dataset**: Landing strips with Sentinel-2 imagery.
- **Loss Function**: BCEWithLogitsLoss.
- **Optimizer**: Adam with learning rate scheduler.
- **Logging**: Weights & Biases (wandb) for experiment tracking.
- **Device**: {}
- **Epochs**: {}

Training has been completed and the model has been saved.
""".format(device, NUM_EPOCHS))


# Training Summary

- **Model**: Swin Transformer (GFM) loaded from pre-trained checkpoint.
- **Dataset**: Landing strips with Sentinel-2 imagery.
- **Loss Function**: BCEWithLogitsLoss.
- **Optimizer**: Adam with learning rate scheduler.
- **Logging**: Weights & Biases (wandb) for experiment tracking.
- **Device**: cpu
- **Epochs**: 1

Training has been completed and the model has been saved.

