In [1]:
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
import cv2 
from skimage.filters import frangi 

In [2]:

def load_dataset(base_dir, img_folder, mask_folder, img_ext=".jpg", mask_ext=".tif", mask_suffix=""):
    images, masks = [], []

    img_dir = os.path.join(base_dir, img_folder)
    mask_dir = os.path.join(base_dir, mask_folder)

    for img_name in os.listdir(img_dir):
        if img_name.endswith(img_ext): # DRIVE dataset
            img_path = img_dir+'/'+ img_name
            print(img_path)
            mask_name = img_name.replace('_training.tif', '_manual1.gif') 
            #mask_name = img_name.replace('.jpg','_1stHO.png') 
            #mask_name = img_name.replace('.JPG','.tif') 

            mask_path = os.path.join(mask_dir, mask_name)
            print('mask path',mask_path)


            if os.path.exists(mask_path): 
                images.append(Image.open(img_path))
                masks.append(Image.open(mask_path))
                
    return images, masks


def apply_retinex(image, sigma_list=[2, 5, 10]):
    retinex = np.zeros_like(image, dtype=np.float32)
    for sigma in sigma_list:
        retinex += np.log1p(image) - np.log1p(cv2.GaussianBlur(image, (0, 0), sigma))
    retinex = retinex / len(sigma_list)
    return np.uint8(255 * (retinex - retinex.min()) / (retinex.max() - retinex.min()))


def apply_clahe(image, clip_limit=4.0, tile_grid_size=(32, 32)):
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    return clahe.apply(image)


def apply_gamma_correction(image, gamma=1.0):
    inv_gamma = 1.0 / gamma
    table = np.array([(i / 255.0) ** inv_gamma * 255 for i in range(256)]).astype("uint8")
    return cv2.LUT(image, table)


def apply_frangi(image):
    image_normalized = image / 255.0
    frangi_image = frangi(image_normalized)
    return np.uint8(255 * (frangi_image - frangi_image.min()) / (frangi_image.max() - frangi_image.min()))


class ConsistentRotation:
    def __init__(self, angles=[0, 15, 30, 45, 90, 100, 120], p=1.0):
        self.angles = angles
        self.p = p

    def __call__(self, img, mask):
        if random.random() < self.p:
            angle = random.choice(self.angles)
            img = img.rotate(angle)
            mask = mask.rotate(angle)
        return img, mask


class RetinalVesselDataset(Dataset):
    def __init__(self, folder_path, dataset_name='drive', transform=None, mask_transform=None, gamma=1.0, clahe_clip_limit=4.0, clahe_tile_grid_size=(32, 32)):
        dataset_loaders = {
            'drive': lambda: load_dataset(folder_path, 'training/images', 'training/1st_manual', '.tif', '.gif', '_manual1'),
            'chase': lambda: load_dataset(folder_path, 'Images', 'Masks', '.jpg', '.png', '_1stHO'),
            'hrf': lambda: load_dataset(folder_path, 'images', 'manual1', '.JPG', '.tif')
        }

        if dataset_name not in dataset_loaders:
            raise ValueError(f"Dataset {dataset_name} no recognized. Use: 'drive', 'chase' o 'hrf'.")

        self.images, self.masks = dataset_loaders[dataset_name]()
        self.transform = transform
        self.mask_transform = mask_transform
        self.rotation = ConsistentRotation()
        self.gamma = gamma
        self.clahe_clip_limit = clahe_clip_limit
        self.clahe_tile_grid_size = clahe_tile_grid_size  # Ahora configurable

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image, mask   = self.images[idx], self.masks[idx]
        green_channel = np.array(image)[:, :, 1]
        green_channel = apply_gamma_correction(green_channel, gamma=self.gamma)
        green_channel = apply_clahe(green_channel, clip_limit=self.clahe_clip_limit, tile_grid_size=self.clahe_tile_grid_size)
        image         = Image.fromarray(green_channel)
        image, mask   = self.rotation(image, mask)
        
        if self.transform:
            image = self.transform(image)
            mask  = self.transform(mask)

        #if self.mask_transform:
        #    mask = self.mask_transform(mask)

        return image, mask
