 # Landing Strip Detection Training Pipeline



 This notebook implements a training pipeline for detecting landing strips using satellite imagery. The pipeline includes:



 - Loading input landing strip data.

 - Creating input areas around the landing strips.

 - Downloading Sentinel-2 imagery from Google Earth Engine.

 - Preparing a dataset for training.

 - Loading the Geo Foundation Model (GFM) for transfer learning.

 - Setting up a training loop with Weights & Biases (wandb) logging.



 **Note**: Ensure that you have authenticated with Google Earth Engine (GEE) using `ee.Authenticate()` and have initialized it with `ee.Initialize()`. Also, make sure `train_utils.py` is in your working directory or Python path.

 ## 1. Setup and Imports

In [1]:
# %%
import sys
import os
import random
import ee
import wandb
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import timm  # PyTorch Image Models library
import logging
from pathlib import Path

# Add the src directory to the sys.path
sys.path.append(os.path.abspath('..'))

# Import functions and constants from train_utils
from secret_runway_detection.train_utils import (
    TILES_PER_AREA_LEN,
    RANDOM_SEED
)

from secret_runway_detection.dataset import LandingStripDataset

from secret_runway_detection.model import (
    SegmentationHead,
    CombinedModel,
)


 ## 2. Configuration and Initialization

In [2]:
# %%
# Debug flag: Set to True to run on CPU, False to use GPU if available
DEBUG = True

# Number of epochs to train for
NUM_EPOCHS = 10  # Adjust as needed

# Device configuration
device = torch.device('cpu') if DEBUG else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# Set up logging
logging.basicConfig(level=logging.INFO)
# logging.getLogger('secret_runway_detection.train_utils').setLevel(logging.DEBUG)
logging.getLogger('secret_runway_detection.train_utils').setLevel(logging.INFO)

# Initialize wandb
wandb.init(project='secret-runway-detection', mode='online' if not DEBUG else 'dryrun')

# Authenticate and initialize Earth Engine
ee.Authenticate()
ee.Initialize()


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Using device: cpu


 ## 5. Load Data into Dataset

In [3]:
# Define paths to the images and labels directories
images_dir = Path('../training_data/images')
labels_dir = Path('../training_data/labels')

# # Optionally, define transformations
# transform = transforms.Compose([
#     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Standard ImageNet normalization
#                          std=[0.229, 0.224, 0.225])
# ])

# Create the dataset
dataset = LandingStripDataset(images_dir, labels_dir)
print(f"Dataset size: {len(dataset)} samples")

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


Dataset size: 9 samples


 ## 6. Load the Geo Foundation Model (GFM)

In [4]:
def load_gfm_model(model_path):
    """
    Loads the Geo Foundation Model (GFM) from a checkpoint.
    
    Parameters:
    - model_path (str): Path to the model checkpoint.
    
    Returns:
    - model (torch.nn.Module): Loaded model.
    """
    model = timm.create_model(
        'swin_base_patch4_window7_224',
        pretrained=False,
        num_classes=0,  # Assuming binary classification
    )
    checkpoint = torch.load(model_path, map_location='cpu')
    
    # Extract the state dictionary
    if 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    
    # Clean the state dictionary (remove 'module.' prefix if present)
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[len('module.'):]] = v
        else:
            new_state_dict[k] = v
    
    # Load the state dictionary
    model.load_state_dict(new_state_dict, strict=False)
    model = model.to(device)
    print("Model loaded and moved to device.")
    return model

# Path to the pre-trained GFM model
backbone_model_path = '../simmim_pretrain/gfm.pth'  # Replace with your actual model path

# Load the model
backbone_model = load_gfm_model(backbone_model_path)


  checkpoint = torch.load(model_path, map_location='cpu')


Model loaded and moved to device.


## 6.1 Add Segmentation Head

In [13]:
segmentation_head = SegmentationHead()

model = CombinedModel(backbone_model, segmentation_head)

In [20]:
# Time one forward pass
import time
start_time = time.time()

model(dataset[0][0].unsqueeze(0))

print(f"Time taken for forward pass: {time.time() - start_time:.2f} seconds")
print(dataset[0][1].shape)


Time taken for forward pass: 0.39 seconds
torch.Size([200, 200])


 ## 7. Define Loss Function and Optimizer

In [8]:
# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Suitable for binary classification with logits
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Optionally, define a learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)


 ## 8. Training Loop with wandb Logging

In [9]:
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        outputs = outputs.squeeze(1)  # Adjust dimensions if necessary
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        if i % 10 == 9 or i == len(dataloader) - 1:  # Log every 10 batches or last batch
            avg_loss = running_loss / 10
            print(f"[Epoch {epoch + 1}, Batch {i + 1}] Loss: {avg_loss:.4f}")
            wandb.log({'epoch': epoch + 1, 'batch': i + 1, 'loss': avg_loss})
            running_loss = 0.0
    
    # Step the scheduler
    scheduler.step()
    
    # Optionally, log learning rate
    current_lr = optimizer.param_groups[0]['lr']
    wandb.log({'learning_rate': current_lr})
    print(f"Epoch {epoch + 1} completed. Learning Rate: {current_lr}")

print("Training complete.")


[Epoch 1, Batch 3] Loss: 0.2594
Epoch 1 completed. Learning Rate: 0.0001
[Epoch 2, Batch 3] Loss: 0.2598
Epoch 2 completed. Learning Rate: 0.0001
[Epoch 3, Batch 3] Loss: 0.2570
Epoch 3 completed. Learning Rate: 0.0001
[Epoch 4, Batch 3] Loss: 0.2569
Epoch 4 completed. Learning Rate: 0.0001
[Epoch 5, Batch 3] Loss: 0.2561
Epoch 5 completed. Learning Rate: 0.0001
[Epoch 6, Batch 3] Loss: 0.2540
Epoch 6 completed. Learning Rate: 0.0001
[Epoch 7, Batch 3] Loss: 0.2553
Epoch 7 completed. Learning Rate: 1e-05
[Epoch 8, Batch 3] Loss: 0.2537
Epoch 8 completed. Learning Rate: 1e-05
[Epoch 9, Batch 3] Loss: 0.2527
Epoch 9 completed. Learning Rate: 1e-05
[Epoch 10, Batch 3] Loss: 0.2528
Epoch 10 completed. Learning Rate: 1e-05
Training complete.


 ## 9. Save the Trained Model

In [10]:
# Save the trained model
model_save_path = '../checkpoints/trained_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to '{model_save_path}'.")


Model saved to 'trained_model.pth'.


 ## 10. Conclusion

In [11]:
print("""
# Training Summary

- **Model**: Swin Transformer (GFM) loaded from pre-trained checkpoint.
- **Dataset**: Landing strips with Sentinel-2 imagery.
- **Loss Function**: BCEWithLogitsLoss.
- **Optimizer**: Adam with learning rate scheduler.
- **Logging**: Weights & Biases (wandb) for experiment tracking.
- **Device**: {}
- **Epochs**: {}

Training has been completed and the model has been saved.
""".format(device, NUM_EPOCHS))



# Training Summary

- **Model**: Swin Transformer (GFM) loaded from pre-trained checkpoint.
- **Dataset**: Landing strips with Sentinel-2 imagery.
- **Loss Function**: BCEWithLogitsLoss.
- **Optimizer**: Adam with learning rate scheduler.
- **Logging**: Weights & Biases (wandb) for experiment tracking.
- **Device**: cpu
- **Epochs**: 10

Training has been completed and the model has been saved.

