 ## Imports and Setup

In [3]:
import os
import sys
import numpy as np
import torch
import random
from tqdm.notebook import tqdm
import wandb
import timm
from pathlib import Path


# Add the src directory to the sys.path
sys.path.append(os.path.abspath('..'))

# from secret_runway_detection.model import CombinedModel, SegmentationHead

# Set random seeds for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


 ## Configuration Parameters

In [4]:
DEBUG = True

# BUFFER_TYPE = 'cross'
RUN_PATH = 'esedx12/secret-runway-detection/380m92co'

# IS_ON_WANDB = True
# RUN_NAME = 'neat-energy-31'
# MODEL_NAME = 'model:v7'

# # Model input and output dimensions
# INPUT_IMAGE_SIDE_LEN_PX = 224  # in pixels
# TILES_PER_INPUT_AREA_LEN = 224  # Number of tiles per side in one input area

# Threshold for converting model outputs to binary predictions
THRESHOLD = 0.5  # Adjust based on validation performance

In [5]:
# # Path to the trained model checkpoint
# MODEL_CHECKPOINT_PATH = '../checkpoints/trained_model.pth'  # Update this path

In [6]:
# Load the W&B run
train_run = wandb.Api().run(RUN_PATH)
train_run.name

'skilled-glade-49'

 ## Load the Trained Model

In [7]:
from more_itertools import one

# Fetch the model artifact from the W&B run
artifacts = train_run.logged_artifacts()
model_artifacts = [a for a in artifacts if a.type == 'model']
artifact = one(model_artifacts)
state_dict_dir = artifact.download(root='../artifacts/')
state_dict_dir = Path(state_dict_dir)
state_dict_path = state_dict_dir / f'{train_run.name}.pth'
state_dict_path

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Downloading large artifact model:v20, 374.16MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


PosixPath('../artifacts/skilled-glade-49.pth')

In [8]:
model_type = train_run.config['model_type']
print(f"Model type: {model_type}")

Model type: simple


In [9]:
# Now load the type backbone_config artifact the same way you loaded the model and put it in a temp directory
backbone_config_artifacts = [a for a in artifacts if a.type == 'backbone_config']
backbone_config_artifact = one(backbone_config_artifacts)
backbone_config_dir = backbone_config_artifact.download(root='../artifacts/')
backbone_config_dir = Path(backbone_config_dir)
backbone_config_path = backbone_config_dir / f'{train_run.name}.yaml'
backbone_config_path


ValueError: too few items in iterable (expected 1)

In [12]:
list(artifacts)

for artifact in artifacts:
    print(artifact.type, artifact.name)

model model:v20
wandb-history run-380m92co-history:v0


In [None]:
# Then load the model like this
# model = get_model(config['model_type'], BACKBONE_CFG_PATH, backbone_weights_path, output_size=config['resolution']).to(device)

from GFM.models import build_model

# Load the model
model = build_model(model_type, backbone_config_path, state_dict_path, output_size=train_run.config['resolution']).to(device)

In [None]:
# # Load model checkpoint from ../checkpoints dir
# backbone = timm.create_model(
#         'swin_base_patch4_window7_224',
#         pretrained=False,
#         num_classes=0,  # Assuming binary classification
#     )

# segmentation_head = SegmentationHead()

# model = CombinedModel(backbone, segmentation_head)


In [None]:
backbone

In [None]:
# Load the Model from WandB, which we saved as state dict
model.load_state_dict(torch.load(state_dict_path, map_location=device))
model.eval()

In [None]:
best_threshold = train_run.summary['best_threshold']

## Load validation images and labels

In [None]:
train_dir = Path(
    f'../training_data_{train_run.config['resolution']}/training_data_{train_run.config["training_dataset"]}')

images_dir = train_dir / 'images'
labels_dir = train_dir / 'labels'


In [None]:
# Get all filenames in the images directory
import re


all_filenames = os.listdir(images_dir)

# Initialize dictionaries and lists
strip_to_files = {}        # For files with strip numbers
possibly_empty_files = []  # For 'possibly_empty' files

