In [None]:
import os
import json
import random
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm

# --- Paired Rotation Class ---
class RandomRotatePair:
    def __init__(self, degrees):
        self.degrees = degrees

    def __call__(self, image, mask):
        angle = random.uniform(-self.degrees, self.degrees)
        image = TF.rotate(image, angle)
        mask = TF.rotate(mask, angle, interpolation=transforms.InterpolationMode.NEAREST)
        return image, mask

# --- Dataset-like class ---
class SemiSegmentationExporter:
    def __init__(self, images_dir, masks_dir, annotation_path, transform=None, mask_transform=None):
        with open(annotation_path, 'r') as f:
            self.annotations = json.load(f)
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.mask_transform = mask_transform

    def export(self, save_dir_images, save_dir_masks, max_samples=None):
        os.makedirs(save_dir_images, exist_ok=True)
        os.makedirs(save_dir_masks, exist_ok=True)

        samples = self.annotations if max_samples is None else self.annotations[:max_samples]

        new_annotations = []

        for idx, ann in tqdm(enumerate(samples), total=len(samples)):
            img_name = ann['image']
            mask_name = ann['mask']
            class_id = ann.get('class_id', None)  # keep class_id if present

            img_path = os.path.join(self.images_dir, img_name)
            mask_path = os.path.join(self.masks_dir, mask_name)

            image = Image.open(img_path).convert('RGB')
            mask = Image.open(mask_path)

            # Apply augmentation
            image, mask = RandomRotatePair(360)(image, mask)

            if self.transform:
                image_t = self.transform(image)
            else:
                image_t = image

            if self.mask_transform:
                mask_t = self.mask_transform(mask)
                if mask_t.dim() == 3 and mask_t.shape[0] == 1:
                    mask_t = mask_t.squeeze(0)
                mask_t = mask_t.long()
            else:
                mask_t = torch.from_numpy(np.array(mask, dtype=np.int64))

            # Save augmented image and mask
            aug_img_name = f"aug_{idx:04d}.png"
            aug_mask_name = f"aug_{idx:04d}.png"

            save_img_path = os.path.join(save_dir_images, aug_img_name)
            save_mask_path = os.path.join(save_dir_masks, aug_mask_name)

            # Convert tensor back to PIL Image for saving if needed
            if isinstance(image_t, torch.Tensor):
                save_img = transforms.ToPILImage()(image_t)
            else:
                save_img = image_t

            if isinstance(mask_t, torch.Tensor):
                save_mask = Image.fromarray(mask_t.cpu().numpy().astype(np.uint8))
            else:
                save_mask = mask_t

            save_img.save(save_img_path)
            save_mask.save(save_mask_path)

            # Append new annotation entry
            new_ann = {
                'image': aug_img_name,
                'mask': aug_mask_name,
            }
            if class_id is not None:
                new_ann['class_id'] = class_id
            new_annotations.append(new_ann)

        # Save the new annotations JSON file
        new_ann_path = os.path.join(os.path.dirname(save_dir_images), 'annotations.json')
        with open(new_ann_path, 'w') as f:
            json.dump(new_annotations, f, indent=2)
        print(f"Saved new annotations file to: {new_ann_path}")


# --- Transform setup ---
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

mask_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.NEAREST),
    transforms.PILToTensor(),
    transforms.Lambda(lambda x: x[0] if x.shape[0] == 1 else x.max(dim=0)[0]),
])

exporter = SemiSegmentationExporter(
    images_dir='train-semi',
    masks_dir='train-semi-segmentation',
    annotation_path='train_semi_annotations_with_seg_ids.json',
    transform=image_transform,
    mask_transform=mask_transform
)

exporter.export(
    save_dir_images='augmented/train-semi',
    save_dir_masks='augmented/train-semi-segmentation',
    max_samples=None
)

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [00:18<00:00, 26.85it/s]

Saved new annotations file to: augmented/annotations.json



