# Inference and Evaluation Notebook: RetinaNet + U-Net

This notebook provides comprehensive inference and evaluation capabilities for both RetinaNet (object detection) and U-Net (landmark detection) models.

## Common Setup and Imports

In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
import glob
import json
import time
from collections import defaultdict
from natsort import natsorted
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import pandas as pd
from PIL import Image

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. RetinaNet Inference and Evaluation

### RetinaNet Model Loading and Configuration

In [None]:
# RetinaNet imports
from retinanet import model
from retinanet.dataloader import CocoDataset, CSVDataset, collater, Resizer, AspectRatioBasedSampler, Normalizer
from retinanet import csv_eval, coco_eval
from retinanet.utils import BBoxTransform, ClipBoxes

# RetinaNet Inference Configuration
RETINANET_CONFIG = {
    'MODEL_PATH': './retinanet_weights/best_retinanet_epoch_50.pt',  # Update with your model path
    'DATASET_TYPE': 'csv',  # 'csv' or 'coco'
    'CONFIDENCE_THRESHOLD': 0.5,
    'NMS_THRESHOLD': 0.5,
    'MAX_DETECTIONS': 100,
    'BATCH_SIZE': 1,
    'IMAGE_SIZE': (512, 512),
    'NUM_CLASSES': 1  # Update based on your dataset
}

# Dataset paths for evaluation
RETINANET_DATA_PATHS = {
    'csv_test': './dataset/annotations_test.csv',
    'csv_classes': './dataset/classes.csv',
    'coco_path': './dataset/coco/',
    'test_images': './dataset/test_images/'
}

print("RetinaNet inference configuration loaded.")
print(f"Model path: {RETINANET_CONFIG['MODEL_PATH']}")
print(f"Confidence threshold: {RETINANET_CONFIG['CONFIDENCE_THRESHOLD']}")

### RetinaNet Model Loading

In [None]:
def load_retinanet_model(model_path, num_classes, device):
    """
    Load RetinaNet model from checkpoint
    """
    try:
        # Load the model
        model = torch.load(model_path, map_location=device)
        model.eval()
        model.to(device)
        
        print(f"RetinaNet model loaded successfully from: {model_path}")
        print(f"Model is on device: {next(model.parameters()).device}")
        
        return model
        
    except Exception as e:
        print(f"Error loading RetinaNet model: {e}")
        print("Creating new model with random weights for testing...")
        
        # Create a new model for testing
        model = model.resnet50(num_classes=num_classes, pretrained=False)
        model.eval()
        model.to(device)
        
        return model

# Load RetinaNet model (uncomment when model is available)
# retinanet_model = load_retinanet_model(RETINANET_CONFIG['MODEL_PATH'], RETINANET_CONFIG['NUM_CLASSES'], device)

### RetinaNet Dataset Loading for Evaluation

In [None]:
def create_retinanet_test_dataset(config, data_paths):
    """
    Create test dataset for RetinaNet evaluation
    """
    if config['DATASET_TYPE'] == 'csv':
        dataset_test = CSVDataset(
            train_file=data_paths['csv_test'],
            class_list=data_paths['csv_classes'],
            transform=transforms.Compose([Normalizer(), Resizer()])
        )
    elif config['DATASET_TYPE'] == 'coco':
        dataset_test = CocoDataset(
            data_paths['coco_path'],
            set_name='test2017',
            transform=transforms.Compose([Normalizer(), Resizer()])
        )
    
    # Create data loader
    sampler_test = AspectRatioBasedSampler(dataset_test, batch_size=config['BATCH_SIZE'], drop_last=False)
    dataloader_test = DataLoader(dataset_test, num_workers=0, collate_fn=collater, batch_sampler=sampler_test)
    
    return dataloader_test, dataset_test

# Create test dataset (uncomment when dataset is available)
# test_dataloader, test_dataset = create_retinanet_test_dataset(RETINANET_CONFIG, RETINANET_DATA_PATHS)
# print(f'Test dataset size: {len(test_dataset)}')

### RetinaNet Inference Functions

In [None]:
def inference_retinanet_single_image(model, image, device, config):
    """
    Perform inference on a single image
    """
    model.eval()
    
    with torch.no_grad():
        # Prepare image
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).float()
        
        if len(image.shape) == 3:
            image = image.unsqueeze(0)
        
        image = image.to(device)
        
        # Inference
        start_time = time.time()
        scores, classification, transformed_anchors = model(image)
        inference_time = time.time() - start_time
        
        # Process predictions
        scores = scores.cpu().numpy()
        classification = classification.cpu().numpy()
        transformed_anchors = transformed_anchors.cpu().numpy()
        
        # Filter by confidence threshold
        indices = np.where(scores > config['CONFIDENCE_THRESHOLD'])[0]
        
        if len(indices) == 0:
            return [], [], [], inference_time
        
        # Get filtered predictions
        filtered_scores = scores[indices]
        filtered_classes = classification[indices]
        filtered_boxes = transformed_anchors[indices]
        
        # Apply NMS
        keep_indices = torch.ops.torchvision.nms(
            torch.from_numpy(filtered_boxes),
            torch.from_numpy(filtered_scores),
            config['NMS_THRESHOLD']
        )
        
        final_scores = filtered_scores[keep_indices]
        final_classes = filtered_classes[keep_indices]
        final_boxes = filtered_boxes[keep_indices]
        
        return final_boxes, final_scores, final_classes, inference_time

