 ## Imports and Setup

In [32]:
import os
import sys
import numpy as np
import torch
import random
from tqdm.notebook import tqdm
import wandb
import timm
from pathlib import Path


# Add the src directory to the sys.path
sys.path.append(os.path.abspath('..'))

from secret_runway_detection.model import CombinedModel, SegmentationHead

# Set random seeds for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


 ## Configuration Parameters

In [2]:
DEBUG = True

# BUFFER_TYPE = 'cross'
RUN_PATH = 'esedx12/secret-runway-detection/621fojqi'

# IS_ON_WANDB = True
# RUN_NAME = 'neat-energy-31'
# MODEL_NAME = 'model:v7'

# # Model input and output dimensions
# INPUT_IMAGE_SIDE_LEN_PX = 224  # in pixels
# TILES_PER_INPUT_AREA_LEN = 224  # Number of tiles per side in one input area

# Threshold for converting model outputs to binary predictions
THRESHOLD = 0.5  # Adjust based on validation performance

In [3]:
# # Path to the trained model checkpoint
# MODEL_CHECKPOINT_PATH = '../checkpoints/trained_model.pth'  # Update this path

In [4]:
# Load the W&B run
train_run = wandb.Api().run(RUN_PATH)

 ## Load the Trained Model

In [42]:
from more_itertools import one

# Fetch the model artifact from the W&B run
artifacts = train_run.logged_artifacts()
artifact = one(artifacts)  # Ensures exactly one artifact
state_dict_dir = artifact.download(root='../artifacts/')
state_dict_dir = Path(state_dict_dir)
state_dict_path = state_dict_dir / f'{train_run.name}.pth'
state_dict_path

