 # Landing Strip Detection Training Pipeline



 This notebook implements a training pipeline for detecting landing strips using satellite imagery. The pipeline includes:



 - Loading input landing strip data.

 - Creating input areas around the landing strips.

 - Downloading Sentinel-2 imagery from Google Earth Engine.

 - Preparing a dataset for training.

 - Loading the Geo Foundation Model (GFM) for transfer learning.

 - Setting up a training loop with Weights & Biases (wandb) logging.



 **Note**: Ensure that you have authenticated with Google Earth Engine (GEE) using `ee.Authenticate()` and have initialized it with `ee.Initialize()`. Also, make sure `train_utils.py` is in your working directory or Python path.

# *TODO*
* Max value of model outputs can be rather small (in one case, 0.6244). This leads to binary search setting threshold lower, predicting all zeroes
* (buffered_labels.float() == 1).float()
tensor([0., 0., 0.,  ..., 0., 0., 0.])
(buffered_labels.float() == 1).float().mean()
tensor(0.2678) **(!!!)**

 ## 1. Setup and Imports

In [1]:
import sys
import os
import random
import wandb
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import timm  # PyTorch Image Models library
import logging
from pathlib import Path
import re
from torcheval.metrics.functional import binary_accuracy

# If on Google colab, chdir to /content/drive/MyDrive/Secret_Runway_Detection
try:
    from google.colab import drive
    drive.mount('/content/drive')
    # Copy the 'Secret Runway Detection Challenge' folder to Colab local storage
    !cp -r '/content/drive/MyDrive/Secret Runway Detection Challenge/colab-stuff/' '/content/'
    # Change the current working directory to the notebooks folder in local storage
    os.chdir('/content/colab-stuff/notebooks')
    USING_COLAB = True
except Exception as e:
    print(e)
    USING_COLAB = False

# Add the src directory to the sys.path
sys.path.append(os.path.abspath('..'))

# Import functions and constants from train_utils
from secret_runway_detection.model import (
    SegmentationHead,
    CombinedModel,
)
from secret_runway_detection.dataset import LandingStripDataset, SegmentationTransform
from secret_runway_detection.train_utils import (
    add_buffer_to_label,
    RANDOM_SEED
)

No module named 'google.colab'


 ## 2. Configuration and Initialization

In [2]:
# %%
# Debug flag: Set to True to run on CPU, False to use GPU if available
# With DEBUG == True, test and train sets are reduced to 10 samples each
import pandas as pd


DEBUG = False

TRAINING_DATASET = 'cross'
TRAIN_PERCENTAGE = 0.8

# Number of epochs to train for
NUM_EPOCHS = 50 if not DEBUG else 2  # Adjust as needed

BATCH_SIZE = 32 if USING_COLAB else 4

# Device configuration
device = torch.device('cpu') if DEBUG else torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# Set up logging
logging.basicConfig(level=logging.INFO)
# logging.getLogger('secret_runway_detection.train_utils').setLevel(logging.DEBUG)
logging.getLogger('secret_runway_detection.train_utils').setLevel(logging.INFO)

# Initialize wandb
wandb.init(project='secret-runway-detection',
           mode='online' if not DEBUG else 'dryrun',
           dir='..',
           tags=[TRAINING_DATASET, 'colab' if USING_COLAB else 'local'],
           job_type='train',
           )

if not wandb.run.name:
    wandb.run.name = f"Run from {pd.Timestamp.now()}"

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Using device: cpu


 ## 5. Load Data into Dataset

In [3]:
train_dir = Path(f'../training_data/training_data_{TRAINING_DATASET}')

if USING_COLAB:
    # Unzip the training data which is at f"{train_dir}.zip" using python
    import zipfile
    with zipfile.ZipFile(f"{train_dir}.zip", 'r') as zip_ref:
        zip_ref.extractall(f"{train_dir}")

images_dir = train_dir / 'images'
labels_dir = train_dir / 'labels'

# Get all filenames in the images directory
all_filenames = os.listdir(images_dir)

