In [1]:

import os
import cv2
import numpy as np
from patchify import patchify
import shutil

In [2]:
def root_only_2023(image_dir, mask_dir, output_image_dir, output_mask_dir, mask_suffix='_root_mask'):
    # Collect all image filenames
    images = {os.path.splitext(f)[0] for f in os.listdir(image_dir) if f.endswith('.png')}

    # Collect all root mask filenames
    masks = {os.path.splitext(f)[0].replace(mask_suffix, '') for f in os.listdir(mask_dir) if
             f.endswith('.tif') and mask_suffix in f}

    # Identify valid images (only those with root masks)
    valid_images = images.intersection(masks)

    # Create output directories
    os.makedirs(output_image_dir, exist_ok=True)
    os.makedirs(output_mask_dir, exist_ok=True)

    # Copy valid images and their masks
    for image in valid_images:
        shutil.copy(os.path.join(image_dir, f"{image}.png"), os.path.join(output_image_dir, f"{image}.png"))
        shutil.copy(os.path.join(mask_dir, f"{image}{mask_suffix}.tif"),
                    os.path.join(output_mask_dir, f"{image}{mask_suffix}.tif"))

    # Log excluded files
    excluded_images = images - valid_images
    excluded_masks = masks - valid_images
    print("Action performed on 2023 dataset. A new dataset with root masks and images only is created.")


In [3]:
def root_only_2024(image_dir, mask_dir, output_image_dir, output_mask_dir, mask_suffix='_root_mask'):
    """
    Clean the mentor-based dataset (2024 structure) by ensuring each image has a corresponding root mask.

    Args:
        image_dir: Directory containing images.
        mask_dir: Directory containing masks.
        output_image_dir: Directory to save cleaned images.
        output_mask_dir: Directory to save cleaned masks.
        mask_suffix: Suffix used in root mask filenames to differentiate from images.

    Returns:
        None. Saves cleaned dataset to specified output directories.
    """
    # Collect all image filenames
    image_files = {os.path.splitext(f)[0] for f in os.listdir(image_dir) if f.endswith('.png')}

    # Collect all root mask filenames
    mask_files = set()
    mask_paths = {}

    # Process mentor subfolders
    for root, dirs, files in os.walk(mask_dir):
        for f in files:
            if f.endswith('.tif') and mask_suffix in f:
                base_name = os.path.splitext(f)[0].replace(mask_suffix, '')
                mask_files.add(base_name)
                mask_paths[base_name] = os.path.join(root, f)

    # Identify valid images (only those with root masks)
    valid_images = image_files.intersection(mask_files)

    # Create output directories
    os.makedirs(output_image_dir, exist_ok=True)
    os.makedirs(output_mask_dir, exist_ok=True)

    # Copy valid images and their masks
    for image in valid_images:
        shutil.copy(os.path.join(image_dir, f"{image}.png"), os.path.join(output_image_dir, f"{image}.png"))

        # Save masks in a single folder
        mask_path = mask_paths[image]
        shutil.copy(mask_path, os.path.join(output_mask_dir, f"{image}_root_mask.tif"))

    # Log excluded files
    excluded_images = image_files - valid_images
    excluded_masks = mask_files - valid_images
    print("Action performed on 2024 dataset. A new dataset with root masks and images only is created.")


In [4]:
# Process 2023 dataset
root_only_2023(
    image_dir="./../client_data/Y2B_23/images/train",
    mask_dir="./../client_data/Y2B_23/masks",
    output_image_dir="root_only_2023/images",
    output_mask_dir="root_only_2023/masks",
)

In [5]:
# Process 2024 dataset
root_only_2024(
    image_dir="./../client_data/Y2B_24/images",
    mask_dir="./../client_data/Y2B_24/masks",
    output_image_dir="root_only_2024/images",
    output_mask_dir="root_only_2024/masks",
)

In [6]:
def process_petri_and_masks(image_dir, mask_dir, mask_suffix='_root_mask', output_size=(1024, 1024)):
    """
    Process images and masks in memory by detecting Petri dish regions, aligning masks, and resizing.

    Args:
        image_dir: Directory containing images.
        mask_dir: Directory containing masks.
        mask_suffix: Suffix used in mask filenames.
        output_size: Tuple specifying the output dimensions (width, height).

    Returns:
        A dictionary with processed images and masks as numpy arrays.
    """
    processed_data = {}

    for image_file in os.listdir(image_dir):
        if not image_file.endswith('.png'):
            continue

        image_name = os.path.splitext(image_file)[0]
        mask_file = f"{image_name}{mask_suffix}.tif"
        image_path = os.path.join(image_dir, image_file)
        mask_path = os.path.join(mask_dir, mask_file)

        if not os.path.exists(mask_path):
            print(f"Mask not found for {image_file}. Skipping.")
            continue

        # Read image and mask
        image = cv2.imread(image_path)
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)

        if len(mask.shape) > 2:
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)

        # Detect Petri dish in the image
        gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        _, binary = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        largest_contour = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(largest_contour)

        # Crop and resize image and mask
        cropped_image = image[y:y + h, x:x + w]
        cropped_mask = mask[y:y + h, x:x + w]
        resized_image = cv2.resize(cropped_image, output_size, interpolation=cv2.INTER_AREA)
        resized_mask = cv2.resize(cropped_mask, output_size, interpolation=cv2.INTER_NEAREST)

        # Store processed data in memory
        processed_data[image_name] = {
            "image": resized_image,
            "mask": resized_mask
        }

    print(f"Processed {len(processed_data)} images and masks in memory.")
    return processed_data


