 # Landing Strip Detection Training Pipeline



 This notebook implements a training pipeline for detecting landing strips using satellite imagery. The pipeline includes:



 - Loading input landing strip data.

 - Creating input areas around the landing strips.

 - Downloading Sentinel-2 imagery from Google Earth Engine.

 - Preparing a dataset for training.

 - Loading the Geo Foundation Model (GFM) for transfer learning.

 - Setting up a training loop with Weights & Biases (wandb) logging.



 **Note**: Ensure that you have authenticated with Google Earth Engine (GEE) using `ee.Authenticate()` and have initialized it with `ee.Initialize()`. Also, make sure `train_utils.py` is in your working directory or Python path.

 ## 1. Setup and Imports

In [5]:
# Check if the secret_runway_detection package is installed
!pip list

import secret_runway_detection



Package                     Version     Editable project location
--------------------------- ----------- ------------------------------------------
aenum                       3.1.15
affine                      2.4.0
aiohappyeyeballs            2.4.3
aiohttp                     3.10.10
aiosignal                   1.3.1
annotated-types             0.7.0
antlr4-python3-runtime      4.9.3
asttokens                   2.4.1
attrs                       24.2.0
bitsandbytes                0.44.1
blessings                   1.7
branca                      0.8.0
cachetools                  5.5.0
certifi                     2024.8.30
charset-normalizer          3.4.0
click                       8.1.7
click-plugins               1.1.1
cligj                       0.7.2
comm                        0.2.2
contextily                  1.6.2
contourpy                   1.3.0
cycler                      0.12.1
debugpy                     1.8.7
decorator                   5.1.1
docker-pycreds             

ModuleNotFoundError: No module named 'secret_runway_detection'

In [1]:
# %%
import sys
import os
import random
import ee
import wandb
import numpy as np
import pandas as pd
import geopandas as gpd
import pyproj
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm  # PyTorch Image Models library
from shapely.geometry import Polygon, Point
import matplotlib.pyplot as plt


# Add the src directory to the sys.path
# sys.path.append(os.path.abspath(os.path.join('..', 'src')))

# Import functions and constants from train_utils
from secret_runway_detection.train_utils import (
    landing_strips_to_enclosing_input_areas,
    input_area_to_input_image,
    make_label_tensor,
    TILE_SIDE_LEN,
    TILES_PER_AREA_LEN,
    INPUT_IMAGE_HEIGHT,
    INPUT_IMAGE_WIDTH,
    RANDOM_SEED
)

from secret_runway_detection.dataset import LandingStripDataset


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'secret_runway_detection'

 ## 2. Configuration and Initialization

In [2]:
# %%
# Debug flag: Set to True to run on CPU, False to use GPU if available
DEBUG = True

# Device configuration
device = torch.device('cpu') if DEBUG else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# Initialize wandb
wandb.init(project='secret-runway-detection', mode='online' if not DEBUG else 'dryrun')

# Authenticate and initialize Earth Engine
ee.Authenticate()
ee.Initialize()


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Using device: cpu


 ## 3. Load Landing Strips Data

In [3]:
# %%
# Path to the landing strips shapefile
landing_strips_shp = '../pac_2024_training/pac_2024_training.shp'  # Update this path as needed

# Load the landing strips shapefile
landing_strips = gpd.read_file(landing_strips_shp)

# Ensure CRS is WGS84
if landing_strips.crs != 'EPSG:4326':
    landing_strips = landing_strips.to_crs('EPSG:4326')

print(f"Loaded {len(landing_strips)} landing strips.")


Loaded 154 landing strips.


 ## 4. Create Input Areas Around Landing Strips

In [4]:
# %%
# Use the function from train_utils to create input areas
num_tiles_per_area_side_len = TILES_PER_AREA_LEN  # From train_utils constants
input_areas = landing_strips_to_enclosing_input_areas(landing_strips, num_tiles_per_area_side_len)

print(f"Created {len(input_areas)} input areas.")


AttributeError: 'GeoDataFrame' object has no attribute 'append'

 ## 5. Define the Dataset

In [None]:
# %%
# Define transformations
transform = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Standard ImageNet normalization
                         std=[0.229, 0.224, 0.225])
])

# Create dataset
dataset = LandingStripDataset(input_areas, landing_strips, transform=transform)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

print(f"Dataset size: {len(dataset)} samples")


 ## 6. Load the Geo Foundation Model (GFM)

In [6]:
# %%
def load_gfm_model(model_path):
    """
    Loads the Geo Foundation Model (GFM) from a checkpoint.
    
    Parameters:
    - model_path (str): Path to the model checkpoint.
    
    Returns:
    - model (torch.nn.Module): Loaded model.
    """
    model = timm.create_model(
        'swin_base_patch4_window7_224',
        pretrained=False,
        num_classes=1  # Assuming binary classification
    )
    checkpoint = torch.load(model_path, map_location='cpu')
    
    # Extract the state dictionary
    if 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    
    # Clean the state dictionary (remove 'module.' prefix if present)
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[len('module.'):]] = v
        else:
            new_state_dict[k] = v
    
    # Load the state dictionary
    model.load_state_dict(new_state_dict, strict=False)
    model = model.to(device)
    print("Model loaded and moved to device.")
    return model

# Path to the pre-trained GFM model
model_path = '../simmim_pretrain/gfm.pth'  # Replace with your actual model path

# Load the model
model = load_gfm_model(model_path)


  checkpoint = torch.load(model_path, map_location='cpu')


Model loaded and moved to device.


 ## 7. Define Loss Function and Optimizer

In [None]:
# %%
# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Suitable for binary classification with logits
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Optionally, define a learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)


 ## 8. Training Loop with wandb Logging

In [None]:
# %%
num_epochs = 10  # Adjust as needed

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        outputs = outputs.squeeze(1)  # Adjust dimensions if necessary
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        if i % 10 == 9 or i == len(dataloader) - 1:  # Log every 10 batches or last batch
            avg_loss = running_loss / 10
            print(f"[Epoch {epoch + 1}, Batch {i + 1}] Loss: {avg_loss:.4f}")
            wandb.log({'epoch': epoch + 1, 'batch': i + 1, 'loss': avg_loss})
            running_loss = 0.0
    
    # Step the scheduler
    scheduler.step()
    
    # Optionally, log learning rate
    current_lr = optimizer.param_groups[0]['lr']
    wandb.log({'learning_rate': current_lr})
    print(f"Epoch {epoch + 1} completed. Learning Rate: {current_lr}")

print("Training complete.")


 ## 9. Save the Trained Model

In [None]:
# %%
# Save the trained model
model_save_path = 'trained_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to '{model_save_path}'.")


 ## 10. Conclusion

In [None]:
# %%
print("""
# Training Summary

- **Model**: Swin Transformer (GFM) loaded from pre-trained checkpoint.
- **Dataset**: Landing strips with Sentinel-2 imagery.
- **Loss Function**: BCEWithLogitsLoss.
- **Optimizer**: Adam with learning rate scheduler.
- **Logging**: Weights & Biases (wandb) for experiment tracking.
- **Device**: {}
- **Epochs**: {}

Training has been completed and the model has been saved.
""".format(device, num_epochs))