def batch_inference_retinanet(model, dataloader, device, config):
    """
    Perform batch inference on test dataset
    """
    model.eval()
    
    all_predictions = []
    all_ground_truths = []
    inference_times = []
    
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(dataloader, desc="Running inference")):
            images = data['img'].to(device)
            annotations = data['annot']
            
            # Inference
            start_time = time.time()
            scores, classification, transformed_anchors = model(images)
            inference_time = time.time() - start_time
            
            # Process batch predictions
            batch_predictions = []
            
            for i in range(images.shape[0]):
                # Get predictions for this image
                img_scores = scores[i].cpu().numpy()
                img_classes = classification[i].cpu().numpy()
                img_boxes = transformed_anchors[i].cpu().numpy()
                
                # Filter by confidence
                indices = np.where(img_scores > config['CONFIDENCE_THRESHOLD'])[0]
                
                if len(indices) > 0:
                    img_predictions = {
                        'boxes': img_boxes[indices],
                        'scores': img_scores[indices],
                        'classes': img_classes[indices]
                    }
                else:
                    img_predictions = {
                        'boxes': np.array([]),
                        'scores': np.array([]),
                        'classes': np.array([])
                    }
                
                batch_predictions.append(img_predictions)
            
            all_predictions.extend(batch_predictions)
            all_ground_truths.extend(annotations)
            inference_times.append(inference_time)
    
    return all_predictions, all_ground_truths, inference_times

# Example usage (uncomment when model and data are ready)
# predictions, ground_truths, times = batch_inference_retinanet(retinanet_model, test_dataloader, device, RETINANET_CONFIG)
# print(f'Average inference time per batch: {np.mean(times):.4f} seconds')

### RetinaNet Evaluation Metrics

In [None]:
def calculate_iou(box1, box2):
    """
    Calculate IoU between two bounding boxes
    """
    # Box format: [x1, y1, x2, y2]
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    if x2 <= x1 or y2 <= y1:
        return 0.0
    
    intersection = (x2 - x1) * (y2 - y1)
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = area1 + area2 - intersection
    
    return intersection / union if union > 0 else 0.0

def evaluate_retinanet_predictions(predictions, ground_truths, iou_threshold=0.5):
    """
    Evaluate RetinaNet predictions using mAP metrics
    """
    all_detections = []
    all_annotations = []
    
    for i, (pred, gt) in enumerate(zip(predictions, ground_truths)):
        # Process predictions
        if len(pred['boxes']) > 0:
            for j, (box, score, cls) in enumerate(zip(pred['boxes'], pred['scores'], pred['classes'])):
                all_detections.append({
                    'image_id': i,
                    'bbox': box,
                    'score': score,
                    'class': int(cls)
                })
        
        # Process ground truth
        if gt is not None and len(gt) > 0:
            for annotation in gt:
                if annotation[4] != -1:  # Valid annotation
                    all_annotations.append({
                        'image_id': i,
                        'bbox': annotation[:4],
                        'class': int(annotation[4])
                    })
    
    # Calculate mAP
    if len(all_detections) == 0 or len(all_annotations) == 0:
        return {
            'mAP': 0.0,
            'mAP_50': 0.0,
            'mAP_75': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0
        }
    
    # Group by class
    detections_by_class = defaultdict(list)
    annotations_by_class = defaultdict(list)
    
    for det in all_detections:
        detections_by_class[det['class']].append(det)
    
    for ann in all_annotations:
        annotations_by_class[ann['class']].append(ann)
    
    # Calculate AP for each class
    aps = []
    
    for class_id in annotations_by_class.keys():
        class_detections = detections_by_class.get(class_id, [])
        class_annotations = annotations_by_class[class_id]
        
        # Sort detections by score
        class_detections = sorted(class_detections, key=lambda x: x['score'], reverse=True)
        
        # Calculate precision and recall
        tp = np.zeros(len(class_detections))
        fp = np.zeros(len(class_detections))
        
        # Track which annotations have been matched
        matched_annotations = set()
        
        for i, detection in enumerate(class_detections):
            # Find best matching annotation
            best_iou = 0.0
            best_ann_idx = -1
            
            for j, annotation in enumerate(class_annotations):
                if annotation['image_id'] == detection['image_id']:
                    iou = calculate_iou(detection['bbox'], annotation['bbox'])
                    if iou > best_iou:
                        best_iou = iou
                        best_ann_idx = j
            
            if best_iou >= iou_threshold and best_ann_idx not in matched_annotations:
                tp[i] = 1
                matched_annotations.add(best_ann_idx)
            else:
                fp[i] = 1
        
        # Calculate precision and recall curves
        tp_cumsum = np.cumsum(tp)
        fp_cumsum = np.cumsum(fp)
        
        precisions = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-8)
        recalls = tp_cumsum / len(class_annotations)
        
        # Calculate AP using 11-point interpolation
        ap = 0.0
        for t in np.arange(0, 1.1, 0.1):
            if np.sum(recalls >= t) == 0:
                p = 0
            else:
                p = np.max(precisions[recalls >= t])
            ap += p / 11.0
        
        aps.append(ap)
    
    # Calculate overall metrics
    mAP = np.mean(aps) if aps else 0.0
    
    # Calculate precision, recall, F1 at IoU threshold
    total_tp = sum(tp.sum() for tp in [np.array([1 if calculate_iou(det['bbox'], ann['bbox']) >= iou_threshold else 0 
                                                for ann in all_annotations if ann['image_id'] == det['image_id'] and ann['class'] == det['class']])
                                      for det in all_detections])
    
    precision = total_tp / len(all_detections) if all_detections else 0.0
    recall = total_tp / len(all_annotations) if all_annotations else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        'mAP': mAP,
        'mAP_50': mAP,  # This is mAP at IoU=0.5
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'num_detections': len(all_detections),
        'num_annotations': len(all_annotations)
    }

# Example evaluation (uncomment when predictions are available)
# retinanet_metrics = evaluate_retinanet_predictions(predictions, ground_truths)
# print("RetinaNet Evaluation Results:")
# for key, value in retinanet_metrics.items():
#     print(f"{key}: {value:.4f}")

