In [3]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Utility Functions
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Dataset
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples


In [6]:
class Focal_Tversky(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, gamma=1.5, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta  = beta
        self.gamma = gamma
        self.smooth = smooth

    def forward(self, logits, targets):
        # logits, targets: (N, 1, H, W) or (N, H, W)
        probs = torch.sigmoid(logits)
        N = targets.size(0)

        # flatten per sample
        probs  = probs.view(N, -1)
        targets = targets.view(N, -1)

        TP = (probs * targets).sum(dim=1)
        FP = (probs * (1 - targets)).sum(dim=1)
        FN = ((1 - probs) * targets).sum(dim=1)

        smooth = torch.tensor(self.smooth, device=probs.device, dtype=probs.dtype)
        tversky = (TP + smooth) / (TP + self.alpha * FP + self.beta * FN + smooth)

        # focal modulation
        loss = (1 - tversky) ** self.gamma

        return loss.mean()


In [7]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        s = self.conv(x)
        p = self.pool(s)
        return s, p

class attention_gate(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.Wg = nn.Sequential(
            nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.Ws = nn.Sequential(
            nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, g, s):
        Wg = self.Wg(g)
        Ws = self.Ws(s)
        out = self.relu(Wg + Ws)
        out = self.output(out)
        return out * s

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.ag = attention_gate(in_c, out_c)
        self.c1 = conv_block(in_c[0]+out_c, out_c)

    def forward(self, x, s):
        x = self.up(x)
        s = self.ag(x, s)
        x = torch.cat([x, s], axis=1)
        x = self.c1(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.e1 = encoder_block(1, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)

        self.b1 = conv_block(512, 1024)

        self.d1 = decoder_block([1024, 512], 512)
        self.d2 = decoder_block([512, 256], 256)
        self.d3 = decoder_block([256, 128], 128)
        self.d4 = decoder_block([128, 64], 64)
        
        self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b1 = self.b1(p4)

        d1 = self.d1(b1, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        output = self.output(d4)
        return output

In [None]:
#  Training and Evaluation 
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [9]:
seeding(42)
MODEL_NAME = "Green_Attention_UNet_Focal_Tversky_Full"
MODEL_DIRECTORY = "Green_Attention_Model_UNet_Focal_Tversky_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Attention_Results_UNet_Focal_Tversky_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
# load full training set 
train_images = sorted(glob("../final_dataset/train/images/*"))
train_masks  = sorted(glob("../final_dataset/train/masks/*"))
train_dataset = RetinalDataset(train_images, train_masks)

# load test set and split it 50/50 into validation and test 
test_images = sorted(glob("../final_dataset/test/images/*"))
test_masks  = sorted(glob("../final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)
n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = Focal_Tversky()

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

best_valid_loss = float("inf")



In [11]:
patience = 15
counter = 0

for epoch in range(50):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        counter = 0  # reset counter
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")
    else:
        counter += 1
        print(f"No improvement in validation loss for {counter} epoch(s)")

    if counter >= patience:
        print(f"Early stopping triggered after {patience} epochs with no improvement.")
        break


                                                                      

Epoch 01 | Time: 11m 47s
	Train Loss: 0.7787 | Valid Loss: 0.6501
	Accuracy: 0.9883 | F1: 0.2538 | Dice: 0.3707 | Recall: 0.3760 | Precision: 0.3656 | Jaccard: 0.1564
Best Green_Attention_UNet_Focal_Tversky_Full Saved


                                                                      

Epoch 02 | Time: 11m 17s
	Train Loss: 0.5590 | Valid Loss: 0.6155
	Accuracy: 0.9912 | F1: 0.3070 | Dice: 0.3834 | Recall: 0.3151 | Precision: 0.4894 | Jaccard: 0.2022
Best Green_Attention_UNet_Focal_Tversky_Full Saved


                                                                      

Epoch 03 | Time: 11m 17s
	Train Loss: 0.5280 | Valid Loss: 0.6723
	Accuracy: 0.9917 | F1: 0.2772 | Dice: 0.3410 | Recall: 0.2392 | Precision: 0.5936 | Jaccard: 0.1859
No improvement in validation loss for 1 epoch(s)


                                                                      

Epoch 04 | Time: 11m 27s
	Train Loss: 0.5056 | Valid Loss: 0.6321
	Accuracy: 0.9917 | F1: 0.3047 | Dice: 0.3777 | Recall: 0.2818 | Precision: 0.5724 | Jaccard: 0.2043
No improvement in validation loss for 2 epoch(s)


                                                                      

Epoch 05 | Time: 11m 17s
	Train Loss: 0.4848 | Valid Loss: 0.5877
	Accuracy: 0.9918 | F1: 0.3359 | Dice: 0.4073 | Recall: 0.3261 | Precision: 0.5423 | Jaccard: 0.2262
Best Green_Attention_UNet_Focal_Tversky_Full Saved


                                                                      

Epoch 06 | Time: 11m 18s
	Train Loss: 0.4736 | Valid Loss: 0.5761
	Accuracy: 0.9915 | F1: 0.3378 | Dice: 0.4071 | Recall: 0.3488 | Precision: 0.4888 | Jaccard: 0.2265
Best Green_Attention_UNet_Focal_Tversky_Full Saved


                                                                      

Epoch 07 | Time: 11m 17s
	Train Loss: 0.4682 | Valid Loss: 0.5979
	Accuracy: 0.9920 | F1: 0.3349 | Dice: 0.4025 | Recall: 0.3089 | Precision: 0.5772 | Jaccard: 0.2281
No improvement in validation loss for 1 epoch(s)


                                                                      

Epoch 08 | Time: 11m 16s
	Train Loss: 0.4541 | Valid Loss: 0.5533
	Accuracy: 0.9909 | F1: 0.3475 | Dice: 0.4234 | Recall: 0.3971 | Precision: 0.4534 | Jaccard: 0.2343
Best Green_Attention_UNet_Focal_Tversky_Full Saved


                                                                      

Epoch 09 | Time: 11m 18s
	Train Loss: 0.4463 | Valid Loss: 0.5414
	Accuracy: 0.9920 | F1: 0.3754 | Dice: 0.4352 | Recall: 0.3639 | Precision: 0.5414 | Jaccard: 0.2587
Best Green_Attention_UNet_Focal_Tversky_Full Saved


                                                                      

Epoch 10 | Time: 11m 18s
	Train Loss: 0.4499 | Valid Loss: 0.7903
	Accuracy: 0.9913 | F1: 0.0947 | Dice: 0.1083 | Recall: 0.0616 | Precision: 0.4477 | Jaccard: 0.0584
No improvement in validation loss for 1 epoch(s)


                                                                      

Epoch 11 | Time: 11m 18s
	Train Loss: 0.4269 | Valid Loss: 0.6657
	Accuracy: 0.9916 | F1: 0.2179 | Dice: 0.2613 | Recall: 0.1914 | Precision: 0.4114 | Jaccard: 0.1436
No improvement in validation loss for 2 epoch(s)


                                                                      

Epoch 12 | Time: 11m 19s
	Train Loss: 0.4791 | Valid Loss: 0.6652
	Accuracy: 0.9901 | F1: 0.2444 | Dice: 0.3084 | Recall: 0.2631 | Precision: 0.3726 | Jaccard: 0.1644
No improvement in validation loss for 3 epoch(s)


                                                                      

Epoch 13 | Time: 11m 19s
	Train Loss: 0.5673 | Valid Loss: 0.8540
	Accuracy: 0.9911 | F1: 0.0270 | Dice: 0.0320 | Recall: 0.0173 | Precision: 0.2117 | Jaccard: 0.0165
No improvement in validation loss for 4 epoch(s)


                                                                      

Epoch 14 | Time: 11m 23s
	Train Loss: 0.4086 | Valid Loss: 0.5940
	Accuracy: 0.9911 | F1: 0.3000 | Dice: 0.3567 | Recall: 0.2921 | Precision: 0.4578 | Jaccard: 0.2042
No improvement in validation loss for 5 epoch(s)


                                                                      

Epoch 15 | Time: 11m 27s
	Train Loss: 0.4999 | Valid Loss: 0.6774
	Accuracy: 0.9888 | F1: 0.1959 | Dice: 0.2708 | Recall: 0.2367 | Precision: 0.3164 | Jaccard: 0.1240
No improvement in validation loss for 6 epoch(s)


                                                                      

Epoch 16 | Time: 11m 23s
	Train Loss: 0.4088 | Valid Loss: 0.6441
	Accuracy: 0.9917 | F1: 0.2302 | Dice: 0.2876 | Recall: 0.2069 | Precision: 0.4719 | Jaccard: 0.1539
No improvement in validation loss for 7 epoch(s)


                                                                      

Epoch 17 | Time: 11m 22s
	Train Loss: 0.3711 | Valid Loss: 0.5725
	Accuracy: 0.9917 | F1: 0.2842 | Dice: 0.3498 | Recall: 0.2692 | Precision: 0.4992 | Jaccard: 0.1893
No improvement in validation loss for 8 epoch(s)


                                                                      

Epoch 18 | Time: 11m 19s
	Train Loss: 0.3600 | Valid Loss: 0.5505
	Accuracy: 0.9919 | F1: 0.3170 | Dice: 0.3752 | Recall: 0.3004 | Precision: 0.4996 | Jaccard: 0.2135
No improvement in validation loss for 9 epoch(s)


                                                                      

Epoch 19 | Time: 11m 21s
	Train Loss: 0.3392 | Valid Loss: 0.5520
	Accuracy: 0.9882 | F1: 0.3005 | Dice: 0.3671 | Recall: 0.3135 | Precision: 0.4428 | Jaccard: 0.2013
No improvement in validation loss for 10 epoch(s)


                                                                      

Epoch 20 | Time: 11m 25s
	Train Loss: 0.3277 | Valid Loss: 0.6085
	Accuracy: 0.9907 | F1: 0.2437 | Dice: 0.3064 | Recall: 0.2216 | Precision: 0.4964 | Jaccard: 0.1616
No improvement in validation loss for 11 epoch(s)


                                                                      

Epoch 21 | Time: 11m 19s
	Train Loss: 0.3218 | Valid Loss: 0.5705
	Accuracy: 0.9907 | F1: 0.2859 | Dice: 0.3412 | Recall: 0.2544 | Precision: 0.5182 | Jaccard: 0.1955
No improvement in validation loss for 12 epoch(s)


                                                                      

Epoch 22 | Time: 11m 20s
	Train Loss: 0.3144 | Valid Loss: 0.5575
	Accuracy: 0.9903 | F1: 0.2976 | Dice: 0.3610 | Recall: 0.2811 | Precision: 0.5045 | Jaccard: 0.2017
No improvement in validation loss for 13 epoch(s)


                                                                      

Epoch 23 | Time: 11m 19s
	Train Loss: 0.3127 | Valid Loss: 0.5668
	Accuracy: 0.9879 | F1: 0.2846 | Dice: 0.3548 | Recall: 0.2786 | Precision: 0.4882 | Jaccard: 0.1933
No improvement in validation loss for 14 epoch(s)


                                                                      

Epoch 24 | Time: 11m 19s
	Train Loss: 0.3118 | Valid Loss: 0.5621
	Accuracy: 0.9884 | F1: 0.2941 | Dice: 0.3684 | Recall: 0.2902 | Precision: 0.5044 | Jaccard: 0.1994
No improvement in validation loss for 15 epoch(s)
Early stopping triggered after 15 epochs with no improvement.




In [12]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


  model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
Testing: 100%|██████████| 113/113 [00:22<00:00,  5.13it/s]

Jaccard: 0.2050, F1: 0.3038, Recall: 0.3414, Precision: 0.4254, Accuracy: 0.9942
FPS: 647.0579388624111