[34m[1mwandb[0m: Downloading large artifact model:v7, 358.74MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


PosixPath('../artifacts/neat-energy-31.pth')

In [43]:
# Load model checkpoint from ../checkpoints dir
backbone = model = timm.create_model(
        'swin_base_patch4_window7_224',
        pretrained=False,
        num_classes=0,  # Assuming binary classification
    )

segmentation_head = SegmentationHead()

model = CombinedModel(backbone, segmentation_head)


In [48]:
# Load the Model from WandB, which we saved as state dict
model.load_state_dict(torch.load(state_dict_path, map_location=device))
model.eval()

  model.load_state_dict(torch.load(state_dict_path, map_location=device))


CombinedModel(
  (backbone): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (layers): Sequential(
      (0): SwinTransformerStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=128, out_features=384, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=128, out_features=128, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path1): Identity()
            (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=128, out_features=51

In [50]:
best_threshold = train_run.config['best_threshold']

KeyError: 'best_threshold'

## Load validation images and labels

In [None]:
train_dir = Path(
    f'../training_data/training_data_{train_run.config["training_dataset"]}')

images_dir = train_dir / 'images'
labels_dir = train_dir / 'labels'


In [None]:
# Get all filenames in the images directory
import re


all_filenames = os.listdir(images_dir)

# Initialize dictionaries and lists
strip_to_files = {}        # For files with strip numbers
possibly_empty_files = []  # For 'possibly_empty' files

# Regular expression pattern to match filenames with strip numbers
pattern = re.compile(r'^area_\d+_of_strip_(\d+)\.npy$')

# Process filenames
for filename in all_filenames:
    if 'possibly_empty' in filename:
        # This is a 'possibly_empty' file
        possibly_empty_files.append(filename)
    else:
        # Try to match the pattern to extract strip number
        match = pattern.match(filename)
        if match:
            strip_number = int(match.group(1))
            # Add filename to the list for this strip number
            strip_to_files.setdefault(strip_number, []).append(filename)
        else:
            print(f"Filename does not match expected pattern: {filename}")


In [None]:
val_strip_numbers = train_run.config['val_strip_numbers']

val_files = []
for strip_num in val_strip_numbers:
    val_files.extend(strip_to_files[strip_num])

In [None]:
# Define your transform if you have one; otherwise, set to None
from dataset import LandingStripDataset


segmentation_transform = None  # Replace with your actual transform if any

# Create validation dataset
val_dataset = LandingStripDataset(
    images_dir=images_dir,
    labels_dir=labels_dir,
    file_list=val_files,
    transform=segmentation_transform
)

## Visualize Image, Label and Prediction

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Load the model
state_dict_path = "path/to/your/state_dict.pth"  # Replace with your actual path
state_dict = torch.load(state_dict_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()

# Function to visualize predictions
def visualize_predictions(model, dataset, num_samples=5):
    model.eval()
    with torch.no_grad():
        for i in range(num_samples):
            sample = dataset[i]
            input_image = sample['image'].unsqueeze(0)  # Add batch dimension
            label = sample['label']
            
            # Generate prediction
            prediction = model(input_image).squeeze(0)  # Remove batch dimension
            
            # Convert tensors to numpy arrays for visualization
            input_image_np = input_image.squeeze(0).numpy().transpose(1, 2, 0)  # HWC format
            label_np = label.numpy()
            prediction_np = prediction.numpy()
            
            # Plot input image, label, and prediction
            fig, ax = plt.subplots(1, 3, figsize=(15, 5))
            ax[0].imshow(input_image_np)
            ax[0].set_title('Input Image')
            ax[0].axis('off')
            
            ax[1].imshow(label_np, cmap='gray')
            ax[1].set_title('Label')
            ax[1].axis('off')
            
            ax[2].imshow(prediction_np, cmap='gray')
            ax[2].set_title('Prediction')
            ax[2].axis('off')
            
            plt.show()

# Visualize predictions for a few samples
visualize_predictions(model, val_dataset, num_samples=5)

## Visualize one AOI confidence map and the corresponding image

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from skimage.transform import resize
import matplotlib.colors as mcolors
import torch

# Function to normalize and prepare the satellite image for display
def prepare_satellite_image(image_data):
    # Stack bands into RGB format
    img_rgb = np.dstack((image_data[0], image_data[1], image_data[2]))
    # Normalize the image for display
    img_rgb = img_rgb.astype(float)
    img_rgb = (img_rgb - img_rgb.min()) / (img_rgb.max() - img_rgb.min())
    return img_rgb

# Function to overlay the confidence map on the satellite image
def overlay_confidence_map(satellite_image, confidence_map):
    print("Overlaying confidence map...")
    # Convert confidence_map to NumPy array if it's a PyTorch tensor
    if isinstance(confidence_map, torch.Tensor):
        confidence_map = confidence_map.detach().cpu().numpy()
    
    plt.figure(figsize=(12, 12))
    plt.imshow(satellite_image)
    plt.title("Satellite Image with Confidence Map Overlay")
    plt.axis('off')
    
    # Resize the confidence map to match the satellite image dimensions
    confidence_map_resized = resize(confidence_map, (satellite_image.shape[0], satellite_image.shape[1]),
                                    order=1, preserve_range=True, anti_aliasing=False)
    
    # Create a color map for the confidence map
    cmap = plt.cm.Reds
    cmap.set_under(color='none')  # Make values below the threshold transparent
    
    # Define a threshold for visualization
    THRESHOLD = 0.5  # Adjust based on your data
    
    # Overlay the confidence map
    plt.imshow(confidence_map_resized, cmap=cmap, alpha=0.5, vmin=THRESHOLD, vmax=1)
    
    # Add a colorbar
    plt.colorbar(label='Confidence Score')
    
    plt.show()
    print("Confidence map plot displayed.")

# Function to overlay the has-strip map on the satellite image
def overlay_has_strip_map(satellite_image, has_strip_map):
    print("Overlaying has-strip map...")
    # Convert has_strip_map to NumPy array if it's a PyTorch tensor
    if isinstance(has_strip_map, torch.Tensor):
        has_strip_map = has_strip_map.detach().cpu().numpy()
    
    plt.figure(figsize=(12, 12))
    plt.imshow(satellite_image)
    plt.title("Satellite Image with Has-Strip Map Overlay")
    plt.axis('off')
    
    # Resize the has-strip map to match the satellite image dimensions
    has_strip_map_resized = resize(has_strip_map.astype(float), (satellite_image.shape[0], satellite_image.shape[1]),
                                   order=0, preserve_range=True, anti_aliasing=False)
    
    # Create a colormap for the has-strip map
    cmap = mcolors.ListedColormap(['none', 'red'])
    bounds = [0, 0.5, 1]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    
    # Overlay the has-strip map
    plt.imshow(has_strip_map_resized, cmap=cmap, norm=norm, alpha=0.5)
    
    # Add a legend
    legend_elements = [
        Patch(facecolor='red', edgecolor='red', label='Has-Strip Area')
    ]
    plt.legend(handles=legend_elements, loc='lower right')
    
    plt.show()
    print("Has-strip map plot displayed.")

# Example usage
# Assuming you have 'aoi_image', 'confidence_map', and 'has_strip_map' variables
satellite_image = prepare_satellite_image(aoi_image)
overlay_confidence_map(satellite_image, confidence_map)
overlay_has_strip_map(satellite_image, has_strip_map)