## 2. U-Net Inference and Evaluation

### U-Net Model Loading and Configuration

In [None]:
# U-Net imports
from Unet.loss import dice_loss, dice
from Unet.Unet import UNet
from Unet.preprocessing import *
from Unet.datagenerater import Dental_Single_Data_Generator
from Unet.utils import *

# U-Net Inference Configuration
UNET_CONFIG = {
    'MODEL_PATH': './unet_weights/',  # Directory containing landmark models
    'MULTICLASS_MODEL_PATH': './unet_weights/best_multiclass_unet.pth',
    'IMAGE_SIZE': (512, 512),
    'NUM_LANDMARKS': 14,
    'BATCH_SIZE': 1,
    'CONFIDENCE_THRESHOLD': 0.5,
    'USE_MULTICLASS': False,  # Set to True for single multi-class model
    'ENCODER_NAME': 'vgg16',
    'DISTANCE_THRESHOLD': 10.0  # Pixel distance threshold for landmark accuracy
}

# Dataset paths for evaluation
UNET_DATA_PATHS = {
    'image_path': './dataset/images/',
    'label_path': './dataset/labels/',
    'test_split': 0.2  # 20% for testing
}

print("U-Net inference configuration loaded.")
print(f"Number of landmarks: {UNET_CONFIG['NUM_LANDMARKS']}")
print(f"Multi-class model: {UNET_CONFIG['USE_MULTICLASS']}")
print(f"Distance threshold: {UNET_CONFIG['DISTANCE_THRESHOLD']} pixels")

### U-Net Model Loading

In [None]:
def load_unet_models(config, device):
    """
    Load U-Net models for landmark detection
    """
    models = []
    
    if config['USE_MULTICLASS']:
        # Load single multi-class model
        try:
            import segmentation_models_pytorch as smp
            model = smp.Unet(
                encoder_name=config['ENCODER_NAME'],
                decoder_attention_type='scse',
                in_channels=1,
                classes=config['NUM_LANDMARKS']
            )
        except ImportError:
            model = UNet(n_channels=1, n_classes=config['NUM_LANDMARKS'])
        
        # Load weights
        try:
            state_dict = torch.load(config['MULTICLASS_MODEL_PATH'], map_location=device)
            model.load_state_dict(state_dict)
            print(f"Multi-class U-Net model loaded from: {config['MULTICLASS_MODEL_PATH']}")
        except Exception as e:
            print(f"Error loading multi-class model: {e}")
            print("Using model with random weights...")
        
        model.eval()
        model.to(device)
        models.append(model)
        
    else:
        # Load separate models for each landmark
        for landmark_idx in range(config['NUM_LANDMARKS']):
            try:
                import segmentation_models_pytorch as smp
                model = smp.Unet(
                    encoder_name=config['ENCODER_NAME'],
                    decoder_attention_type='scse',
                    in_channels=1,
                    classes=1
                )
            except ImportError:
                model = UNet(n_channels=1, n_classes=1)
            
            # Load weights for this landmark
            model_path = os.path.join(config['MODEL_PATH'], str(landmark_idx), 'weight.pth')
            try:
                state_dict = torch.load(model_path, map_location=device)
                model.load_state_dict(state_dict)
                print(f"Landmark {landmark_idx} model loaded from: {model_path}")
            except Exception as e:
                print(f"Error loading model for landmark {landmark_idx}: {e}")
                print("Using model with random weights...")
            
            model.eval()
            model.to(device)
            models.append(model)
    
    return models

# Load U-Net models (uncomment when models are available)
# unet_models = load_unet_models(UNET_CONFIG, device)
# print(f"Loaded {len(unet_models)} U-Net model(s)")

### U-Net Dataset Loading for Evaluation

In [None]:
def prepare_unet_test_dataset(data_paths, config):
    """
    Prepare U-Net test dataset
    """
    # Get all image and label files
    image_files = natsorted(glob.glob(os.path.join(data_paths['image_path'], '*.png')))
    label_files = natsorted(glob.glob(os.path.join(data_paths['label_path'], '*.npy')))
    
    # Match image and label files
    matched_pairs = []
    for label_file in label_files:
        base_name = os.path.basename(label_file).split('.')[0]
        matching_images = [img for img in image_files if base_name in os.path.basename(img)]
        if matching_images:
            matched_pairs.append((matching_images[0], label_file))
    
    # Split into train/test
    split_idx = int(len(matched_pairs) * (1 - data_paths['test_split']))
    
    test_pairs = matched_pairs[split_idx:]
    
    x_test = [pair[0] for pair in test_pairs]
    y_test = [pair[1] for pair in test_pairs]
    
    print(f'Test samples: {len(x_test)}')
    
    return x_test, y_test

def create_unet_test_dataloaders(x_test, y_test, config):
    """
    Create test data loaders for U-Net evaluation
    """
    transform_test = transforms.Compose([
        ToTensor(),
    ])
    
    if config['USE_MULTICLASS']:
        # Single multi-class dataset
        testset = Dental_Single_Data_Generator(
            config['IMAGE_SIZE'], x_test, y_test, 
            landmark_num=-1, mode="test", transform=transform_test
        )
        
        testloader = DataLoader(testset, batch_size=config['BATCH_SIZE'], shuffle=False)
        return testloader
    else:
        # Separate datasets for each landmark
        test_loaders = []
        
        for landmark_idx in range(config['NUM_LANDMARKS']):
            testset = Dental_Single_Data_Generator(
                config['IMAGE_SIZE'], x_test, y_test, 
                landmark_num=landmark_idx, mode="test", transform=transform_test
            )
            
            testloader = DataLoader(testset, batch_size=config['BATCH_SIZE'], shuffle=False)
            test_loaders.append(testloader)
        
        return test_loaders

