In [1]:
import os
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm

# Directories
image_dir = "/Users/nilsmanni/Desktop/MA_3/ML/muticlass_model/masks_4_categories"  # Directory containing the RGB images
gt_dir = "/Users/nilsmanni/Desktop/MA_3/ML/ground_truth_img"         # Directory containing the ground truth masks
output_dir = "/Users/nilsmanni/Desktop/MA_3/ML/processed_data/"  # Directory to save processed patches
patch_size = 256  # Size of the square patches

# Ensure output directories exist
os.makedirs(output_dir, exist_ok=True)
image_output_dir = os.path.join(output_dir, "images")
gt_output_dir = os.path.join(output_dir, "groundtruths")
os.makedirs(image_output_dir, exist_ok=True)
os.makedirs(gt_output_dir, exist_ok=True)

# Transformation for preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to Tensor (C, H, W)
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize using ImageNet stats
                         std=[0.229, 0.224, 0.225])
])

# Function to extract patches from an image
def extract_patches(image, patch_size):
    patches = []
    height, width = image.shape[-2], image.shape[-1]  # H, W from (C, H, W)
    for y in range(0, height - patch_size + 1, patch_size):
        for x in range(0, width - patch_size + 1, patch_size):
            patch = image[:, y:y+patch_size, x:x+patch_size]  # Extract patch (C, H, W)
            patches.append(patch)
    return patches

# Preprocess RGB images and corresponding ground truth masks
def preprocess_images_and_masks(image_dir, gt_dir, output_dir, patch_size):
    image_files = sorted(os.listdir(image_dir))
    gt_files = sorted(os.listdir(gt_dir))

    for image_file, gt_file in tqdm(zip(image_files, gt_files), total=len(image_files), desc="Processing Images"):
        # Load image and ground truth mask
        image_path = os.path.join(image_dir, image_file)
        gt_path = os.path.join(gt_dir, gt_file)

        image = Image.open(image_path).convert('RGB')  # Ensure image is RGB
        gt = Image.open(gt_path).convert('L')  # Convert ground truth to grayscale

        # Transform the image and mask
        image_tensor = transform(image)
        gt_tensor = torch.tensor(np.array(gt), dtype=torch.float32).unsqueeze(0)  # Add channel dim (1, H, W)

        # Extract patches
        image_patches = extract_patches(image_tensor, patch_size)
        gt_patches = extract_patches(gt_tensor, patch_size)

        # Save patches
        base_name = os.path.splitext(image_file)[0]
        for i, (img_patch, gt_patch) in enumerate(zip(image_patches, gt_patches)):
            img_patch_path = os.path.join(image_output_dir, f"{base_name}_patch_{i}.pt")
            gt_patch_path = os.path.join(gt_output_dir, f"{base_name}_gt_patch_{i}.pt")
            torch.save(img_patch, img_patch_path)
            torch.save(gt_patch, gt_patch_path)

# Run preprocessing
preprocess_images_and_masks(image_dir, gt_dir, output_dir, patch_size)


Processing Images:  99%|█████████▉| 190/191 [02:31<00:00,  1.26it/s]


In [2]:
images_paths_dict = {}
groundtruths_paths_dict = {}

# Dynamically calculate the number of patches
for idx, image_file in enumerate(sorted(os.listdir(image_output_dir))):
    if "patch" in image_file:
        # Extract patch index from filename
        patch_idx = int(image_file.split("_patch_")[-1].split(".")[0])
        if idx not in images_paths_dict:
            images_paths_dict[patch_idx] = []
        images_paths_dict[patch_idx].append(os.path.join(image_output_dir, image_file))

for idx, gt_file in enumerate(sorted(os.listdir(gt_output_dir))):
    if "patch" in gt_file:
        # Extract patch index from filename
        patch_idx = int(gt_file.split("_gt_patch_")[-1].split(".")[0])
        if idx not in groundtruths_paths_dict:
            groundtruths_paths_dict[patch_idx] = []
        groundtruths_paths_dict[patch_idx].append(os.path.join(gt_output_dir, gt_file))

# Load dataset
dataset = PatchesDataset(images_paths_dict, groundtruths_paths_dict, "AntsSegmentation")


NameError: name 'PatchesDataset' is not defined