In [None]:
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as TF, InterpolationMode
from torch.utils.data import Dataset

class UnlabeledLesionDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = self.transform(img) if self.transform else img
        return img

class LabeledLesionDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx]).convert("L")
        img, mask = self.transform(img, mask)
        return img, mask

class SegmentationTransform:
    def __init__(self, size=(224,224)):
        self.size = size
    def __call__(self, img, mask=None):
        img = TF.resize(img, self.size, interpolation=InterpolationMode.BILINEAR)
        img = TF.to_tensor(img) * 2.0 - 1.0
        if mask is not None:
            mask = TF.resize(mask, self.size, interpolation=InterpolationMode.NEAREST)
            mask = TF.to_tensor(mask)
            mask = (mask > 0).float()
            return img, mask
        return img