# Initialize dictionaries and lists
strip_to_files = {}        # For files with strip numbers
possibly_empty_files = []  # For 'possibly_empty' files

# Regular expression pattern to match filenames with strip numbers
pattern = re.compile(r'^area_\d+_of_strip_(\d+)\.npy$')

# Process filenames
for filename in all_filenames:
    if 'possibly_empty' in filename:
        # This is a 'possibly_empty' file
        possibly_empty_files.append(filename)
    else:
        # Try to match the pattern to extract strip number
        match = pattern.match(filename)
        if match:
            strip_number = int(match.group(1))
            # Add filename to the list for this strip number
            strip_to_files.setdefault(strip_number, []).append(filename)
        else:
            print(f"Filename does not match expected pattern: {filename}")

# List of all unique strip numbers
strip_numbers = list(strip_to_files.keys())

# Shuffle strip numbers for random splitting
random.seed(RANDOM_SEED)  # Ensure reproducibility
random.shuffle(strip_numbers)

# Calculate split index for strips
num_strips = len(strip_numbers)
split_index = int(num_strips * TRAIN_PERCENTAGE)

# Split strip numbers into train and test sets
train_strip_numbers = strip_numbers[:split_index]
val_strip_numbers = strip_numbers[split_index:]

# Collect filenames for train and test sets based on strip numbers
train_files = []
for strip_num in train_strip_numbers:
    train_files.extend(strip_to_files[strip_num])

val_files = []
for strip_num in val_strip_numbers:
    val_files.extend(strip_to_files[strip_num])

# Now handle the 'possibly_empty' files
# Shuffle the possibly_empty files
random.shuffle(possibly_empty_files)

# Calculate split index for possibly_empty files
num_possibly_empty = len(possibly_empty_files)
split_index_empty = int(num_possibly_empty * TRAIN_PERCENTAGE)

# Split possibly_empty files into train and test sets
train_possibly_empty_files = possibly_empty_files[:split_index_empty]
val_possibly_empty_files = possibly_empty_files[split_index_empty:]

# Add the possibly_empty files to the train and test file lists
train_files.extend(train_possibly_empty_files)
val_files.extend(val_possibly_empty_files)

# Output some information
print(f"Total files: {len(all_filenames)}")
print(f"Total strips: {len(strip_numbers)}")
print(f"Training files: {len(train_files)}")
print(f"Testing files: {len(val_files)}")

# Define your transform if you have one; otherwise, set to None
segmentation_transform = None  # Replace with your actual transform if any

# Create train dataset
train_dataset = LandingStripDataset(
    images_dir=images_dir,
    labels_dir=labels_dir,
    file_list=train_files,
    transform=segmentation_transform
)

# Create test dataset
val_dataset = LandingStripDataset(
    images_dir=images_dir,
    labels_dir=labels_dir,
    file_list=val_files,
    transform=segmentation_transform
)

if DEBUG:
    train_dataset.samples = train_dataset.samples[:10]
    val_dataset.samples = val_dataset.samples[:10]

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

Total files: 2130
Total strips: 113
Training files: 1700
Testing files: 430


 ## 6. Load the Geo Foundation Model (GFM)

In [4]:
def load_gfm_model(model_path):
    """
    Loads the Geo Foundation Model (GFM) from a checkpoint.

    Parameters:
    - model_path (str): Path to the model checkpoint.

    Returns:
    - model (torch.nn.Module): Loaded model.
    """
    model = timm.create_model(
        'swin_base_patch4_window7_224',
        pretrained=False,
        num_classes=0,  # Assuming binary classification
    ).to(device)
    checkpoint = torch.load(model_path, map_location=device)

    # Extract the state dictionary
    if 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint

    # Clean the state dictionary (remove 'module.' prefix if present)
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[len('module.'):]] = v
        else:
            new_state_dict[k] = v

    # Load the state dictionary
    model.load_state_dict(new_state_dict, strict=False)
    model = model.to(device)
    print(f"Model loaded and moved to device {device}.")
    return model.to(device)

! pip install yacs

