# NOTE
- Padding should be done because the size of the images must be divisible to 32. U-Net Architecture takes input images only that way.
- Contrast Limited Adaptive Histogram Equalization (CLAHE) improves the performance
- GAN?

In [1]:
"""PATH"""
training_path = "C:/Users/Mardeen/Desktop/UT/deep-learning-med/DRIVE_dataset/datasets/training/"

In [23]:
import os
import numpy as np
from glob import glob
from PIL import Image
import torch
from torch.utils.data import Dataset, random_split
from models import *
from losses import *

class DRIVEDataset(Dataset):
    """## image  height = 512, width = 512, and 3 color channels.
    ## label  height = 512, width = 512, and 1 color channel.
    ## .float() / 255: Normalizes the pixel values from [0, 255] to [0, 1]
    """

    def __init__(self, data_path):
        super().__init__()
        
        self.images_path = sorted(glob(os.path.join(data_path, "images", "*")))
        self.masks_path = sorted(glob(os.path.join(data_path, "1st_manual", "*")))
        self.n_samples = len(self.images_path)

        if self.n_samples == 0:
            print("No images found! Check the dataset path.")

        for i in self.masks_path:
            if not os.path.exists(i):
                print(f"File {i} does not exist.")

    def __getitem__(self, index):
        try:
            data = Image.open(self.images_path[index]).convert('RGB').resize((512,512), resample=Image.Resampling.NEAREST)
            label = Image.open(self.masks_path[index]).convert('L').resize((512,512), resample=Image.Resampling.NEAREST)

            data = np.array(data)
            label = np.array(label)

            if data.shape[-1] == 3:
                data = torch.from_numpy(data.transpose(2, 0, 1)).float() / 255
                label = torch.from_numpy(label).float().unsqueeze(0) / 255
            else:
                data = torch.from_numpy(data).unsqueeze(0).float() / 255
                label = torch.from_numpy(label).float().unsqueeze(0) / 255

            return data, label
        
        except Exception as e:
            print(f"Error loading file at index {index}: {e}")
            return None, None 
    
    def __len__(self):
        return self.n_samples

def split_dataset(data_path, train_ratio=0.8, seed=42):
    """
    Splits the dataset into train and test datasets.

    :param data_path: Path to the dataset folder.
    :param train_ratio: Ratio of dataset to use for training. Default is 80% train, 20% test.
    :param seed: Random seed for reproducibility.
    :return: train_dataset, test_dataset
    """
    dataset = DRIVEDataset(data_path)  

    train_size = int(train_ratio * len(dataset))
    test_size = len(dataset) - train_size

    # Set random seed for reproducibility
    torch.manual_seed(seed)
    
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    print(f"Dataset split: {train_size} training samples, {test_size} testing samples.")
    return train_dataset, test_dataset



In [24]:
dataset = DRIVEDataset(training_path)

# first image and mask
image, mask = dataset[0]
print(f"Image shape: {image.shape}")
print(f"Mask shape: {mask.shape}")


# Split dataset into training and testing sets
train_dataset, val_dataset = split_dataset(training_path, train_ratio=0.8)

# Check first image-mask pair from the train set
if len(train_dataset) > 0:
    image, mask = train_dataset[0]
    print(f"Train Image shape: {image.shape}")
    print(f"Train Mask shape: {mask.shape}")


Image shape: torch.Size([3, 512, 512])
Mask shape: torch.Size([1, 512, 512])
Dataset split: 16 training samples, 4 testing samples.
Train Image shape: torch.Size([3, 512, 512])
Train Mask shape: torch.Size([1, 512, 512])


# Dataloader

In [25]:
from torch.utils.data import DataLoader
batch_size= 4
num_workers=0
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(val_dataset)}")

# Checking the first batch
for images, masks in train_loader:
    print(f"Batch Image shape: {images.shape}")  # Expected: (batch_size, 3, 512, 512)
    print(f"Batch Mask shape: {masks.shape}")  # Expected: (batch_size, 1, 512, 512)
    break  # Only check the first batch

Train dataset size: 16
Test dataset size: 4
Batch Image shape: torch.Size([4, 3, 512, 512])
Batch Mask shape: torch.Size([4, 1, 512, 512])


# Train loop

In [26]:
def train_unet(model, train_loader, val_loader, num_epochs=10, lr=1e-4, device="cuda"):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()  # Binary Cross Entropy Loss

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss_bce = criterion(outputs, masks)
            loss_dice = dice_loss(outputs, masks)
            loss = loss_bce + loss_dice  # Combined BCE + Dice Loss

            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        
        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)

                loss_bce = criterion(outputs, masks)
                loss_dice = dice_loss(outputs, masks)
                loss = loss_bce + loss_dice

                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")


In [28]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = UNet()

train_unet(model, train_loader, val_loader, num_epochs=10, lr=1e-4, device=device)


Using device: cpu
Epoch [1/10] - Train Loss: 1.7805, Val Loss: 1.7902
Epoch [2/10] - Train Loss: 1.7771, Val Loss: 1.7868
Epoch [3/10] - Train Loss: 1.7730, Val Loss: 1.7823
Epoch [4/10] - Train Loss: 1.7674, Val Loss: 1.7759
Epoch [5/10] - Train Loss: 1.7589, Val Loss: 1.7648
Epoch [6/10] - Train Loss: 1.7429, Val Loss: 1.7407
Epoch [7/10] - Train Loss: 1.7131, Val Loss: 1.6959
Epoch [8/10] - Train Loss: 1.7069, Val Loss: 1.6973
Epoch [9/10] - Train Loss: 1.6946, Val Loss: 1.6932
Epoch [10/10] - Train Loss: 1.6931, Val Loss: 1.6931