# Regular expression pattern to match filenames with strip numbers
pattern = re.compile(r'^area_\d+_of_strip_(\d+)\.npy$')

# Process filenames
for filename in all_filenames:
    if 'possibly_empty' in filename:
        # This is a 'possibly_empty' file
        possibly_empty_files.append(filename)
    else:
        # Try to match the pattern to extract strip number
        match = pattern.match(filename)
        if match:
            strip_number = int(match.group(1))
            # Add filename to the list for this strip number
            strip_to_files.setdefault(strip_number, []).append(filename)
        else:
            print(f"Filename does not match expected pattern: {filename}")


In [None]:
val_strip_numbers = train_run.config['val_strip_numbers']

val_files = []
for strip_num in val_strip_numbers:
    val_files.extend(strip_to_files[strip_num])

train_strip_numbers = train_run.config['train_strip_numbers']

train_files = []
for strip_num in train_strip_numbers:
    train_files.extend(strip_to_files[strip_num])

In [None]:
# Define your transform if you have one; otherwise, set to None
from secret_runway_detection.dataset import LandingStripDataset, SegmentationTransform

segmentation_transform = SegmentationTransform()  # Replace with your actual transform if any
# segmentation_transform = None

# Create validation dataset
val_dataset = LandingStripDataset(
    images_dir=images_dir,
    labels_dir=labels_dir,
    file_list=val_files,
    transform=segmentation_transform
)

# Create training dataset
train_dataset = LandingStripDataset(
    images_dir=images_dir,
    labels_dir=labels_dir,
    file_list=train_files,
    transform=segmentation_transform
)


## Visualize Image, Label and Prediction

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.pyplot as plt
import numpy as np
import torch

# Function to visualize predictions with optional label overlay
def visualize_predictions(model, dataset, num_samples=5, overlay=False):
    model.eval()
    with torch.no_grad():
        for i in range(num_samples):
            sample = dataset[i]
            input_image = sample[0].unsqueeze(0)  # Add batch dimension
            label = sample[1]
            
            # Generate prediction
            prediction = model(input_image).squeeze(0)  # Remove batch dimension
            
            # Convert tensors to numpy arrays for visualization
            input_image_np = input_image.squeeze(0).numpy().transpose(1, 2, 0)  # HWC format
            label_np = label.numpy()
            prediction_np = prediction.numpy().squeeze(0)
            
            # Create subplots
            fig, ax = plt.subplots(1, 3 if not overlay else 3, figsize=(15, 5))
            
            # Display Input Image
            ax_idx = 0
            ax[ax_idx].imshow(input_image_np)
            ax[ax_idx].set_title('Input Image')
            ax[ax_idx].axis('off')
            
            if overlay:
                # Overlay Label on Input Image
                ax_idx += 1
                ax[ax_idx].imshow(input_image_np)
                ax[ax_idx].imshow(label_np, cmap='jet', alpha=0.5)  # Adjust alpha for transparency
                ax[ax_idx].set_title('Input Image with Label Overlay')
                ax[ax_idx].axis('off')
                
                # Display Prediction
                ax_idx += 1
                ax[ax_idx].imshow(prediction_np, cmap='gray')
                ax[ax_idx].set_title('Prediction')
                ax[ax_idx].axis('off')
            else:
                # Display Label
                ax_idx += 1
                ax[ax_idx].imshow(label_np, cmap='gray')
                ax[ax_idx].set_title('Label')
                ax[ax_idx].axis('off')
                
                # Display Prediction
                ax_idx += 1
                ax[ax_idx].imshow(prediction_np, cmap='gray')
                ax[ax_idx].set_title('Prediction')
                ax[ax_idx].axis('off')
            
            plt.show()


In [None]:
# Visualize predictions for a few VALIDATION samples
visualize_predictions(model, val_dataset, num_samples=5, overlay=True)

In [None]:
# Visualize predictions for a few TRAINING samples
visualize_predictions(model, train_dataset, num_samples=5)