In [1]:
import os
import json
import torch
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision import transforms
from torchvision.transforms.functional import to_tensor, resize, hflip, vflip, rotate, normalize
import random
import cv2

In [2]:
def apply_joint_transforms(image, mask, target_size=(256, 256)):
    # Random horizontal flip
    if random.random() > 0.5:
        image = hflip(image)
        mask = hflip(mask)

    # Random vertical flip
    if random.random() > 0.5:
        image = vflip(image)
        mask = vflip(mask)

    # Random rotation
    angle = random.uniform(-15, 15)
    image = rotate(image, angle, expand=False)
    mask = rotate(mask, angle, expand=False)

    # Resize
    image = resize(image, target_size)
    mask = resize(mask, target_size)

    # To tensor + normalize
    image = to_tensor(image)
    image = normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    mask = to_tensor(mask)

    return image, mask.float() # Ensure mask is [0, 1]

In [3]:
def preprocess_and_save_segmentation(root_dir, partition, output_path, transform, target_size=(256, 256)):
    with open(os.path.join(root_dir, 'annotations.json'), 'r') as f:
        annotations = json.load(f)

    images = annotations['images']
    corners = annotations["annotations"]['corners']

    valid_image_ids = list(range(len(corners)))

    # Train/val/test split
    train_ids, temp_ids = train_test_split(valid_image_ids, test_size=0.3, random_state=42) # 70% train
    valid_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42) # 10% valid, 10% test

    if partition == 'train':
        splits = train_ids
    elif partition == 'valid':
        splits = valid_ids
    elif partition == 'test':
        splits = test_ids
    else:
        raise ValueError(f"Unknown partition: {partition}")

    data = []

    for img in images:
        img_id = img['id']
        if img_id not in splits:
            continue

        img_path = os.path.join(root_dir, 'all_images', img['file_name'])
        orig_image = Image.open(img_path).convert('RGB')
        orig_w, orig_h = orig_image.size

        # Step 1: Extract and scale corner points
        corner_keys = ["bottom_right", "top_right", "top_left", "bottom_left"]
        if img_id <= len(corners) - 1:
            corner_data = corners[img_id]['corners']
        else:
            print(f"Warning: No corner data for image ID {img_id}")
            continue

        polygon = []
        for key in corner_keys:
            if key in corner_data:
                x, y = corner_data[key]
                x_scaled = int(x * (target_size[0] / orig_w))
                y_scaled = int(y * (target_size[1] / orig_h))
                polygon.append([x_scaled, y_scaled])
            else:
                polygon.append([0, 0])

        polygon = np.array([polygon], dtype=np.int32)

        # Step 2: Create binary mask (PIL)
        mask_np = np.zeros(target_size, dtype=np.uint8)
        cv2.fillPoly(mask_np, [polygon], 255)
        mask_pil = Image.fromarray(mask_np)

        if partition == 'train':
            # Apply data augmentation for training partition
            for _ in range(3):  # Duplicate training data 3 times
                augmented_image, augmented_mask = apply_joint_transforms(orig_image, mask_pil, target_size)
                data.append((augmented_image, augmented_mask))
        else:
            # Apply standard transform for validation/test partitions
            image = transform(orig_image)
            mask_pil = resize(mask_pil, target_size)
            mask_pil = to_tensor(mask_pil).float()
            mask = mask_pil.unsqueeze(0) if mask_pil.ndim == 2 else mask_pil  # (1, H, W)
            data.append((image, mask))

    torch.save(data, output_path)
    print(f"{partition} set saved to {output_path} with {len(data)} samples")

In [4]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

preprocess_and_save_segmentation("../Shared", "train", "train.pt", transform)
preprocess_and_save_segmentation("../Shared", "valid", "valid.pt", transform)
preprocess_and_save_segmentation("../Shared", "test", "test.pt", transform)

train set saved to train.pt with 4362 samples
valid set saved to valid.pt with 312 samples
test set saved to test.pt with 312 samples