# Prepare test dataset (uncomment when dataset is available)
# x_test, y_test = prepare_unet_test_dataset(UNET_DATA_PATHS, UNET_CONFIG)
# test_loaders = create_unet_test_dataloaders(x_test, y_test, UNET_CONFIG)
# print(f"Created {len(test_loaders) if isinstance(test_loaders, list) else 1} test loader(s)")

### U-Net Inference Functions

In [None]:
def extract_landmark_coordinates(mask, threshold=0.5):
    """
    Extract landmark coordinates from segmentation mask
    """
    # Apply threshold
    binary_mask = (mask > threshold).astype(np.uint8)
    
    # Find contours
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if len(contours) == 0:
        return None
    
    # Get the largest contour
    largest_contour = max(contours, key=cv2.contourArea)
    
    # Get centroid
    M = cv2.moments(largest_contour)
    if M["m00"] == 0:
        return None
    
    cx = int(M["m10"] / M["m00"])
    cy = int(M["m01"] / M["m00"])
    
    return (cx, cy)

def inference_unet_single_image(models, image, device, config):
    """
    Perform U-Net inference on a single image
    """
    predictions = []
    inference_times = []
    
    # Prepare image
    if isinstance(image, np.ndarray):
        image = torch.from_numpy(image).float()
    
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    
    image = image.to(device)
    
    if config['USE_MULTICLASS']:
        # Single multi-class model
        model = models[0]
        model.eval()
        
        with torch.no_grad():
            start_time = time.time()
            output = model(image)
            output = torch.sigmoid(output)
            inference_time = time.time() - start_time
            
            # Process each landmark channel
            output_np = output.cpu().numpy().squeeze()
            
            for landmark_idx in range(config['NUM_LANDMARKS']):
                landmark_mask = output_np[landmark_idx]
                coords = extract_landmark_coordinates(landmark_mask, config['CONFIDENCE_THRESHOLD'])
                predictions.append(coords)
            
            inference_times.append(inference_time)
    else:
        # Separate models for each landmark
        for landmark_idx, model in enumerate(models):
            model.eval()
            
            with torch.no_grad():
                start_time = time.time()
                output = model(image)
                output = torch.sigmoid(output)
                inference_time = time.time() - start_time
                
                # Extract coordinates
                mask = output.cpu().numpy().squeeze()
                coords = extract_landmark_coordinates(mask, config['CONFIDENCE_THRESHOLD'])
                predictions.append(coords)
                inference_times.append(inference_time)
    
    return predictions, inference_times

def batch_inference_unet(models, test_loaders, device, config):
    """
    Perform batch inference on U-Net test dataset
    """
    all_predictions = []
    all_ground_truths = []
    all_inference_times = []
    
    if config['USE_MULTICLASS']:
        # Multi-class model inference
        model = models[0]
        model.eval()
        
        with torch.no_grad():
            for batch_idx, sample in enumerate(tqdm(test_loaders, desc="Running U-Net inference")):
                images = sample['image'].to(device)
                landmarks = sample['landmarks'].cpu().numpy()
                
                start_time = time.time()
                outputs = model(images)
                outputs = torch.sigmoid(outputs)
                inference_time = time.time() - start_time
                
                # Process batch
                batch_predictions = []
                batch_ground_truths = []
                
                for i in range(images.shape[0]):
                    # Extract predictions for each landmark
                    img_predictions = []
                    output_np = outputs[i].cpu().numpy()
                    
                    for landmark_idx in range(config['NUM_LANDMARKS']):
                        landmark_mask = output_np[landmark_idx]
                        coords = extract_landmark_coordinates(landmark_mask, config['CONFIDENCE_THRESHOLD'])
                        img_predictions.append(coords)
                    
                    batch_predictions.append(img_predictions)
                    
                    # Extract ground truth coordinates
                    gt_coords = []
                    for landmark_idx in range(config['NUM_LANDMARKS']):
                        gt_mask = landmarks[i, landmark_idx]
                        gt_coord = extract_landmark_coordinates(gt_mask, 0.5)
                        gt_coords.append(gt_coord)
                    
                    batch_ground_truths.append(gt_coords)
                
                all_predictions.extend(batch_predictions)
                all_ground_truths.extend(batch_ground_truths)
                all_inference_times.append(inference_time)
    else:
        # Separate models inference
        num_samples = len(test_loaders[0].dataset)
        
        # Initialize prediction and ground truth arrays
        predictions_by_sample = [[] for _ in range(num_samples)]
        ground_truths_by_sample = [[] for _ in range(num_samples)]
        
        # Process each landmark model
        for landmark_idx, (model, test_loader) in enumerate(zip(models, test_loaders)):
            model.eval()
            
            with torch.no_grad():
                for batch_idx, sample in enumerate(tqdm(test_loader, desc=f"Landmark {landmark_idx + 1}/{config['NUM_LANDMARKS']}")):
                    images = sample['image'].to(device)
                    landmarks = sample['landmarks'].cpu().numpy()
                    
                    start_time = time.time()
                    outputs = model(images)
                    outputs = torch.sigmoid(outputs)
                    inference_time = time.time() - start_time
                    
                    # Process batch
                    for i in range(images.shape[0]):
                        sample_idx = batch_idx * config['BATCH_SIZE'] + i
                        
                        # Extract prediction coordinates
                        mask = outputs[i].cpu().numpy().squeeze()
                        coords = extract_landmark_coordinates(mask, config['CONFIDENCE_THRESHOLD'])
                        
                        # Store prediction
                        if len(predictions_by_sample[sample_idx]) <= landmark_idx:
                            predictions_by_sample[sample_idx].extend([None] * (landmark_idx + 1 - len(predictions_by_sample[sample_idx])))
                        predictions_by_sample[sample_idx][landmark_idx] = coords
                        
                        # Extract ground truth coordinates
                        gt_mask = landmarks[i].squeeze()
                        gt_coord = extract_landmark_coordinates(gt_mask, 0.5)
                        
                        # Store ground truth
                        if len(ground_truths_by_sample[sample_idx]) <= landmark_idx:
                            ground_truths_by_sample[sample_idx].extend([None] * (landmark_idx + 1 - len(ground_truths_by_sample[sample_idx])))
                        ground_truths_by_sample[sample_idx][landmark_idx] = gt_coord
                    
                    all_inference_times.append(inference_time)
        
        all_predictions = predictions_by_sample
        all_ground_truths = ground_truths_by_sample
    
    return all_predictions, all_ground_truths, all_inference_times

