In [None]:
import os
import numpy as np
from pycocotools.coco import COCO
from PIL import Image, ImageDraw
import cv2
import torch
from torchvision import transforms

def coco_ann_to_mask(ann, height, width):
    """Convert COCO annotation to binary mask."""
    from pycocotools import mask as maskUtils

    if isinstance(ann['segmentation'], list):
        rles = maskUtils.frPyObjects(ann['segmentation'], height, width)
        rle = maskUtils.merge(rles)
    elif isinstance(ann['segmentation']['counts'], list):
        rle = maskUtils.frPyObjects(ann['segmentation'], height, width)
    else:
        rle = ann['segmentation']
    return maskUtils.decode(rle)

def preprocess_for_vit(base_path, rgb_ann_file, thermal_ann_file, target_size=(224, 224), max_images=100):
    rgb_coco = COCO(rgb_ann_file)
    thermal_coco = COCO(thermal_ann_file)

    image_ids = rgb_coco.getImgIds()
    dataset = []

    tf_rgb = transforms.Compose([
        transforms.ToTensor(),  # converts to [0,1] and channels-first
        transforms.Resize(target_size),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # normalize to [-1, 1]
    ])

    tf_thermal = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(target_size),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    for img_id in image_ids[:max_images]:
        rgb_info = rgb_coco.loadImgs(img_id)[0]
        thermal_info = thermal_coco.loadImgs(img_id)[0]

        rgb_path = os.path.join(base_path, "rgb", rgb_info['file_name'])
        thermal_path = os.path.join(base_path, "thermal", thermal_info['file_name'])

        if not os.path.exists(rgb_path) or not os.path.exists(thermal_path):
            print(f"Missing: {rgb_path} or {thermal_path}, skipping...")
            continue

        # Load images
        rgb_img = Image.open(rgb_path).convert("RGB")
        thermal_img = Image.open(thermal_path).convert("L")

        # Load and combine all masks for the image
        anns = rgb_coco.loadAnns(rgb_coco.getAnnIds(imgIds=img_id))
        orig_h, orig_w = rgb_img.size[1], rgb_img.size[0]

        combined_mask = np.zeros((orig_h, orig_w), dtype=np.uint8)
        for ann in anns:
            mask = coco_ann_to_mask(ann, orig_h, orig_w)
            combined_mask = np.maximum(combined_mask, mask * ann['category_id'])  # combine classes

        # Convert to PIL and resize
        mask_img = Image.fromarray(combined_mask)
        mask_img = mask_img.resize(target_size, resample=Image.NEAREST)

        # Apply transforms
        rgb_tensor = tf_rgb(rgb_img)
        thermal_tensor = tf_thermal(thermal_img)
        mask_tensor = torch.from_numpy(np.array(mask_img)).long()  # no normalization for masks

        sample = {
            'file_name': rgb_info['file_name'],
            'rgb': rgb_tensor,              # [3, H, W]
            'thermal': thermal_tensor,      # [1, H, W]
            'mask': mask_tensor             # [H, W] with class indices
        }

        dataset.append(sample)

    print(f"Processed {len(dataset)} samples for ViT segmentation.")
    return dataset