# Path to the pre-trained GFM model
# Replace with your actual model path
backbone_model_path = '../simmim_pretrain/gfm.pth'

# Load the model
backbone_model = load_gfm_model(backbone_model_path)


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


  checkpoint = torch.load(model_path, map_location=device)


Model loaded and moved to device cpu.


## 6.1. Add Segmentation Head

In [5]:
segmentation_head = SegmentationHead()

model = CombinedModel(backbone_model, segmentation_head).to(device)

 ## 7. Define Loss Function and Optimizer

In [6]:
# Define loss function and optimizer
# Suitable for binary classification with logits
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Optionally, define a learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [7]:
for name, param in backbone_model.named_parameters():
    if not param.requires_grad:
        print(f"{name}: requires_grad={param.requires_grad}")

 ## 8. Training Loop with wandb Logging

In [8]:
# Before the training loop, watch the model
wandb.watch(model, criterion=criterion, log="all", log_freq=10)

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        outputs = outputs.squeeze(1)  # Adjust dimensions if necessary

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()

        # Log every 10 batches or last batch
        if (i + 1) % 10 == 0 or i == len(train_dataloader):
            avg_loss = running_loss / 10
            print(f"[Epoch {epoch + 1}, Batch {i + 1}] Training Loss: {avg_loss:.4f}")

            # Log metrics to wandb
            wandb.log({
                'epoch': epoch + 1,
                'batch': i + 1,
                'training_loss': avg_loss,
                'learning_rate': optimizer.param_groups[0]['lr']
            })

            running_loss = 0.0

    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)
            outputs = outputs.squeeze(1)  # Adjust dimensions if necessary

            # Compute loss
            loss = criterion(outputs, labels)

            # Accumulate validation loss
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_dataloader)
    print(f"Epoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")

    # Log validation loss to wandb
    wandb.log({
        'epoch': epoch + 1,
        'validation_loss': avg_val_loss
    })

    # Step the scheduler
    scheduler.step()

print("Training complete.")

Epoch 1 Validation Loss: 0.7344
Epoch 2 Validation Loss: 0.7354
Training complete.


## 8.1. Find accuracy-optimizing thresholds

 ## 9. Save the Trained Model

In [9]:
wandb.run.name

'Run from 2024-11-01 16:01:03.670882'

In [10]:
# Create a 'checkpoints' directory within the current directory
os.makedirs('../checkpoints', exist_ok=True)

# Define the model save path within the 'checkpoints' directory
model_save_path = f'../checkpoints/{wandb.run.name}.pth'

# Save the model's state_dict
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to '{model_save_path}'.")

# Create a W&B Artifact for the model
artifact = wandb.Artifact('model', type='model')

# Add the saved model file to the artifact
artifact.add_file(model_save_path)

# Log the artifact to W&B
wandb.log_artifact(artifact)

Model saved to '../checkpoints/Run from 2024-11-01 16:01:03.670882.pth'.


<Artifact model>

In [11]:
wandb.finish()

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁█
validation_loss,▁█

0,1
epoch,2.0
validation_loss,0.73541


 ## 10. Conclusion

In [12]:
print("""
# Training Summary

- **Model**: Swin Transformer (GFM) loaded from pre-trained checkpoint.
- **Dataset**: Landing strips with Sentinel-2 imagery.
- **Loss Function**: BCEWithLogitsLoss.
- **Optimizer**: Adam with learning rate scheduler.
- **Logging**: Weights & Biases (wandb) for experiment tracking.
- **Device**: {}
- **Epochs**: {}

Training has been completed and the model has been saved.
""".format(device, NUM_EPOCHS))


# Training Summary

- **Model**: Swin Transformer (GFM) loaded from pre-trained checkpoint.
- **Dataset**: Landing strips with Sentinel-2 imagery.
- **Loss Function**: BCEWithLogitsLoss.
- **Optimizer**: Adam with learning rate scheduler.
- **Logging**: Weights & Biases (wandb) for experiment tracking.
- **Device**: cpu
- **Epochs**: 2

Training has been completed and the model has been saved.

