In [1]:
import os
import cv2
import numpy as np
from pathlib import Path

In [2]:
def load_yolo_bbox(label_path):
    """
    Load YOLO format bounding boxes from a label file.
    Returns list of tuples (class_id, x_center, y_center, width, height)
    """
    boxes = []
    with open(label_path, 'r') as f:
        for line in f:
            # Parse YOLO format: class x_center y_center width height
            class_id, x_center, y_center, width, height, *confidence = map(float, line.strip().split())
            boxes.append((class_id, x_center, y_center, width, height))
    return boxes

def yolo_to_pixel_coords(box, img_width, img_height):
    """
    Convert YOLO format (x_center, y_center, width, height) to pixel coordinates
    Returns (x_min, y_min, x_max, y_max)
    """
    x_center, y_center, width, height = box[1:]
    
    # Convert from relative coordinates to absolute pixels
    x_center = x_center * img_width
    y_center = y_center * img_height
    width = width * img_width
    height = height * img_height
    
    # Calculate corners
    x_min = int(x_center - width/2)
    y_min = int(y_center - height/2)
    x_max = int(x_center + width/2)
    y_max = int(y_center + height/2)
    
    # Ensure coordinates are within image bounds
    x_min = max(0, x_min)
    y_min = max(0, y_min)
    x_max = min(img_width, x_max)
    y_max = min(img_height, y_max)
    
    return x_min, y_min, x_max, y_max

def extract_patches(image_dir, label_dir, output_dir=None, patch_size=32):
    """
    Extract patches from images based on YOLO format bounding boxes.
    Patches will be expanded to patch_size x patch_size while maintaining their center point.
    If output_dir is provided, saves patches to files.
    Otherwise, displays them using cv2.imshow()
    """
    image_dir = Path(image_dir)
    label_dir = Path(label_dir)
    
    if output_dir:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
    
    # Process all PNG files that have corresponding label files
    for img_path in image_dir.glob('*.png'):
        # Find corresponding label file
        label_path = label_dir / f"{img_path.stem}.txt"
        if not label_path.exists():
            print(f"No label file found for {img_path}")
            continue
            
        # Read image and get dimensions
        img = cv2.imread(str(img_path))
        if img is None:
            print(f"Could not read image {img_path}")
            continue
            
        img_height, img_width = img.shape[:2]
        
        # Load bounding boxes
        boxes = load_yolo_bbox(label_path)
        
        # Extract and save/display patches
        for i, box in enumerate(boxes):
            x_min, y_min, x_max, y_max = yolo_to_pixel_coords(box, img_width, img_height)
            
            # Calculate center point
            x_center = (x_min + x_max) // 2
            y_center = (y_min + y_max) // 2
            
            # Calculate new boundaries for fixed-size patch
            half_size = patch_size // 2
            x_min = max(0, x_center - half_size)
            x_max = min(img_width, x_center + half_size)
            y_min = max(0, y_center - half_size)
            y_max = min(img_height, y_center + half_size)
            
            # Handle edge cases where patch would go outside image bounds
            if x_max - x_min < patch_size:
                if x_min == 0:  # If we hit the left edge
                    x_max = min(img_width, patch_size)
                else:  # If we hit the right edge
                    x_min = max(0, img_width - patch_size)
                    
            if y_max - y_min < patch_size:
                if y_min == 0:  # If we hit the top edge
                    y_max = min(img_height, patch_size)
                else:  # If we hit the bottom edge
                    y_min = max(0, img_height - patch_size)
            
            # Extract patch
            patch = img[y_min:y_max, x_min:x_max]
            
            if patch.size == 0:
                print(f"Warning: Empty patch extracted from {img_path.name}, box {i}")
                continue
                
            if output_dir:
                # Save patch
                patch_path = output_dir / f"{img_path.stem}_patch_{i}.png"
                try:
                    cv2.imwrite(str(patch_path), patch)
                except Exception as e:
                    print(f"Error saving patch from {img_path.name}, box {i}: {str(e)}")
                    print(f"Patch shape: {patch.shape}")
            else:
                # Display patch
                cv2.imshow(f"Patch {i} from {img_path.name}", patch)
                cv2.waitKey(0)
                
    if not output_dir:
        cv2.destroyAllWindows()

# Example usage
def visualize_boxes(image_dir, label_dir, output_dir=None, patch_size=32):
    """
    Visualize both original and expanded bounding boxes on the images.
    Red box: Original YOLO bounding box
    Green box: Expanded 32x32 box
    """
    image_dir = Path(image_dir)
    label_dir = Path(label_dir)
    
    if output_dir:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
    
    for img_path in image_dir.glob('*.png'):
        # Find corresponding label file
        label_path = label_dir / f"{img_path.stem}.txt"
        if not label_path.exists():
            continue
            
        # Read image and get dimensions
        img = cv2.imread(str(img_path))
        if img is None:
            continue
            
        img_height, img_width = img.shape[:2]
        vis_img = img.copy()
        
        # Load bounding boxes
        boxes = load_yolo_bbox(label_path)
        
        for i, box in enumerate(boxes):
            # Get original box coordinates
            x_min, y_min, x_max, y_max = yolo_to_pixel_coords(box, img_width, img_height)
            
            # # Draw original box in red
            # cv2.rectangle(vis_img, (x_min, y_min), (x_max, y_max), (0, 0, 255), 1)
            
            # Calculate expanded box coordinates
            x_center = (x_min + x_max) // 2
            y_center = (y_min + y_max) // 2
            
            half_size = patch_size // 2
            exp_x_min = max(0, x_center - half_size)
            exp_x_max = min(img_width, x_center + half_size)
            exp_y_min = max(0, y_center - half_size)
            exp_y_max = min(img_height, y_center + half_size)
            
            # Handle edge cases
            if exp_x_max - exp_x_min < patch_size:
                if exp_x_min == 0:
                    exp_x_max = min(img_width, patch_size)
                else:
                    exp_x_min = max(0, img_width - patch_size)
                    
            if exp_y_max - exp_y_min < patch_size:
                if exp_y_min == 0:
                    exp_y_max = min(img_height, patch_size)
                else:
                    exp_y_min = max(0, img_height - patch_size)
            
            # Draw expanded box in green
            cv2.rectangle(vis_img, (exp_x_min, exp_y_min), (exp_x_max, exp_y_max), (0, 255, 0), 1)
            
        if output_dir:
            # Save visualization
            vis_path = output_dir / f"{img_path.stem}_boxes.png"
            cv2.imwrite(str(vis_path), vis_img)
        else:
            # Display visualization
            cv2.imshow(f"Boxes - {img_path.name}", vis_img)
            cv2.waitKey(0)
    
    if not output_dir:
        cv2.destroyAllWindows()

if __name__ == "__main__":
    image_dir = "/mnt/storage/ji/brain_mri_valdo_mayo/YOLO_valdo_stacked_temp/images/train"
    label_dir = "/mnt/storage/ji/brain_mri_valdo_mayo/YOLO_valdo_stacked_temp/labels/train"
    
    # Extract patches
    output_dir = "YOLO_valdo_stacked_temp_extracted_patches"
    extract_patches(image_dir, label_dir, output_dir, patch_size=32)
    
    # Visualize boxes
    vis_output_dir = "YOLO_valdo_stacked_temp_visualized_boxes"
    visualize_boxes(image_dir, label_dir, vis_output_dir, patch_size=32)