# Example usage (uncomment when models and data are ready)
# unet_predictions, unet_ground_truths, unet_times = batch_inference_unet(unet_models, test_loaders, device, UNET_CONFIG)
# print(f'U-Net inference completed on {len(unet_predictions)} samples')
# print(f'Average inference time: {np.mean(unet_times):.4f} seconds')

### U-Net Evaluation Metrics

In [None]:
def calculate_euclidean_distance(coord1, coord2):
    """
    Calculate Euclidean distance between two coordinates
    """
    if coord1 is None or coord2 is None:
        return float('inf')
    
    return np.sqrt((coord1[0] - coord2[0])**2 + (coord1[1] - coord2[1])**2)

def evaluate_unet_predictions(predictions, ground_truths, config):
    """
    Evaluate U-Net predictions using landmark detection metrics
    """
    num_landmarks = config['NUM_LANDMARKS']
    distance_threshold = config['DISTANCE_THRESHOLD']
    
    # Initialize metrics storage
    landmark_distances = [[] for _ in range(num_landmarks)]
    landmark_accuracies = [[] for _ in range(num_landmarks)]
    detection_rates = [0] * num_landmarks
    
    total_samples = len(predictions)
    
    # Process each sample
    for sample_idx, (pred_landmarks, gt_landmarks) in enumerate(zip(predictions, ground_truths)):
        for landmark_idx in range(num_landmarks):
            pred_coord = pred_landmarks[landmark_idx] if landmark_idx < len(pred_landmarks) else None
            gt_coord = gt_landmarks[landmark_idx] if landmark_idx < len(gt_landmarks) else None
            
            if gt_coord is not None:
                if pred_coord is not None:
                    # Calculate distance
                    distance = calculate_euclidean_distance(pred_coord, gt_coord)
                    landmark_distances[landmark_idx].append(distance)
                    
                    # Check if within threshold
                    is_accurate = distance <= distance_threshold
                    landmark_accuracies[landmark_idx].append(is_accurate)
                    
                    # Update detection rate
                    if is_accurate:
                        detection_rates[landmark_idx] += 1
                else:
                    # No prediction made
                    landmark_distances[landmark_idx].append(float('inf'))
                    landmark_accuracies[landmark_idx].append(False)
    
    # Calculate metrics for each landmark
    landmark_metrics = []
    
    for landmark_idx in range(num_landmarks):
        distances = landmark_distances[landmark_idx]
        accuracies = landmark_accuracies[landmark_idx]
        
        if len(distances) > 0:
            # Filter out infinite distances for mean calculation
            finite_distances = [d for d in distances if d != float('inf')]
            
            metrics = {
                'landmark_id': landmark_idx,
                'mean_distance': np.mean(finite_distances) if finite_distances else float('inf'),
                'median_distance': np.median(finite_distances) if finite_distances else float('inf'),
                'std_distance': np.std(finite_distances) if finite_distances else 0.0,
                'accuracy': np.mean(accuracies) if accuracies else 0.0,
                'detection_rate': detection_rates[landmark_idx] / len(accuracies) if accuracies else 0.0,
                'num_samples': len(accuracies)
            }
        else:
            metrics = {
                'landmark_id': landmark_idx,
                'mean_distance': float('inf'),
                'median_distance': float('inf'),
                'std_distance': 0.0,
                'accuracy': 0.0,
                'detection_rate': 0.0,
                'num_samples': 0
            }
        
        landmark_metrics.append(metrics)
    
    # Calculate overall metrics
    all_distances = [d for distances in landmark_distances for d in distances if d != float('inf')]
    all_accuracies = [acc for accuracies in landmark_accuracies for acc in accuracies]
    
    overall_metrics = {
        'overall_mean_distance': np.mean(all_distances) if all_distances else float('inf'),
        'overall_median_distance': np.median(all_distances) if all_distances else float('inf'),
        'overall_std_distance': np.std(all_distances) if all_distances else 0.0,
        'overall_accuracy': np.mean(all_accuracies) if all_accuracies else 0.0,
        'overall_detection_rate': sum(detection_rates) / (total_samples * num_landmarks),
        'total_samples': total_samples,
        'total_landmarks': num_landmarks
    }
    
    return {
        'landmark_metrics': landmark_metrics,
        'overall_metrics': overall_metrics
    }

def calculate_dice_score_batch(predictions, ground_truths):
    """
    Calculate Dice score for segmentation masks
    """
    dice_scores = []
    
    for pred_landmarks, gt_landmarks in zip(predictions, ground_truths):
        sample_dice_scores = []
        
        for pred_coord, gt_coord in zip(pred_landmarks, gt_landmarks):
            if pred_coord is not None and gt_coord is not None:
                # Create binary masks for Dice calculation
                # This is a simplified version - in practice, you'd use the actual segmentation masks
                pred_mask = np.zeros((512, 512))  # Adjust size as needed
                gt_mask = np.zeros((512, 512))
                
                # Create small circles around the coordinates
                cv2.circle(pred_mask, pred_coord, 5, 1, -1)
                cv2.circle(gt_mask, gt_coord, 5, 1, -1)
                
                # Calculate Dice score
                intersection = np.sum(pred_mask * gt_mask)
                union = np.sum(pred_mask) + np.sum(gt_mask)
                
                dice_score = 2 * intersection / union if union > 0 else 0.0
                sample_dice_scores.append(dice_score)
            else:
                sample_dice_scores.append(0.0)
        
        dice_scores.append(np.mean(sample_dice_scores))
    
    return np.mean(dice_scores)

