In [2]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Utility Functions 
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Dataset 
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples


In [5]:
class DiceBCELoss(nn.Module):
    def __init__(self, alpha=0.5, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.smooth = smooth
        self.bce_fn = nn.BCEWithLogitsLoss()

    def forward(self, logits, targets):
        # BCE component (stable, with logits)
        bce_loss = self.bce_fn(logits, targets)

        # Dice component (per-sample)
        probs = torch.sigmoid(logits)
        batch_size = probs.shape[0]
        dice_losses = []
        for i in range(batch_size):
            p = probs[i].view(-1)
            g = targets[i].view(-1)
            inter = (p * g).sum()
            dice = 1 - (2*inter + self.smooth) / (p.sum() + g.sum() + self.smooth)
            dice_losses.append(dice)
        dice_loss = torch.stack(dice_losses).mean()

        # Combined
        return self.alpha * dice_loss + (1 - self.alpha) * bce_loss


In [None]:
# Model Architecture 
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x

class EncoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = ConvBlock(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        x = self.conv(x)
        p = self.pool(x)
        return x, p

class DecoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)
        self.conv = ConvBlock(out_c * 2, out_c)

    def forward(self, x, skip):
        x = self.up(x)
        x = nn.functional.interpolate(x, size=skip.size()[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip], axis=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.e1 = EncoderBlock(1, 64)
        self.e2 = EncoderBlock(64, 128)
        self.e3 = EncoderBlock(128, 256)
        self.e4 = EncoderBlock(256, 512)
        self.b = ConvBlock(512, 1024)
        self.d1 = DecoderBlock(1024, 512)
        self.d2 = DecoderBlock(512, 256)
        self.d3 = DecoderBlock(256, 128)
        self.d4 = DecoderBlock(128, 64)
        self.outputs = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        b = self.b(p4)
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        return self.outputs(d4)


In [None]:
# Training and Evaluation 
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [8]:
seeding(42)
MODEL_NAME = "Green_UNet_DiceBCE_Full"
MODEL_DIRECTORY = "Green_Model_UNet_DiceBCE_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Results_UNet_DiceBCE_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
# load full training set 
train_images = sorted(glob("../final_dataset/train/images/*"))
train_masks  = sorted(glob("../final_dataset/train/masks/*"))
train_dataset = RetinalDataset(train_images, train_masks)

# load test set and split it 50/50 into validation and test
test_images = sorted(glob("../final_dataset/test/images/*"))
test_masks  = sorted(glob("../final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)
n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = DiceBCELoss()

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

best_valid_loss = float("inf")



In [10]:
for epoch in range(40):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss  
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")

                                                                                

Epoch 01 | Time: 9m 5s
	Train Loss: 0.4535 | Valid Loss: 0.4004
	Accuracy: 0.9915 | F1: 0.2798 | Dice: 0.3516 | Recall: 0.2524 | Precision: 0.5796 | Jaccard: 0.1854
Best Green_UNet_DiceBCE_Full Saved


                                                                                

Epoch 02 | Time: 8m 36s
	Train Loss: 0.3405 | Valid Loss: 0.3838
	Accuracy: 0.9917 | F1: 0.3135 | Dice: 0.3823 | Recall: 0.2867 | Precision: 0.5733 | Jaccard: 0.2086
Best Green_UNet_DiceBCE_Full Saved


                                                                                

Epoch 03 | Time: 8m 36s
	Train Loss: 0.3240 | Valid Loss: 0.4222
	Accuracy: 0.9918 | F1: 0.2414 | Dice: 0.2896 | Recall: 0.1883 | Precision: 0.6273 | Jaccard: 0.1602


                                                                                

Epoch 04 | Time: 8m 36s
	Train Loss: 0.3128 | Valid Loss: 0.3753
	Accuracy: 0.9922 | F1: 0.3237 | Dice: 0.3793 | Recall: 0.2836 | Precision: 0.5723 | Jaccard: 0.2194
Best Green_UNet_DiceBCE_Full Saved


                                                                                

Epoch 05 | Time: 8m 35s
	Train Loss: 0.3028 | Valid Loss: 0.3591
	Accuracy: 0.9921 | F1: 0.3572 | Dice: 0.4257 | Recall: 0.3387 | Precision: 0.5729 | Jaccard: 0.2444
Best Green_UNet_DiceBCE_Full Saved


                                                                                

Epoch 06 | Time: 8m 35s
	Train Loss: 0.2947 | Valid Loss: 0.3739
	Accuracy: 0.9922 | F1: 0.3236 | Dice: 0.3806 | Recall: 0.2757 | Precision: 0.6143 | Jaccard: 0.2200


                                                                                

Epoch 07 | Time: 8m 35s
	Train Loss: 0.2891 | Valid Loss: 0.3616
	Accuracy: 0.9923 | F1: 0.3515 | Dice: 0.4136 | Recall: 0.3253 | Precision: 0.5676 | Jaccard: 0.2405


                                                                                

Epoch 08 | Time: 8m 37s
	Train Loss: 0.2832 | Valid Loss: 0.3637
	Accuracy: 0.9923 | F1: 0.3433 | Dice: 0.4039 | Recall: 0.3138 | Precision: 0.5668 | Jaccard: 0.2364


                                                                                

Epoch 09 | Time: 8m 37s
	Train Loss: 0.2783 | Valid Loss: 0.3964
	Accuracy: 0.9923 | F1: 0.2892 | Dice: 0.3436 | Recall: 0.2376 | Precision: 0.6202 | Jaccard: 0.1967


                                                                                

Epoch 10 | Time: 8m 35s
	Train Loss: 0.2750 | Valid Loss: 0.3709
	Accuracy: 0.9917 | F1: 0.3315 | Dice: 0.3964 | Recall: 0.3095 | Precision: 0.5509 | Jaccard: 0.2247


                                                                                

Epoch 11 | Time: 8m 36s
	Train Loss: 0.2717 | Valid Loss: 0.3606
	Accuracy: 0.9922 | F1: 0.3636 | Dice: 0.4243 | Recall: 0.3215 | Precision: 0.6238 | Jaccard: 0.2505


                                                                                

Epoch 12 | Time: 8m 35s
	Train Loss: 0.2567 | Valid Loss: 0.3665
	Accuracy: 0.9924 | F1: 0.3456 | Dice: 0.4077 | Recall: 0.2992 | Precision: 0.6398 | Jaccard: 0.2378


                                                                                

Epoch 13 | Time: 8m 36s
	Train Loss: 0.2519 | Valid Loss: 0.3691
	Accuracy: 0.9922 | F1: 0.3403 | Dice: 0.4017 | Recall: 0.2994 | Precision: 0.6101 | Jaccard: 0.2340


                                                                                

Epoch 14 | Time: 8m 36s
	Train Loss: 0.2494 | Valid Loss: 0.3759
	Accuracy: 0.9900 | F1: 0.3321 | Dice: 0.4053 | Recall: 0.3109 | Precision: 0.5823 | Jaccard: 0.2289


                                                                                

Epoch 15 | Time: 8m 37s
	Train Loss: 0.2473 | Valid Loss: 0.3717
	Accuracy: 0.9916 | F1: 0.3455 | Dice: 0.4140 | Recall: 0.3246 | Precision: 0.5716 | Jaccard: 0.2378


                                                                                

Epoch 16 | Time: 8m 38s
	Train Loss: 0.2452 | Valid Loss: 0.3734
	Accuracy: 0.9882 | F1: 0.3426 | Dice: 0.4197 | Recall: 0.3396 | Precision: 0.5492 | Jaccard: 0.2368


                                                                                

Epoch 17 | Time: 8m 37s
	Train Loss: 0.2437 | Valid Loss: 0.3709
	Accuracy: 0.9914 | F1: 0.3345 | Dice: 0.4022 | Recall: 0.3078 | Precision: 0.5802 | Jaccard: 0.2308


                                                                                

Epoch 18 | Time: 8m 37s
	Train Loss: 0.2414 | Valid Loss: 0.3775
	Accuracy: 0.9893 | F1: 0.3279 | Dice: 0.4037 | Recall: 0.3114 | Precision: 0.5737 | Jaccard: 0.2262


                                                                                

Epoch 19 | Time: 8m 36s
	Train Loss: 0.2409 | Valid Loss: 0.3718
	Accuracy: 0.9900 | F1: 0.3380 | Dice: 0.4126 | Recall: 0.3295 | Precision: 0.5519 | Jaccard: 0.2328


                                                                                

Epoch 20 | Time: 8m 36s
	Train Loss: 0.2406 | Valid Loss: 0.3923
	Accuracy: 0.9833 | F1: 0.3107 | Dice: 0.3952 | Recall: 0.2983 | Precision: 0.5854 | Jaccard: 0.2138


                                                                                

Epoch 21 | Time: 8m 36s
	Train Loss: 0.2404 | Valid Loss: 0.3715
	Accuracy: 0.9900 | F1: 0.3469 | Dice: 0.4218 | Recall: 0.3355 | Precision: 0.5679 | Jaccard: 0.2400


                                                                                

Epoch 22 | Time: 8m 36s
	Train Loss: 0.2401 | Valid Loss: 0.3916
	Accuracy: 0.9919 | F1: 0.3066 | Dice: 0.3669 | Recall: 0.2650 | Precision: 0.5963 | Jaccard: 0.2116


                                                                                

Epoch 23 | Time: 8m 37s
	Train Loss: 0.2399 | Valid Loss: 0.3869
	Accuracy: 0.9917 | F1: 0.3179 | Dice: 0.3836 | Recall: 0.2825 | Precision: 0.5974 | Jaccard: 0.2196


                                                                                

Epoch 24 | Time: 8m 37s
	Train Loss: 0.2396 | Valid Loss: 0.3760
	Accuracy: 0.9918 | F1: 0.3318 | Dice: 0.3976 | Recall: 0.2977 | Precision: 0.5987 | Jaccard: 0.2297


                                                                                

Epoch 25 | Time: 8m 37s
	Train Loss: 0.2396 | Valid Loss: 0.3889
	Accuracy: 0.9919 | F1: 0.3131 | Dice: 0.3763 | Recall: 0.2744 | Precision: 0.5988 | Jaccard: 0.2155


                                                                                

Epoch 26 | Time: 8m 42s
	Train Loss: 0.2395 | Valid Loss: 0.3758
	Accuracy: 0.9900 | F1: 0.3443 | Dice: 0.4214 | Recall: 0.3365 | Precision: 0.5637 | Jaccard: 0.2379


                                                                                

Epoch 27 | Time: 8m 37s
	Train Loss: 0.2395 | Valid Loss: 0.3884
	Accuracy: 0.9919 | F1: 0.3096 | Dice: 0.3721 | Recall: 0.2699 | Precision: 0.5986 | Jaccard: 0.2134


                                                                                

Epoch 28 | Time: 8m 36s
	Train Loss: 0.2395 | Valid Loss: 0.3800
	Accuracy: 0.9914 | F1: 0.3246 | Dice: 0.3914 | Recall: 0.2961 | Precision: 0.5771 | Jaccard: 0.2248


                                                                                

Epoch 29 | Time: 8m 36s
	Train Loss: 0.2395 | Valid Loss: 0.3787
	Accuracy: 0.9916 | F1: 0.3210 | Dice: 0.3842 | Recall: 0.2855 | Precision: 0.5872 | Jaccard: 0.2222


                                                                                

Epoch 30 | Time: 8m 37s
	Train Loss: 0.2394 | Valid Loss: 0.3928
	Accuracy: 0.9920 | F1: 0.3160 | Dice: 0.3736 | Recall: 0.2721 | Precision: 0.5960 | Jaccard: 0.2192


                                                                                

Epoch 31 | Time: 8m 38s
	Train Loss: 0.2394 | Valid Loss: 0.3753
	Accuracy: 0.9909 | F1: 0.3320 | Dice: 0.4021 | Recall: 0.3086 | Precision: 0.5768 | Jaccard: 0.2295


                                                                                

Epoch 32 | Time: 8m 37s
	Train Loss: 0.2394 | Valid Loss: 0.3896
	Accuracy: 0.9912 | F1: 0.3103 | Dice: 0.3799 | Recall: 0.2812 | Precision: 0.5851 | Jaccard: 0.2138


                                                                                

Epoch 33 | Time: 8m 36s
	Train Loss: 0.2394 | Valid Loss: 0.3797
	Accuracy: 0.9888 | F1: 0.3271 | Dice: 0.4073 | Recall: 0.3104 | Precision: 0.5920 | Jaccard: 0.2248


                                                                                

Epoch 34 | Time: 8m 35s
	Train Loss: 0.2394 | Valid Loss: 0.3779
	Accuracy: 0.9900 | F1: 0.3243 | Dice: 0.4005 | Recall: 0.3071 | Precision: 0.5757 | Jaccard: 0.2241


                                                                                

Epoch 35 | Time: 8m 36s
	Train Loss: 0.2394 | Valid Loss: 0.3767
	Accuracy: 0.9912 | F1: 0.3340 | Dice: 0.4020 | Recall: 0.3096 | Precision: 0.5728 | Jaccard: 0.2306


                                                                                

Epoch 36 | Time: 8m 35s
	Train Loss: 0.2394 | Valid Loss: 0.3709
	Accuracy: 0.9908 | F1: 0.3380 | Dice: 0.4102 | Recall: 0.3174 | Precision: 0.5799 | Jaccard: 0.2333


                                                                                

Epoch 37 | Time: 8m 36s
	Train Loss: 0.2394 | Valid Loss: 0.3746
	Accuracy: 0.9909 | F1: 0.3295 | Dice: 0.4000 | Recall: 0.3040 | Precision: 0.5844 | Jaccard: 0.2279


                                                                                

Epoch 38 | Time: 8m 38s
	Train Loss: 0.2394 | Valid Loss: 0.3792
	Accuracy: 0.9900 | F1: 0.3239 | Dice: 0.4011 | Recall: 0.3095 | Precision: 0.5696 | Jaccard: 0.2230


                                                                                

Epoch 39 | Time: 8m 38s
	Train Loss: 0.2394 | Valid Loss: 0.3865
	Accuracy: 0.9836 | F1: 0.3199 | Dice: 0.3991 | Recall: 0.3077 | Precision: 0.5679 | Jaccard: 0.2206


                                                                                

Epoch 40 | Time: 8m 36s
	Train Loss: 0.2394 | Valid Loss: 0.3977
	Accuracy: 0.9916 | F1: 0.2920 | Dice: 0.3552 | Recall: 0.2524 | Precision: 0.5996 | Jaccard: 0.1996




In [11]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


  model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
Testing: 100%|████████████████████████████████| 113/113 [00:19<00:00,  5.75it/s]

Jaccard: 0.2379, F1: 0.3457, Recall: 0.3594, Precision: 0.4845, Accuracy: 0.9945
FPS: 826.0346025344518