# Process Petri dish and masks for 2023 dataset in memory
processed_2023 = process_petri_and_masks(
    image_dir="root_only_2023/images",
    mask_dir="root_only_2023/masks",
)

# Process Petri dish and masks for 2024 dataset in memory
processed_2024 = process_petri_and_masks(
    image_dir="root_only_2024/images",
    mask_dir="root_only_2024/masks",
)

In [7]:
import random
import matplotlib.pyplot as plt


def test_random_overlays(processed_data, num_samples=3):
    """
    Test overlay alignment for randomly selected images and masks.
    
    Args:
        processed_data: Dictionary containing processed images and masks.
        num_samples: Number of random samples to test.
    
    Returns:
        None. Displays overlays.
    """
    # Randomly select samples
    keys = random.sample(list(processed_data.keys()), min(num_samples, len(processed_data)))

    for key in keys:
        image_data = processed_data[key]["image"]
        mask_data = processed_data[key]["mask"]

        # Ensure the mask is in grayscale
        if len(mask_data.shape) == 3:
            mask_data = cv2.cvtColor(mask_data, cv2.COLOR_BGR2GRAY)

        # Create an overlay
        overlay = image_data.copy()
        overlay[mask_data > 0] = [255, 0, 0]  # Highlight mask regions in red

        # Display the image, mask, and overlay
        plt.figure(figsize=(12, 4))
        plt.suptitle(f"Test: {key} (Dimensions: {image_data.shape})")

        plt.subplot(1, 3, 1)
        plt.title("Image")
        plt.imshow(cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB))
        plt.axis("off")

        plt.subplot(1, 3, 2)
        plt.title("Mask")
        plt.imshow(mask_data, cmap="gray")
        plt.axis("off")

        plt.subplot(1, 3, 3)
        plt.title("Overlay")
        plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
        plt.axis("off")

        plt.show()


# Test overlays for 2023 dataset
print("Testing overlays for 2023 dataset...")
test_random_overlays(processed_2023)

# Test overlays for 2024 dataset
print("Testing overlays for 2024 dataset...")
test_random_overlays(processed_2024)


In [8]:
def process_2023_patches_with_patchify(processed_2023, patch_size, output_dir="dataset/train"):
    os.makedirs(os.path.join(output_dir, "train_images/train"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "train_masks/train"), exist_ok=True)

    for base_name, data in processed_2023.items():
        image, mask = data["image"], data["mask"]
        patch_h, patch_w = patch_size

        # Extract patches using patchify
        image_patches = patchify(image, (patch_h, patch_w, 3), step=patch_h)
        mask_patches = patchify(mask, (patch_h, patch_w), step=patch_h)

        for i in range(image_patches.shape[0]):
            for j in range(image_patches.shape[1]):
                patch_image = image_patches[i, j, 0]
                patch_mask = mask_patches[i, j]

                image_path = os.path.join(output_dir, "train_images/train", f"{base_name}_patch_{i}_{j}.png")
                mask_path = os.path.join(output_dir, "train_masks/train", f"{base_name}_patch_{i}_{j}_root_mask.png")

                cv2.imwrite(image_path, patch_image)
                cv2.imwrite(mask_path, patch_mask)

In [9]:
def process_2024_patches_with_patchify(processed_2024, patch_size, output_dir="dataset"):
    os.makedirs(os.path.join(output_dir, "train_images/train"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "train_masks/train"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "val_images/val"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "val_masks/val"), exist_ok=True)

    for base_name, data in processed_2024.items():
        image, mask = data["image"], data["mask"]
        patch_h, patch_w = patch_size
        folder_prefix = "val" if base_name.startswith("val_") else "train"

        # Extract patches using patchify
        image_patches = patchify(image, (patch_h, patch_w, 3), step=patch_h)
        mask_patches = patchify(mask, (patch_h, patch_w), step=patch_h)

        for i in range(image_patches.shape[0]):
            for j in range(image_patches.shape[1]):
                patch_image = image_patches[i, j, 0]
                patch_mask = mask_patches[i, j]

                image_path = os.path.join(output_dir, f"{folder_prefix}_images/{folder_prefix}",
                                          f"{base_name}_patch_{i}_{j}.png")
                mask_path = os.path.join(output_dir, f"{folder_prefix}_masks/{folder_prefix}",
                                         f"{base_name}_patch_{i}_{j}_root_mask.png")

                cv2.imwrite(image_path, patch_image)
                cv2.imwrite(mask_path, patch_mask)


In [10]:
# Generate patches for the 2023 dataset
process_2023_patches_with_patchify(
    processed_2023=processed_2023,
    patch_size=(128, 128),
    output_dir="dataset",
)

In [11]:
# Generate patches for the 2024 dataset
process_2024_patches_with_patchify(
    processed_2024=processed_2024,
    patch_size=(128, 128),
    output_dir="dataset",
)

In [12]:
def overlay_image_and_mask(image_path, mask_path):
    """
    Overlay a specific image and mask to visually confirm alignment.

    Args:
        image_path (str): Path to the image patch.
        mask_path (str): Path to the mask patch.

    Returns:
        None. Displays the image, mask, and overlay using matplotlib.
    """
    # Load image and mask
    image = cv2.imread(image_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    # Check if image and mask dimensions match
    assert image.shape[:2] == mask.shape[:2], "Image and mask dimensions do not match!"

    # Overlay mask on the image
    overlay = image.copy()
    overlay[mask > 0] = [255, 0, 0]  # Red overlay for the mask

    # Plot the results
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.title("Image")
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.title("Mask")
    plt.imshow(mask, cmap="gray")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.title("Overlay")
    plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    plt.axis("off")

    plt.show()

# overlay_image_and_mask(
# image_path=r",
# mask_path=r")