# Example evaluation (uncomment when predictions are available)
# unet_metrics = evaluate_unet_predictions(unet_predictions, unet_ground_truths, UNET_CONFIG)

# print("\nU-Net Evaluation Results:")
# print("=" * 50)
# print(f"Overall Accuracy: {unet_metrics['overall_metrics']['overall_accuracy']:.4f}")
# print(f"Overall Mean Distance: {unet_metrics['overall_metrics']['overall_mean_distance']:.2f} pixels")
# print(f"Overall Detection Rate: {unet_metrics['overall_metrics']['overall_detection_rate']:.4f}")

# print("\nPer-Landmark Results:")
# for landmark_metric in unet_metrics['landmark_metrics']:
#     print(f"Landmark {landmark_metric['landmark_id'] + 1:2d}: "
#           f"Acc={landmark_metric['accuracy']:.3f}, "
#           f"Dist={landmark_metric['mean_distance']:.2f}px, "
#           f"Det={landmark_metric['detection_rate']:.3f}")

## 3. Visualization Functions

In [None]:
def visualize_retinanet_predictions(images, predictions, ground_truths, class_names=None, num_samples=4):
    """
    Visualize RetinaNet predictions
    """
    fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 5, 10))
    
    if num_samples == 1:
        axes = axes.reshape(2, 1)
    
    for i in range(min(num_samples, len(images))):
        image = images[i]
        pred = predictions[i]
        gt = ground_truths[i]
        
        # Original image with ground truth
        axes[0, i].imshow(image, cmap='gray')
        axes[0, i].set_title(f'Ground Truth - Sample {i+1}')
        axes[0, i].axis('off')
        
        # Draw ground truth boxes
        if gt is not None and len(gt) > 0:
            for annotation in gt:
                if annotation[4] != -1:  # Valid annotation
                    box = annotation[:4]
                    rect = patches.Rectangle(
                        (box[0], box[1]), box[2] - box[0], box[3] - box[1],
                        linewidth=2, edgecolor='green', facecolor='none'
                    )
                    axes[0, i].add_patch(rect)
        
        # Image with predictions
        axes[1, i].imshow(image, cmap='gray')
        axes[1, i].set_title(f'Predictions - Sample {i+1}')
        axes[1, i].axis('off')
        
        # Draw prediction boxes
        if len(pred['boxes']) > 0:
            for box, score, cls in zip(pred['boxes'], pred['scores'], pred['classes']):
                rect = patches.Rectangle(
                    (box[0], box[1]), box[2] - box[0], box[3] - box[1],
                    linewidth=2, edgecolor='red', facecolor='none'
                )
                axes[1, i].add_patch(rect)
                
                # Add score text
                class_name = class_names[int(cls)] if class_names else f'Class {int(cls)}'
                axes[1, i].text(
                    box[0], box[1] - 5, f'{class_name}: {score:.2f}',
                    color='red', fontsize=8, weight='bold'
                )
    
    plt.tight_layout()
    plt.show()