# Augmentation

In [None]:
# import os
# import numpy as np
# import pickle
# import cv2
# import torch
# from torch.utils.data import Dataset
# from PIL import Image
# from glob import glob


# import torch
# from torchvision.transforms import functional as tf
# from torchvision.transforms import Compose, RandomHorizontalFlip, RandomVerticalFlip



# def pipeline_tranforms():
#     return Compose([RandomHorizontalFlip(p=0.5),
#                     RandomVerticalFlip(p=0.5),
#                     Fix_RandomRotation(),
#                     ])


# class Fix_RandomRotation(object):

#     def __init__(self, degrees=360, expand=False, center=None):
#         self.degrees = degrees
#         self.expand = expand
#         self.center = center

#     @staticmethod
#     def get_params():
#         p = torch.rand(1)

#         if p >= 0 and p < 0.25:
#             angle = -180
#         elif p >= 0.25 and p < 0.5:
#             angle = -90
#         elif p >= 0.5 and p < 0.75:
#             angle = 90
#         else:
#             angle = 0
#         return angle

#     def __call__(self, img):
#         angle = self.get_params()
#         return tf.rotate(img, angle, expand=self.expand, center=self.center)

#     def __repr__(self):
#         format_string = self.__class__.__name__ + \
#             '(degrees={0}'.format(self.degrees)
#         # format_string += ', resample={0}'.format(self.resample)
#         format_string += ', expand={0}'.format(self.expand)
#         if self.center is not None:
#             format_string += ', center={0}'.format(self.center)
#         format_string += ')'
#         return format_string
    
# class AUGMENT(Dataset):

#     def __init__(self, CFG, images_path, mask_paths, mode='train'):
#         super(AUGMENT, self).__init__()
#         self.mode = mode
#         self.images_path = images_path
#         self.masks_path = mask_paths
#         self.transforms = pipeline_tranforms()
#         self.CFG = CFG
        
#         self.n_samples = len(self.images_path)
    
#     def __getitem__(self, index):
#         """ Reading image """
#         image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR)
#         image = clahe_equalized(image)
#         image = cv2.resize(image, (self.CFG.size, self.CFG.size), interpolation=cv2.INTER_NEAREST)

#         image = image / 255.0  # type: ignore # (512, 512, 3) Normalizing to range (0,1)
#         image = np.transpose(image, (2, 0, 1))  # (3, 512, 512)
#         image = image.astype(np.float32)
#         image = torch.from_numpy(image)

#         """ Reading mask """
#         mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)
#         mask = cv2.resize(mask,  (self.CFG.size, self.CFG.size), interpolation=cv2.INTER_NEAREST)
#         mask = mask / 255.0  # type: ignore # (512, 512)
#         mask = np.expand_dims(mask, axis=0)  # (1, 512, 512)
#         mask = mask.astype(np.float32)
#         mask = torch.from_numpy(mask)

#         # common transform
#         if self.mode == 'train':
#             seed = torch.seed()
#             torch.manual_seed(seed)
#             image = self.transforms(image) # type: ignore
#             torch.manual_seed(seed)
#             mask = self.transforms(mask) # type: ignore

#         return image, mask

#     def __len__(self):
#         return self.n_samples

In [None]:
# import numpy as np
# import torch
# from torch.utils.data import DataLoader, SubsetRandomSampler

# def dataset_loader(dataset):
#     """
#     Splits the dataset into training and validation sets and returns DataLoaders.

#     Args:
#         dataset (Dataset): A PyTorch Dataset (e.g., DRIVEDataset).
#         CFG: Configuration object with `batch_size`, `num_workers`, and `random_seed`.

#     Returns:
#         train_loader (DataLoader): DataLoader for the training set.
#         val_loader (DataLoader): DataLoader for the validation set.
#     """

#     # Split dataset into train and validation
#     validation_split = 0.2  # 20% for validation
#     shuffle_dataset = True

#     # Creating indices for training and validation split
#     dataset_size = len(dataset)
#     indices = list(range(dataset_size))
#     split = int(np.floor(validation_split * dataset_size))

#     if shuffle_dataset:
#         np.random.seed(42)
#         np.random.shuffle(indices)

#     train_indices, val_indices = indices[split:], indices[:split]

#     # Create subset samplers
#     train_sampler = SubsetRandomSampler(train_indices)
#     valid_sampler = SubsetRandomSampler(val_indices)

#     # Create DataLoaders
#     train_loader = DataLoader(
#         dataset, batch_size=batch_size, pin_memory=True,
#         sampler=train_sampler, drop_last=True, num_workers=num_workers
#     )

#     val_loader = DataLoader(
#         dataset, batch_size=batch_size, pin_memory=True,
#         sampler=valid_sampler, drop_last=True, num_workers=num_workers
#     )

#     print(f"Dataset split: {len(train_indices)} training samples, {len(val_indices)} validation samples.")

#     return train_loader, val_loader