def visualize_unet_predictions(images, predictions, ground_truths, config, num_samples=4):
    """
    Visualize U-Net landmark predictions
    """
    fig, axes = plt.subplots(3, num_samples, figsize=(num_samples * 5, 15))
    
    if num_samples == 1:
        axes = axes.reshape(3, 1)
    
    colors = plt.cm.tab20(np.linspace(0, 1, config['NUM_LANDMARKS']))
    
    for i in range(min(num_samples, len(images))):
        image = images[i]
        pred_landmarks = predictions[i]
        gt_landmarks = ground_truths[i]
        
        # Original image
        axes[0, i].imshow(image, cmap='gray')
        axes[0, i].set_title(f'Original Image - Sample {i+1}')
        axes[0, i].axis('off')
        
        # Image with ground truth landmarks
        axes[1, i].imshow(image, cmap='gray')
        axes[1, i].set_title(f'Ground Truth Landmarks - Sample {i+1}')
        axes[1, i].axis('off')
        
        for landmark_idx, gt_coord in enumerate(gt_landmarks):
            if gt_coord is not None:
                axes[1, i].scatter(
                    gt_coord[0], gt_coord[1],
                    c=[colors[landmark_idx]], s=50, marker='o',
                    label=f'L{landmark_idx+1}' if i == 0 else None
                )
        
        # Image with predicted landmarks
        axes[2, i].imshow(image, cmap='gray')
        axes[2, i].set_title(f'Predicted Landmarks - Sample {i+1}')
        axes[2, i].axis('off')
        
        for landmark_idx, pred_coord in enumerate(pred_landmarks):
            if pred_coord is not None:
                axes[2, i].scatter(
                    pred_coord[0], pred_coord[1],
                    c=[colors[landmark_idx]], s=50, marker='x',
                    label=f'L{landmark_idx+1}' if i == 0 else None
                )
        
        # Draw lines between GT and predictions
        for landmark_idx, (gt_coord, pred_coord) in enumerate(zip(gt_landmarks, pred_landmarks)):
            if gt_coord is not None and pred_coord is not None:
                distance = calculate_euclidean_distance(gt_coord, pred_coord)
                axes[2, i].plot(
                    [gt_coord[0], pred_coord[0]],
                    [gt_coord[1], pred_coord[1]],
                    'r--', alpha=0.5, linewidth=1
                )
                
                # Add distance text
                mid_x = (gt_coord[0] + pred_coord[0]) / 2
                mid_y = (gt_coord[1] + pred_coord[1]) / 2
                axes[2, i].text(
                    mid_x, mid_y, f'{distance:.1f}px',
                    fontsize=8, color='red', weight='bold'
                )
    
    # Add legend
    if num_samples > 0:
        axes[1, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        axes[2, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.show()

def plot_evaluation_metrics(retinanet_metrics, unet_metrics, config):
    """
    Plot evaluation metrics for both models
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # RetinaNet metrics
    if retinanet_metrics:
        retinanet_values = [retinanet_metrics['mAP'], retinanet_metrics['precision'], 
                           retinanet_metrics['recall'], retinanet_metrics['f1_score']]
        retinanet_labels = ['mAP', 'Precision', 'Recall', 'F1-Score']
        
        axes[0, 0].bar(retinanet_labels, retinanet_values)
        axes[0, 0].set_title('RetinaNet Metrics')
        axes[0, 0].set_ylabel('Score')
        axes[0, 0].set_ylim(0, 1)
        
        for i, v in enumerate(retinanet_values):
            axes[0, 0].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')
    
    # U-Net overall metrics
    if unet_metrics:
        unet_values = [unet_metrics['overall_metrics']['overall_accuracy'], 
                      unet_metrics['overall_metrics']['overall_detection_rate']]
        unet_labels = ['Accuracy', 'Detection Rate']
        
        axes[0, 1].bar(unet_labels, unet_values)
        axes[0, 1].set_title('U-Net Overall Metrics')
        axes[0, 1].set_ylabel('Score')
        axes[0, 1].set_ylim(0, 1)
        
        for i, v in enumerate(unet_values):
            axes[0, 1].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')
    
    # U-Net per-landmark accuracy
    if unet_metrics:
        landmark_accuracies = [m['accuracy'] for m in unet_metrics['landmark_metrics']]
        landmark_ids = [f'L{i+1}' for i in range(len(landmark_accuracies))]
        
        axes[1, 0].bar(landmark_ids, landmark_accuracies)
        axes[1, 0].set_title('U-Net Per-Landmark Accuracy')
        axes[1, 0].set_ylabel('Accuracy')
        axes[1, 0].set_ylim(0, 1)
        axes[1, 0].tick_params(axis='x', rotation=45)
    
    # U-Net per-landmark mean distance
    if unet_metrics:
        landmark_distances = [m['mean_distance'] if m['mean_distance'] != float('inf') else 0 
                            for m in unet_metrics['landmark_metrics']]
        
        axes[1, 1].bar(landmark_ids, landmark_distances)
        axes[1, 1].set_title('U-Net Per-Landmark Mean Distance')
        axes[1, 1].set_ylabel('Distance (pixels)')
        axes[1, 1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()

# Example visualization (uncomment when data is available)
# visualize_retinanet_predictions(sample_images, predictions[:4], ground_truths[:4])
# visualize_unet_predictions(sample_images, unet_predictions[:4], unet_ground_truths[:4], UNET_CONFIG)
# plot_evaluation_metrics(retinanet_metrics, unet_metrics, UNET_CONFIG)

## 4. Performance Analysis and Comparison

In [None]:
def create_performance_report(retinanet_metrics, unet_metrics, retinanet_times, unet_times):
    """
    Create comprehensive performance report
    """
    report = {
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        'retinanet': {
            'metrics': retinanet_metrics,
            'performance': {
                'mean_inference_time': np.mean(retinanet_times) if retinanet_times else 0,
                'std_inference_time': np.std(retinanet_times) if retinanet_times else 0,
                'fps': 1.0 / np.mean(retinanet_times) if retinanet_times and np.mean(retinanet_times) > 0 else 0
            }
        },
        'unet': {
            'metrics': unet_metrics,
            'performance': {
                'mean_inference_time': np.mean(unet_times) if unet_times else 0,
                'std_inference_time': np.std(unet_times) if unet_times else 0,
                'fps': 1.0 / np.mean(unet_times) if unet_times and np.mean(unet_times) > 0 else 0
            }
        }
    }
    
    return report

def save_results_to_csv(report, filename='inference_results.csv'):
    """
    Save results to CSV file
    """
    # Create RetinaNet results
    retinanet_data = {
        'Model': 'RetinaNet',
        'mAP': report['retinanet']['metrics'].get('mAP', 0) if report['retinanet']['metrics'] else 0,
        'Precision': report['retinanet']['metrics'].get('precision', 0) if report['retinanet']['metrics'] else 0,
        'Recall': report['retinanet']['metrics'].get('recall', 0) if report['retinanet']['metrics'] else 0,
        'F1-Score': report['retinanet']['metrics'].get('f1_score', 0) if report['retinanet']['metrics'] else 0,
        'Inference Time (s)': report['retinanet']['performance']['mean_inference_time'],
        'FPS': report['retinanet']['performance']['fps']
    }
    
    # Create U-Net results
    unet_data = {
        'Model': 'U-Net',
        'Accuracy': report['unet']['metrics']['overall_metrics']['overall_accuracy'] if report['unet']['metrics'] else 0,
        'Detection Rate': report['unet']['metrics']['overall_metrics']['overall_detection_rate'] if report['unet']['metrics'] else 0,
        'Mean Distance (px)': report['unet']['metrics']['overall_metrics']['overall_mean_distance'] if report['unet']['metrics'] else 0,
        'Inference Time (s)': report['unet']['performance']['mean_inference_time'],
        'FPS': report['unet']['performance']['fps']
    }
    
    # Save to CSV
    df_retinanet = pd.DataFrame([retinanet_data])
    df_unet = pd.DataFrame([unet_data])
    
    df_retinanet.to_csv(f'retinanet_{filename}', index=False)
    df_unet.to_csv(f'unet_{filename}', index=False)
    
    print(f"Results saved to retinanet_{filename} and unet_{filename}")

def print_performance_summary(report):
    """
    Print performance summary
    """
    print("\n" + "="*70)
    print("PERFORMANCE SUMMARY")
    print("="*70)
    print(f"Report generated: {report['timestamp']}")
    
    print("\n🎯 RETINANET RESULTS:")
    print("-" * 30)
    if report['retinanet']['metrics']:
        print(f"mAP: {report['retinanet']['metrics'].get('mAP', 0):.4f}")
        print(f"Precision: {report['retinanet']['metrics'].get('precision', 0):.4f}")
        print(f"Recall: {report['retinanet']['metrics'].get('recall', 0):.4f}")
        print(f"F1-Score: {report['retinanet']['metrics'].get('f1_score', 0):.4f}")
    else:
        print("No RetinaNet metrics available")
    
    print(f"Inference Time: {report['retinanet']['performance']['mean_inference_time']:.4f} ± {report['retinanet']['performance']['std_inference_time']:.4f} seconds")
    print(f"FPS: {report['retinanet']['performance']['fps']:.2f}")
    
    print("\n🎯 U-NET RESULTS:")
    print("-" * 30)
    if report['unet']['metrics']:
        print(f"Overall Accuracy: {report['unet']['metrics']['overall_metrics']['overall_accuracy']:.4f}")
        print(f"Detection Rate: {report['unet']['metrics']['overall_metrics']['overall_detection_rate']:.4f}")
        print(f"Mean Distance: {report['unet']['metrics']['overall_metrics']['overall_mean_distance']:.2f} pixels")
        print(f"Total Samples: {report['unet']['metrics']['overall_metrics']['total_samples']}")
        print(f"Total Landmarks: {report['unet']['metrics']['overall_metrics']['total_landmarks']}")
    else:
        print("No U-Net metrics available")
    
    print(f"Inference Time: {report['unet']['performance']['mean_inference_time']:.4f} ± {report['unet']['performance']['std_inference_time']:.4f} seconds")
    print(f"FPS: {report['unet']['performance']['fps']:.2f}")
    
    print("\n" + "="*70)

# Example usage (uncomment when all results are available)
# performance_report = create_performance_report(retinanet_metrics, unet_metrics, times, unet_times)
# print_performance_summary(performance_report)
# save_results_to_csv(performance_report)

## 5. Complete Inference Pipeline

In [None]:
def run_complete_inference_pipeline():
    """
    Run complete inference pipeline for both models
    """
    print("Starting complete inference pipeline...")
    
    # 1. Load models
    print("\n1. Loading models...")
    # retinanet_model = load_retinanet_model(RETINANET_CONFIG['MODEL_PATH'], RETINANET_CONFIG['NUM_CLASSES'], device)
    # unet_models = load_unet_models(UNET_CONFIG, device)
    
    # 2. Prepare test datasets
    print("\n2. Preparing test datasets...")
    # test_dataloader, test_dataset = create_retinanet_test_dataset(RETINANET_CONFIG, RETINANET_DATA_PATHS)
    # x_test, y_test = prepare_unet_test_dataset(UNET_DATA_PATHS, UNET_CONFIG)
    # test_loaders = create_unet_test_dataloaders(x_test, y_test, UNET_CONFIG)
    
    # 3. Run inference
    print("\n3. Running inference...")
    # retinanet_predictions, retinanet_ground_truths, retinanet_times = batch_inference_retinanet(retinanet_model, test_dataloader, device, RETINANET_CONFIG)
    # unet_predictions, unet_ground_truths, unet_times = batch_inference_unet(unet_models, test_loaders, device, UNET_CONFIG)
    
    # 4. Evaluate results
    print("\n4. Evaluating results...")
    # retinanet_metrics = evaluate_retinanet_predictions(retinanet_predictions, retinanet_ground_truths)
    # unet_metrics = evaluate_unet_predictions(unet_predictions, unet_ground_truths, UNET_CONFIG)
    
    # 5. Generate report
    print("\n5. Generating performance report...")
    # performance_report = create_performance_report(retinanet_metrics, unet_metrics, retinanet_times, unet_times)
    # print_performance_summary(performance_report)
    
    # 6. Save results
    print("\n6. Saving results...")
    # save_results_to_csv(performance_report)
    
    # 7. Create visualizations
    print("\n7. Creating visualizations...")
    # visualize_retinanet_predictions(sample_images, retinanet_predictions[:4], retinanet_ground_truths[:4])
    # visualize_unet_predictions(sample_images, unet_predictions[:4], unet_ground_truths[:4], UNET_CONFIG)
    # plot_evaluation_metrics(retinanet_metrics, unet_metrics, UNET_CONFIG)
    
    print("\n✅ Inference pipeline completed successfully!")
    
    # return performance_report

# Run the complete pipeline (uncomment when ready)
# final_report = run_complete_inference_pipeline()

print("\n📋 USAGE INSTRUCTIONS:")
print("1. Update model paths in RETINANET_CONFIG and UNET_CONFIG")
print("2. Update dataset paths in RETINANET_DATA_PATHS and UNET_DATA_PATHS")
print("3. Uncomment the relevant sections to run inference")
print("4. Run run_complete_inference_pipeline() for end-to-end evaluation")
print("\n🎯 This notebook provides:")
print("- Single image and batch inference for both models")
print("- Comprehensive evaluation metrics")
print("- Performance analysis and timing")
print("- Visualization of results")
print("- CSV export of results")
print("- Complete inference pipeline")